Source code for optimal_cutoffs.optimizers

"""Threshold search strategies for optimizing classification metrics."""

from typing import Literal

import numpy as np
from scipy import optimize

from .metrics import (
    METRIC_REGISTRY,
    get_confusion_matrix,
    get_multiclass_confusion_matrix,
    get_vectorized_metric,
    has_vectorized_implementation,
    is_piecewise_metric,
    multiclass_metric,
)
from .multiclass_coord import optimal_multiclass_thresholds_coord_ascent
from .piecewise import optimal_threshold_sortscan
from .types import ArrayLike, AveragingMethod, ComparisonOperator, OptimizationMethod
from .validation import (
    _validate_comparison_operator,
    _validate_inputs,
    _validate_metric_name,
    _validate_optimization_method,
)


def _accuracy(
    prob: np.ndarray, true_labs: ArrayLike, pred_prob: ArrayLike, verbose: bool = False
) -> float:
    tp, tn, fp, fn = get_confusion_matrix(true_labs, pred_prob, prob[0])
    accuracy = (tp + tn) / (tp + tn + fp + fn)
    if verbose:
        print(f"Probability: {prob[0]:0.4f} Accuracy: {accuracy:0.4f}")
    return 1 - accuracy


def _f1(
    prob: np.ndarray, true_labs: ArrayLike, pred_prob: ArrayLike, verbose: bool = False
) -> float:
    tp, tn, fp, fn = get_confusion_matrix(true_labs, pred_prob, prob[0])
    precision = tp / (tp + fp) if tp + fp > 0 else 0.0
    recall = tp / (tp + fn) if tp + fn > 0 else 0.0
    f1 = (
        2 * precision * recall / (precision + recall)
        if (precision + recall) > 0
        else 0.0
    )
    if verbose:
        print(f"Probability: {prob[0]:0.4f} F1 score: {f1:0.4f}")
    return 1 - f1


[docs] def get_probability( true_labs: ArrayLike, pred_prob: ArrayLike, objective: Literal["accuracy", "f1"] = "accuracy", verbose: bool = False, ) -> float: """Brute-force search for a simple metric's best threshold. .. deprecated:: 1.0.0 :func:`get_probability` is deprecated and will be removed in a future version. Use :func:`get_optimal_threshold` instead, which provides a unified API for both binary and multiclass classification with more optimization methods and additional features like sample weights. Parameters ---------- true_labs: Array of true binary labels. pred_prob: Predicted probabilities from a classifier. objective: Metric to optimize. Supported values are ``"accuracy"`` and ``"f1"``. verbose: If ``True``, print intermediate metric values during the search. Returns ------- float Threshold that maximizes the specified metric. """ import warnings warnings.warn( "get_probability is deprecated and will be removed in a future version. " "Use get_optimal_threshold instead, which provides a unified API for " "both binary and multiclass classification with more optimization methods " "and additional features like sample weights.", DeprecationWarning, stacklevel=2, ) if objective == "accuracy": prob = optimize.brute( _accuracy, (slice(0.1, 0.9, 0.1),), args=(true_labs, pred_prob, verbose), disp=verbose, ) elif objective == "f1": prob = optimize.brute( _f1, (slice(0.1, 0.9, 0.1),), args=(true_labs, pred_prob, verbose), disp=verbose, ) else: raise ValueError(f"Unknown objective: {objective}") return float(prob[0] if isinstance(prob, np.ndarray) else prob)
def _metric_score( true_labs: ArrayLike, pred_prob: ArrayLike, threshold: float, metric: str = "f1", sample_weight: ArrayLike | None = None, comparison: ComparisonOperator = ">", ) -> float: """Compute a metric score for a given threshold using registry metrics. Parameters ---------- true_labs: Array of true labels. pred_prob: Array of predicted probabilities. threshold: Decision threshold. metric: Name of metric from registry. sample_weight: Optional array of sample weights. Returns ------- float Computed metric score. """ tp, tn, fp, fn = get_confusion_matrix( true_labs, pred_prob, threshold, sample_weight, comparison ) try: metric_func = METRIC_REGISTRY[metric] except KeyError as exc: raise ValueError(f"Unknown metric: {metric}") from exc return float(metric_func(tp, tn, fp, fn)) def _multiclass_metric_score( true_labs: ArrayLike, pred_prob: ArrayLike, thresholds: ArrayLike, metric: str = "f1", average: AveragingMethod = "macro", sample_weight: ArrayLike | None = None, ) -> float: """Compute a multiclass metric score for given per-class thresholds. Parameters ---------- true_labs: Array of true class labels. pred_prob: Array of predicted probabilities. thresholds: Array of per-class thresholds. metric: Name of metric from registry. average: Averaging strategy for multiclass. sample_weight: Optional array of sample weights. Returns ------- float Computed multiclass metric score. """ confusion_matrices = get_multiclass_confusion_matrix( true_labs, pred_prob, thresholds, sample_weight ) return multiclass_metric(confusion_matrices, metric, average) def _optimal_threshold_piecewise( true_labs: ArrayLike, pred_prob: ArrayLike, metric: str = "f1", sample_weight: ArrayLike | None = None, comparison: ComparisonOperator = ">", ) -> float: """Find optimal threshold using O(n log n) algorithm for piecewise metrics. This function provides a backward-compatible interface to the optimized sort-and-scan implementation for piecewise-constant metrics. Parameters ---------- true_labs: Array of true binary labels. pred_prob: Array of predicted probabilities. metric: Name of metric to optimize from METRIC_REGISTRY. sample_weight: Optional array of sample weights. comparison: Comparison operator for thresholding: ">" (exclusive) or ">=" (inclusive). Returns ------- float Optimal threshold that maximizes the metric. """ # Check if we have a vectorized implementation if has_vectorized_implementation(metric): try: vectorized_metric = get_vectorized_metric(metric) threshold, _, _ = optimal_threshold_sortscan( true_labs, pred_prob, vectorized_metric, sample_weight=sample_weight, inclusive=comparison, ) return threshold except Exception: # Fall back to original implementation if vectorized fails pass # Fall back to original implementation return _optimal_threshold_piecewise_fallback( true_labs, pred_prob, metric, sample_weight, comparison ) def _optimal_threshold_piecewise_fallback( true_labs: ArrayLike, pred_prob: ArrayLike, metric: str = "f1", sample_weight: ArrayLike | None = None, comparison: ComparisonOperator = ">", ) -> float: """Fallback implementation for metrics not yet vectorized. This is the original O(k log n) implementation that evaluates at unique probabilities. """ true_labs = np.asarray(true_labs) pred_prob = np.asarray(pred_prob) if len(true_labs) == 0: raise ValueError("true_labs cannot be empty") if len(true_labs) != len(pred_prob): raise ValueError( f"Length mismatch: true_labs ({len(true_labs)}) vs " f"pred_prob ({len(pred_prob)})" ) # Get metric function try: metric_func = METRIC_REGISTRY[metric] except KeyError as exc: raise ValueError(f"Unknown metric: {metric}") from exc # Handle edge case: single prediction if len(pred_prob) == 1: return float(pred_prob[0]) # Sort by predicted probability in descending order for efficiency sort_idx = np.argsort(-pred_prob) sorted_probs = pred_prob[sort_idx] sorted_labels = true_labs[sort_idx] # Handle sample weights if sample_weight is not None: sample_weight = np.asarray(sample_weight) if len(sample_weight) != len(true_labs): raise ValueError( f"Length mismatch: sample_weight ({len(sample_weight)}) vs " f"true_labs ({len(true_labs)})" ) # Sort weights along with labels and probabilities weights_sorted = sample_weight[sort_idx] else: weights_sorted = np.ones(len(true_labs)) # Compute total positives and negatives (weighted) P = float(np.sum(weights_sorted * sorted_labels)) N = float(np.sum(weights_sorted * (1 - sorted_labels))) # Handle edge case: all same class if P == 0 or N == 0: return 0.5 # Find unique probabilities to use as threshold candidates unique_probs = np.unique(pred_prob) best_score = -np.inf best_threshold = 0.5 # Cumulative sums for TP and FP (weighted) cum_tp = np.cumsum(weights_sorted * sorted_labels) cum_fp = np.cumsum(weights_sorted * (1 - sorted_labels)) # Evaluate at each unique threshold for threshold in unique_probs: # Find position where we switch from positive to negative predictions # Apply comparison operator for thresholding if comparison == ">": pos_mask = sorted_probs > threshold else: # ">=" pos_mask = sorted_probs >= threshold if np.any(pos_mask): # Find last position where condition is satisfied last_pos = np.where(pos_mask)[0][-1] tp = float(cum_tp[last_pos]) fp = float(cum_fp[last_pos]) else: # No predictions above threshold -> all negative tp = fp = 0.0 fn = P - tp tn = N - fp # Compute metric score score = float(metric_func(int(tp), int(tn), int(fp), int(fn))) if score > best_score: best_score = score best_threshold = threshold return float(best_threshold)
[docs] def get_optimal_threshold( true_labs: ArrayLike, pred_prob: ArrayLike, metric: str = "f1", method: OptimizationMethod = "auto", sample_weight: ArrayLike | None = None, comparison: ComparisonOperator = ">", ) -> float | np.ndarray: """Find the threshold that optimizes a metric. Parameters ---------- true_labs: Array of true binary labels or multiclass labels (0, 1, 2, ..., n_classes-1). pred_prob: Predicted probabilities from a classifier. For binary: 1D array (n_samples,). For multiclass: 2D array (n_samples, n_classes). metric: Name of a metric registered in :data:`~optimal_cutoffs.metrics.METRIC_REGISTRY`. method: Strategy used for optimization: - ``"auto"``: Automatically selects best method (default) - ``"sort_scan"``: O(n log n) algorithm for piecewise metrics with vectorized implementation - ``"smart_brute"``: Evaluates all unique probabilities - ``"minimize"``: Uses ``scipy.optimize.minimize_scalar`` - ``"gradient"``: Simple gradient ascent sample_weight: Optional array of sample weights for handling imbalanced datasets. comparison: Comparison operator for thresholding: ">" (exclusive) or ">=" (inclusive). Returns ------- float | np.ndarray For binary: The threshold that maximizes the chosen metric. For multiclass: Array of per-class thresholds. """ # Validate inputs true_labs, pred_prob, sample_weight = _validate_inputs( true_labs, pred_prob, sample_weight=sample_weight ) _validate_metric_name(metric) _validate_optimization_method(method) _validate_comparison_operator(comparison) # Check if this is multiclass if pred_prob.ndim == 2: return get_optimal_multiclass_thresholds( true_labs, pred_prob, metric, method, average="macro", sample_weight=sample_weight, comparison=comparison, ) # Binary case - implement method routing with auto detection if method == "auto": # Auto routing: prefer sort_scan for piecewise metrics with vectorized # implementation if is_piecewise_metric(metric) and has_vectorized_implementation(metric): method = "sort_scan" else: method = "smart_brute" if method == "sort_scan": # Use O(n log n) sort-and-scan optimization for vectorized piecewise metrics if not has_vectorized_implementation(metric): raise ValueError( f"sort_scan method requires vectorized implementation for " f"metric '{metric}'" ) vectorized_metric = get_vectorized_metric(metric) threshold, _, _ = optimal_threshold_sortscan( true_labs, pred_prob, vectorized_metric, sample_weight=sample_weight, inclusive=comparison, ) return threshold if method == "smart_brute": # Use fast piecewise optimization for piecewise-constant metrics if is_piecewise_metric(metric): return _optimal_threshold_piecewise( true_labs, pred_prob, metric, sample_weight, comparison ) else: # Fall back to original brute force for non-piecewise metrics thresholds = np.unique(pred_prob) scores = [ _metric_score( true_labs, pred_prob, t, metric, sample_weight, comparison ) for t in thresholds ] return float(thresholds[int(np.argmax(scores))]) if method == "minimize": res = optimize.minimize_scalar( lambda t: -_metric_score( true_labs, pred_prob, t, metric, sample_weight, comparison ), bounds=(0, 1), method="bounded", ) # ``minimize_scalar`` may return a threshold that is suboptimal for # piecewise-constant metrics like F1. To provide a more robust # solution, use the same enhanced candidate generation as smart_brute. if is_piecewise_metric(metric) and has_vectorized_implementation(metric): # For piecewise metrics, use the same optimal threshold as smart_brute piecewise_threshold = _optimal_threshold_piecewise( true_labs, pred_prob, metric, sample_weight, comparison ) # Compare scipy result with piecewise result scipy_score = _metric_score( true_labs, pred_prob, res.x, metric, sample_weight, comparison ) piecewise_score = _metric_score( true_labs, pred_prob, piecewise_threshold, metric, sample_weight, comparison, ) if piecewise_score >= scipy_score: return float(piecewise_threshold) else: return float(res.x) else: # Fall back to original candidate evaluation for non-piecewise metrics candidates = np.unique(np.append(pred_prob, res.x)) scores = [ _metric_score( true_labs, pred_prob, t, metric, sample_weight, comparison ) for t in candidates ] return float(candidates[int(np.argmax(scores))]) if method == "gradient": threshold = 0.5 lr = 0.1 eps = 1e-5 for _ in range(100): # Ensure evaluation points are within bounds thresh_plus = np.clip(threshold + eps, 0.0, 1.0) thresh_minus = np.clip(threshold - eps, 0.0, 1.0) grad = ( _metric_score( true_labs, pred_prob, thresh_plus, metric, sample_weight, comparison ) - _metric_score( true_labs, pred_prob, thresh_minus, metric, sample_weight, comparison, ) ) / (2 * eps) threshold = np.clip(threshold + lr * grad, 0.0, 1.0) # Final safety clip to ensure numerical precision doesn't cause issues return float(np.clip(threshold, 0.0, 1.0)) raise ValueError(f"Unknown method: {method}")
[docs] def get_optimal_multiclass_thresholds( true_labs: ArrayLike, pred_prob: ArrayLike, metric: str = "f1", method: OptimizationMethod = "auto", average: AveragingMethod = "macro", sample_weight: ArrayLike | None = None, vectorized: bool = False, comparison: ComparisonOperator = ">", ) -> np.ndarray | float: """Find optimal per-class thresholds for multiclass classification using One-vs-Rest. Parameters ---------- true_labs: Array of true class labels (0, 1, 2, ..., n_classes-1). pred_prob: Array of predicted probabilities with shape (n_samples, n_classes). metric: Name of a metric registered in :data:`~optimal_cutoffs.metrics.METRIC_REGISTRY`. method: Strategy used for optimization: - ``"auto"``: Automatically selects best method (default) - ``"sort_scan"``: O(n log n) algorithm for piecewise metrics with vectorized implementation - ``"smart_brute"``: Evaluates all unique probabilities - ``"minimize"``: Uses ``scipy.optimize.minimize_scalar`` - ``"gradient"``: Simple gradient ascent - ``"coord_ascent"``: Coordinate ascent for coupled multiclass optimization (single-label consistent) average: Averaging strategy that affects optimization: - "macro"/"none": Optimize each class independently (default behavior) - "micro": Optimize to maximize micro-averaged metric across all classes - "weighted": Optimize each class independently, same as macro sample_weight: Optional array of sample weights for handling imbalanced datasets. vectorized: If True, use vectorized implementation for better performance when possible. comparison: Comparison operator for thresholding: ">" (exclusive) or ">=" (inclusive). Returns ------- np.ndarray | float For "macro"/"weighted"/"none": Array of optimal thresholds, one per class. For "micro" with single threshold strategy: Single optimal threshold. """ true_labs = np.asarray(true_labs) pred_prob = np.asarray(pred_prob) # Input validation if len(true_labs) == 0: raise ValueError("true_labs cannot be empty") if pred_prob.ndim != 2: raise ValueError(f"pred_prob must be 2D for multiclass, got {pred_prob.ndim}D") if len(true_labs) != pred_prob.shape[0]: raise ValueError( f"Length mismatch: true_labs ({len(true_labs)}) vs " f"pred_prob ({pred_prob.shape[0]})" ) if np.any(np.isnan(pred_prob)) or np.any(np.isinf(pred_prob)): raise ValueError("pred_prob contains NaN or infinite values") # Check class labels are valid for One-vs-Rest unique_labels = np.unique(true_labs) expected_labels = np.arange(len(unique_labels)) if not np.array_equal(np.sort(unique_labels), expected_labels): raise ValueError( f"Class labels must be consecutive integers starting from 0. " f"Got {unique_labels}, expected {expected_labels}" ) n_classes = pred_prob.shape[1] if average == "micro": # For micro-averaging, we can either: # 1. Pool all OvR problems and optimize a single threshold # 2. Optimize per-class thresholds to maximize micro-averaged metric # We implement approach 2 for more flexibility return _optimize_micro_averaged_thresholds( true_labs, pred_prob, metric, method, sample_weight, vectorized, comparison ) elif method == "coord_ascent": # Coordinate ascent for coupled multiclass optimization if sample_weight is not None: raise NotImplementedError( "coord_ascent method does not yet support sample weights" ) if comparison != ">": raise NotImplementedError( "coord_ascent method currently only supports '>' comparison" ) if metric != "f1": raise NotImplementedError( "coord_ascent method currently only supports F1 metric" ) # Use vectorized F1 metric for sort-scan initialization if has_vectorized_implementation(metric): vectorized_metric = get_vectorized_metric(metric) else: raise ValueError( f"coord_ascent requires vectorized implementation for metric '{metric}'" ) tau, _, _ = optimal_multiclass_thresholds_coord_ascent( true_labs, pred_prob, sortscan_metric_fn=vectorized_metric, sortscan_kernel=optimal_threshold_sortscan, max_iter=20, init="ovr_sortscan", tol_stops=1, ) return tau else: # For macro, weighted, none: optimize each class independently if vectorized and method == "smart_brute" and is_piecewise_metric(metric): return _optimize_thresholds_vectorized( true_labs, pred_prob, metric, sample_weight, comparison ) else: # Standard per-class optimization optimal_thresholds = np.zeros(n_classes) for class_idx in range(n_classes): # One-vs-Rest: current class vs all others true_binary = (true_labs == class_idx).astype(int) pred_binary_prob = pred_prob[:, class_idx] # Optimize threshold for this class optimal_thresholds[class_idx] = get_optimal_threshold( true_binary, pred_binary_prob, metric, method, sample_weight, comparison, ) return optimal_thresholds
def _optimize_micro_averaged_thresholds( true_labs: ArrayLike, pred_prob: ArrayLike, metric: str, method: OptimizationMethod, sample_weight: ArrayLike | None, vectorized: bool, comparison: ComparisonOperator = ">", ) -> np.ndarray: """Optimize thresholds to maximize micro-averaged metric. For micro-averaging, we optimize per-class thresholds jointly to maximize the micro-averaged metric score across all classes. """ true_labs = np.asarray(true_labs) pred_prob = np.asarray(pred_prob) n_classes = pred_prob.shape[1] def objective(thresholds): """Objective function: negative micro-averaged metric.""" cms = get_multiclass_confusion_matrix( true_labs, pred_prob, thresholds, sample_weight, comparison ) score = multiclass_metric(cms, metric, "micro") return -float(score) if method == "smart_brute": # For micro-averaging with smart_brute, we need to search over combinations # of thresholds. Start with independent optimization as initial guess. initial_thresholds = np.zeros(n_classes) for class_idx in range(n_classes): true_binary = (true_labs == class_idx).astype(int) pred_binary_prob = pred_prob[:, class_idx] initial_thresholds[class_idx] = get_optimal_threshold( true_binary, pred_binary_prob, metric, "smart_brute", sample_weight, comparison, ) # For now, return the independent optimization result # TODO: Implement joint optimization for micro-averaging return initial_thresholds elif method in ["minimize", "gradient"]: # Use scipy optimization for joint threshold optimization from scipy.optimize import minimize # Initial guess: independent optimization per class initial_guess = np.zeros(n_classes) for class_idx in range(n_classes): true_binary = (true_labs == class_idx).astype(int) pred_binary_prob = pred_prob[:, class_idx] initial_guess[class_idx] = get_optimal_threshold( true_binary, pred_binary_prob, metric, "minimize", sample_weight, comparison, ) # Joint optimization result = minimize( objective, initial_guess, method="L-BFGS-B", bounds=[(0, 1) for _ in range(n_classes)], ) return result.x else: raise ValueError(f"Unknown method: {method}") def _optimize_thresholds_vectorized( true_labs: ArrayLike, pred_prob: ArrayLike, metric: str, sample_weight: ArrayLike | None, comparison: ComparisonOperator = ">", ) -> np.ndarray: """Vectorized optimization for piecewise metrics. This function vectorizes the piecewise threshold optimization across all classes for better performance. """ true_labs = np.asarray(true_labs) pred_prob = np.asarray(pred_prob) n_samples, n_classes = pred_prob.shape # Create binary labels for all classes at once: (n_samples, n_classes) true_binary_all = (true_labs[:, None] == np.arange(n_classes)).astype(int) optimal_thresholds = np.zeros(n_classes) # For now, fall back to per-class optimization # TODO: Implement fully vectorized version for class_idx in range(n_classes): optimal_thresholds[class_idx] = _optimal_threshold_piecewise( true_binary_all[:, class_idx], pred_prob[:, class_idx], metric, sample_weight, comparison, ) return optimal_thresholds __all__ = [ "get_probability", "get_optimal_threshold", "get_optimal_multiclass_thresholds", ]