Source code for optimal_cutoffs.piecewise

"""Optimized O(n log n) sort-and-scan kernel for piecewise-constant metrics.

This module provides an exact optimizer for binary classification metrics that are
piecewise-constant with respect to the decision threshold. The algorithm sorts
predictions once and scans all n cuts in a single pass, achieving true O(n log n)
complexity with vectorized operations.

Notes on `require_proba`:
- If `require_proba=True`, inputs are validated to lie in [0, 1].
- The returned threshold is *usually* in [0, 1]; however, in boundary or tie cases,
  we may nudge it by one floating-point ULP beyond the range to correctly realize
  strict inclusivity/exclusivity (e.g., to ensure “predict none” with '>=' when max p == 1.0).
"""

from __future__ import annotations

from collections.abc import Callable
from typing import Any

import numpy as np

from .metrics import (
    apply_metric_to_confusion_counts,
    compute_vectorized_confusion_matrices,
    confusion_matrix_from_predictions,
)
from .types_minimal import OptimizationResult
from .validation import validate_binary_classification, validate_weights

Array = np.ndarray[Any, Any]

# NOTE: NUMERICAL_TOLERANCE moved to function parameters for user control


def _evaluate_metric_scalar_efficient(
    metric_fn: Callable, tp: float, tn: float, fp: float, fn: float
) -> float:
    """Efficiently evaluate metric function on scalar confusion matrix values.

    This avoids the inefficient pattern of converting scalars to single-element
    arrays just to call vectorized functions and extract the first element.

    Parameters
    ----------
    metric_fn : callable
        Vectorized metric function that expects arrays
    tp, tn, fp, fn : float
        Scalar confusion matrix values

    Returns
    -------
    float
        Metric score
    """
    # Call vectorized function with single-element arrays and extract result
    return float(
        metric_fn(
            np.array([tp], dtype=float),
            np.array([tn], dtype=float),
            np.array([fp], dtype=float),
            np.array([fn], dtype=float),
        )[0]
    )


def _compute_threshold_midpoint(
    p_sorted: Array, k_star: int, inclusive: bool = False, tolerance: float = 1e-10
) -> float:
    """Compute threshold as midpoint between adjacent sorted scores.

    With k indexed so that:
      k = 0: predict none positive (threshold > max score)
      k = 1..n-1: predict top-k items positive
      k = n: predict all positive (threshold <= min score)
    """
    n = p_sorted.size

    # k == 0: threshold must exclude every score
    if k_star == 0:
        max_prob = float(p_sorted[0])
        # For '>=', make threshold strictly greater than max_prob to exclude ties
        return float(np.nextafter(max_prob, np.inf)) if inclusive else max_prob

    # k == n: threshold must include every score
    if k_star == n:
        min_prob = float(p_sorted[-1])
        # For '>', make threshold strictly smaller than min_prob to include ties
        return float(np.nextafter(min_prob, -np.inf)) if not inclusive else min_prob

    # General case: separate p_sorted[k_star-1] (included) and p_sorted[k_star] (excluded)
    inc = float(p_sorted[k_star - 1])
    exc = float(p_sorted[k_star])

    if inc - exc > tolerance:
        thr = 0.5 * (inc + exc)
        # For '>=' we bias a half-ulp downward so equals land in the included side
        return float(np.nextafter(thr, -np.inf)) if inclusive else thr

    # Tied (or nearly tied) scores: choose side per operator
    tied = exc
    # For '>', place threshold just above tied to exclude equals.
    # For '>=', place just below tied to include equals.
    return (
        float(np.nextafter(tied, np.inf))
        if not inclusive
        else float(np.nextafter(tied, -np.inf))
    )


def _realized_k(p_sorted: Array, threshold: float, inclusive: bool) -> int:
    """Given a threshold and comparison mode, return #positives among p_sorted (desc)."""
    # Convert to ascending by negation and use searchsorted with the right side
    q = -p_sorted
    t = -threshold
    side = "right" if inclusive else "left"
    return int(np.searchsorted(q, t, side=side))


def _predict_from_threshold(probs: Array, threshold: float, inclusive: bool) -> Array:
    """Predict labels (0/1) from probabilities and threshold."""
    p = np.asarray(probs)
    if p.ndim == 2 and p.shape[1] == 2:
        p = p[:, 1]
    elif p.ndim == 2 and p.shape[1] == 1:
        p = p.ravel()
    return (
        (p >= threshold).astype(np.int32)
        if inclusive
        else (p > threshold).astype(np.int32)
    )


[docs] def optimal_threshold_sortscan( y_true: Array, pred_prob: Array, metric: str | Callable[[Array, Array, Array, Array], Array], *, sample_weight: Array | None = None, inclusive: bool = False, # True for ">=", False for ">" require_proba: bool = True, tolerance: float = 1e-10, ) -> OptimizationResult: """Exact optimizer for piecewise-constant metrics using O(n log n) sort-and-scan. Parameters ---------- y_true : array-like of shape (n_samples,) Binary labels in {0, 1}. pred_prob : array-like of shape (n_samples,) Predicted probabilities in [0, 1] or arbitrary scores if require_proba=False. metric : str or callable Metric name (e.g., "f1", "precision") or vectorized function. If string, automatically resolves to vectorized implementation. If callable: (tp_vec, tn_vec, fp_vec, fn_vec) -> score_vec. sample_weight : array-like, optional Non-negative sample weights of shape (n_samples,). inclusive : bool, default=False If True, use ">="; if False, use ">". require_proba : bool, default=True Validate inputs in [0, 1]. Threshold may be nudged by ±1 ULP outside [0,1] to exactly realize inclusivity/exclusivity in boundary/tie cases. tolerance : float, default=1e-10 Numerical tolerance for floating-point comparisons when computing threshold midpoints and handling ties between scores. Returns ------- OptimizationResult thresholds : array([optimal_threshold]) scores : array([achieved_score]) predict : callable(probs) -> {0,1}^n metric : str, set to "piecewise_metric" n_classes : 2 diagnostics: dict with keys: - k_argmax: theoretical best cut index (0..n) from the sweep - k_realized: positives realized by the returned threshold - score_theoretical: score at k_argmax - score_actual: score achieved by the returned threshold - tie_discrepancy: abs(theoretical - actual) - inclusive: bool - require_proba: bool """ # 0) Resolve metric to vectorized function if isinstance(metric, str): from .metrics import get_metric_function metric_fn = get_metric_function(metric) else: metric_fn = metric # 1) Validate inputs y, p, _ = validate_binary_classification( y_true, pred_prob, require_proba=require_proba ) n = y.shape[0] weights = ( validate_weights(sample_weight, n) if sample_weight is not None else np.ones(n, dtype=np.float64) ) # 2) Sort once by descending score (stable) order = np.argsort(-p, kind="mergesort") y_sorted = y[order] p_sorted = p[order] w_sorted = weights[order] # 3) Vectorized confusion counts at all n+1 cuts (k=0..n) tp_vec, tn_vec, fp_vec, fn_vec = compute_vectorized_confusion_matrices( y_sorted, w_sorted ) # 4) Vectorized metric over all cuts; take argmax score_vec = apply_metric_to_confusion_counts( metric_fn, tp_vec, tn_vec, fp_vec, fn_vec ) k_star = int(np.argmax(score_vec)) score_theoretical = float(score_vec[k_star]) # 5) Convert k* to a concrete threshold with correct > / >= semantics threshold = _compute_threshold_midpoint(p_sorted, k_star, inclusive, tolerance) # 6) Evaluate the achieved score at that threshold (handles ties & numerics) pred_labels = _predict_from_threshold(p, threshold, inclusive) tp, tn, fp, fn = confusion_matrix_from_predictions( y, pred_labels, sample_weight=weights ) score_actual = _evaluate_metric_scalar_efficient(metric_fn, tp, tn, fp, fn) # 7) If the realized score differs meaningfully (e.g., due to ties), probe a few # locally optimal alternatives (extremes and one-ULP nudges around the boundary). tie_discrepancy = abs(score_actual - score_theoretical) if tie_discrepancy > max(1e-6, 100 * tolerance): best_thr = threshold best_score = score_actual min_s = float(p_sorted[-1]) max_s = float(p_sorted[0]) candidates: list[float] = [] if inclusive: # Include all vs exclude all candidates.extend([min_s, float(np.nextafter(max_s, np.inf))]) else: candidates.extend([float(np.nextafter(min_s, -np.inf)), max_s]) if 0 < k_star < n: inc = float(p_sorted[k_star - 1]) # last included by k* exc = float(p_sorted[k_star]) # first excluded by k* candidates.extend( [ float(np.nextafter(inc, -np.inf)), # just below included float(np.nextafter(exc, np.inf)), # just above excluded ] ) # Evaluate candidates for t in candidates: # If require_proba, clamp only if it does not alter intended decision boundary; # we accept tiny excursions beyond [0,1] when necessary for semantics. t_eval = t pred_labels_alt = _predict_from_threshold(p, t_eval, inclusive) tp2, tn2, fp2, fn2 = confusion_matrix_from_predictions( y, pred_labels_alt, sample_weight=weights ) s2 = _evaluate_metric_scalar_efficient(metric_fn, tp2, tn2, fp2, fn2) if s2 > best_score: best_score = s2 best_thr = t_eval threshold = best_thr score_actual = best_score # 8) Diagnostics and final return k_real = _realized_k(p_sorted, threshold, inclusive) def predict_binary(probs: Array) -> Array: return _predict_from_threshold(probs, threshold, inclusive) diagnostics = { "k_argmax": k_star, "k_realized": k_real, "score_theoretical": score_theoretical, "score_actual": score_actual, "tie_discrepancy": abs(score_actual - score_theoretical), "inclusive": inclusive, "require_proba": require_proba, } return OptimizationResult( thresholds=np.array([float(threshold)], dtype=float), scores=np.array([float(score_actual)], dtype=float), predict=predict_binary, metric="piecewise_metric", n_classes=2, diagnostics=diagnostics, )