Source code for optimal_cutoffs.cv

"""Cross-validation helpers for threshold optimization."""

from typing import Any

import numpy as np
from numpy.typing import ArrayLike
from sklearn.model_selection import (
    KFold,  # type: ignore[import-untyped]
    StratifiedKFold,
)

from .core import get_optimal_threshold
from .metrics import (
    METRICS,
    confusion_matrix_at_threshold,
    multiclass_confusion_matrices_at_thresholds,
    multiclass_metric_ovr,
    multiclass_metric_single_label,
)
from .validation import (
    _validate_averaging_method,
    _validate_comparison_operator,
    _validate_metric_name,
    _validate_optimization_method,
)


[docs] def cv_threshold_optimization( true_labs: ArrayLike, pred_prob: ArrayLike, metric: str = "f1", method: str = "auto", cv: int | Any = 5, random_state: int | None = None, sample_weight: ArrayLike | None = None, *, comparison: str = ">", average: str = "macro", **opt_kwargs: Any, ) -> tuple[np.ndarray[Any, Any], np.ndarray[Any, Any]]: """Estimate optimal threshold(s) using cross-validation. Supports both binary and multiclass classification with proper handling of all threshold return formats (scalar, array, dict from expected mode). Uses StratifiedKFold by default for better class balance preservation. Parameters ---------- true_labs : ArrayLike Array of true labels (binary or multiclass). pred_prob : ArrayLike Predicted probabilities. For binary: 1D array. For multiclass: 2D array. metric : str, default="f1" Metric name to optimize; must exist in the metric registry. method : OptimizationMethod, default="auto" Optimization strategy passed to get_optimal_threshold. cv : int or cross-validator, default=5 Number of folds or custom cross-validator object. random_state : int, optional Seed for the cross-validator shuffling. sample_weight : ArrayLike, optional Sample weights for handling imbalanced datasets. comparison : ComparisonOperator, default=">" Comparison operator for threshold application. average : str, default="macro" Averaging strategy for multiclass metrics. **opt_kwargs : Any Additional arguments passed to get_optimal_threshold. Returns ------- tuple[np.ndarray[Any, Any], np.ndarray[Any, Any]] Arrays of per-fold thresholds and scores. """ # Validate parameters early for better user experience _validate_metric_name(metric) _validate_comparison_operator(comparison) _validate_averaging_method(average) _validate_optimization_method(method) true_labs = np.asarray(true_labs) pred_prob = np.asarray(pred_prob) if sample_weight is not None: sample_weight = np.asarray(sample_weight) # Choose splitter: stratify by default for classification when possible if hasattr(cv, "split"): splitter = cv # custom splitter provided else: n_splits = int(cv) if true_labs.ndim == 1 and np.unique(true_labs).size > 1: splitter = StratifiedKFold( n_splits=n_splits, shuffle=True, random_state=random_state ) else: splitter = KFold(n_splits=n_splits, shuffle=True, random_state=random_state) thresholds = [] scores = [] for train_idx, test_idx in splitter.split(true_labs, true_labs): # Extract training data and weights train_weights = None if sample_weight is None else sample_weight[train_idx] test_weights = None if sample_weight is None else sample_weight[test_idx] result = get_optimal_threshold( true_labs[train_idx], pred_prob[train_idx], metric=metric, method=method, sample_weight=train_weights, comparison=comparison, average=average, **opt_kwargs, ) thr = _extract_thresholds(result) thresholds.append(thr) scores.append( _evaluate_threshold_on_fold( true_labs[test_idx], pred_prob[test_idx], thr, metric=metric, average=average, sample_weight=test_weights, comparison=comparison, ) ) return np.array(thresholds, dtype=object), np.array(scores, dtype=float)
[docs] def nested_cv_threshold_optimization( true_labs: ArrayLike, pred_prob: ArrayLike, metric: str = "f1", method: str = "auto", inner_cv: int = 5, outer_cv: int = 5, random_state: int | None = None, sample_weight: ArrayLike | None = None, *, comparison: str = ">", average: str = "macro", **opt_kwargs: Any, ) -> tuple[np.ndarray[Any, Any], np.ndarray[Any, Any]]: """Nested cross-validation for unbiased threshold optimization. Inner CV estimates robust thresholds by averaging across folds, outer CV evaluates performance. Uses StratifiedKFold by default for better class balance. The threshold selection uses statistically sound averaging rather than cherry-picking the best-performing fold. Parameters ---------- true_labs : ArrayLike Array of true labels (binary or multiclass). pred_prob : ArrayLike Predicted probabilities. For binary: 1D array. For multiclass: 2D array. metric : str, default="f1" Metric name to optimize. method : OptimizationMethod, default="auto" Optimization strategy passed to get_optimal_threshold. inner_cv : int, default=5 Number of folds in the inner loop used to estimate thresholds. outer_cv : int, default=5 Number of outer folds for unbiased performance assessment. random_state : int, optional Seed for the cross-validators. sample_weight : ArrayLike, optional Sample weights for handling imbalanced datasets. comparison : ComparisonOperator, default=">" Comparison operator for threshold application. average : str, default="macro" Averaging strategy for multiclass metrics. **opt_kwargs : Any Additional arguments passed to get_optimal_threshold. Returns ------- tuple[np.ndarray[Any, Any], np.ndarray[Any, Any]] Arrays of outer-fold thresholds and scores. """ # Validate parameters early for better user experience _validate_metric_name(metric) _validate_comparison_operator(comparison) _validate_averaging_method(average) _validate_optimization_method(method) true_labs = np.asarray(true_labs) pred_prob = np.asarray(pred_prob) if sample_weight is not None: sample_weight = np.asarray(sample_weight) # stratify in outer loop when possible if true_labs.ndim == 1 and np.unique(true_labs).size > 1: outer = StratifiedKFold( n_splits=outer_cv, shuffle=True, random_state=random_state ) else: outer = KFold(n_splits=outer_cv, shuffle=True, random_state=random_state) outer_thresholds = [] outer_scores = [] for train_idx, test_idx in outer.split(true_labs, true_labs): # Extract training and test data with weights train_weights = None if sample_weight is None else sample_weight[train_idx] test_weights = None if sample_weight is None else sample_weight[test_idx] inner_thresholds, inner_scores = cv_threshold_optimization( true_labs[train_idx], pred_prob[train_idx], metric=metric, method=method, cv=inner_cv, random_state=random_state, sample_weight=train_weights, comparison=comparison, average=average, **opt_kwargs, ) # Average thresholds across inner folds for robust estimate # This is statistically sound - considers all folds rather than cherry-picking if isinstance(inner_thresholds[0], float | np.floating): # Binary case: simple averaging thr = float(np.mean(inner_thresholds)) elif isinstance(inner_thresholds[0], np.ndarray): # Multiclass: average each class threshold thr = np.mean(np.vstack(inner_thresholds), axis=0) elif isinstance(inner_thresholds[0], dict): # Dict format (e.g., from expected mode) thr = _average_threshold_dicts(inner_thresholds) else: # Fallback: try converting to array and averaging try: thr = np.mean(np.array(inner_thresholds)) except (ValueError, TypeError): # If averaging fails, use mean score to select representative threshold mean_score = np.mean(inner_scores) closest_idx = np.argmin(np.abs(inner_scores - mean_score)) thr = inner_thresholds[closest_idx] outer_thresholds.append(thr) score = _evaluate_threshold_on_fold( true_labs[test_idx], pred_prob[test_idx], thr, metric=metric, average=average, sample_weight=test_weights, comparison=comparison, ) outer_scores.append(score) return np.array(outer_thresholds, dtype=object), np.array(outer_scores, dtype=float)
# -------------------- helpers -------------------- def _extract_thresholds(thr_result: Any) -> Any: """Extract thresholds from OptimizationResult objects. Now primarily handles OptimizationResult objects since get_optimal_threshold returns unified OptimizationResult. Maintains backward compatibility for legacy formats. """ from .types_minimal import OptimizationResult # OptimizationResult (new unified format) if isinstance(thr_result, OptimizationResult): return thr_result.thresholds # Legacy formats for backward compatibility # (thr, score) if isinstance(thr_result, tuple) and len(thr_result) == 2: return thr_result[0] # dict from expected/micro or macro/weighted if isinstance(thr_result, dict): if "thresholds" in thr_result: return thr_result["thresholds"] if "threshold" in thr_result: return thr_result["threshold"] # Bayes with decisions has no thresholds; raise clearly if "decisions" in thr_result: raise ValueError("Bayes decisions cannot be used for threshold CV scoring.") return thr_result def _average_threshold_dicts(threshold_dicts: list[dict[str, Any]]) -> dict[str, Any]: """Average dictionary-based thresholds from multiple CV folds. Parameters ---------- threshold_dicts : list[dict[str, Any]] List of threshold dictionaries from inner CV folds Returns ------- dict[str, Any] Averaged threshold dictionary with same structure as input """ if not threshold_dicts: raise ValueError("Cannot average empty list of threshold dictionaries") # Check for consistent structure first_dict = threshold_dicts[0] for i, d in enumerate(threshold_dicts[1:], 1): if set(d.keys()) != set(first_dict.keys()): raise ValueError(f"Inconsistent dict keys between folds 0 and {i}") result = {} # Average numerical values for key in first_dict: values = [d[key] for d in threshold_dicts] if key in ("threshold", "thresholds"): # These are the actual threshold values to average if isinstance(values[0], float | np.floating): result[key] = float(np.mean(values)) elif isinstance(values[0], np.ndarray): result[key] = np.mean(np.vstack(values), axis=0) else: # Try to convert to array and average result[key] = np.mean(np.array(values)) elif key == "score" or key.endswith("_score"): # Don't average scores - they're fold-specific performance # Use mean as representative value result[key] = float(np.mean(values)) else: # For other keys (like per_class arrays), average if numeric try: if isinstance(values[0], np.ndarray): result[key] = np.mean(np.vstack(values), axis=0) elif isinstance(values[0], int | float | np.number): result[key] = float(np.mean(values)) else: # Non-numeric data - keep first fold's value result[key] = values[0] except (TypeError, ValueError): # If averaging fails, keep first fold's value result[key] = values[0] return result def _evaluate_threshold_on_fold( y_true: ArrayLike, pred_prob: ArrayLike, thr: Any, *, metric: str, average: str, sample_weight: ArrayLike | None, comparison: str, ) -> float: """Compute the chosen metric on the test fold for a given threshold object.""" y_true = np.asarray(y_true) pred_prob = np.asarray(pred_prob) sw = None if sample_weight is None else np.asarray(sample_weight) if pred_prob.ndim == 1: # scalar threshold required if isinstance(thr, dict): t = float(thr.get("threshold", thr)) else: # Handle both scalar and array cases thr_array = np.asarray(thr) t = ( float(thr_array.item()) if thr_array.ndim == 0 else float(thr_array.flat[0]) ) tp, tn, fp, fn = confusion_matrix_at_threshold( y_true, pred_prob, t, sample_weight=sw, comparison=comparison ) # Metric validation happens early in CV functions - no need to validate again metric_fn = METRICS[metric].fn return float(metric_fn(tp, tn, fp, fn)) # Multiclass / multilabel (n, K) K = pred_prob.shape[1] if isinstance(thr, dict): if "thresholds" in thr: thresholds = np.asarray(thr["thresholds"], dtype=float) elif "threshold" in thr: # micro: single global threshold – broadcast per class thresholds = np.full(K, float(thr["threshold"]), dtype=float) else: raise ValueError("Unexpected threshold dict shape for multiclass.") elif np.isscalar(thr): thresholds = np.full(K, float(thr), dtype=float) # type: ignore[arg-type] else: thresholds = np.asarray(thr, dtype=float) if thresholds.shape != (K,): raise ValueError( f"Per-class thresholds must have shape ({K},), got {thresholds.shape}." ) if metric == "accuracy": # Exclusive accuracy uses the margin-based single-label decision rule return float( multiclass_metric_single_label( y_true, pred_prob, thresholds, "accuracy", comparison, sw ) ) cms = multiclass_confusion_matrices_at_thresholds( y_true, pred_prob, thresholds, sample_weight=sw, comparison=comparison ) return float(multiclass_metric_ovr(cms, metric, average))