"""High-level wrapper for threshold optimization."""
from typing import Self
import numpy as np
from .multiclass_coord import _assign_labels_shifted
from .optimizers import get_optimal_threshold, get_probability
from .types import ArrayLike, ComparisonOperator, OptimizationMethod, SampleWeightLike
[docs]
class ThresholdOptimizer:
"""Optimizer for classification thresholds supporting both binary and multiclass.
The class wraps threshold optimization functions and exposes a scikit-learn
style ``fit``/``predict`` API. For multiclass, uses One-vs-Rest strategy.
"""
[docs]
def __init__(
self,
objective: str = "accuracy",
verbose: bool = False,
method: OptimizationMethod = "auto",
comparison: ComparisonOperator = ">",
) -> None:
"""Create a new optimizer.
Parameters
----------
objective:
Metric to optimize, e.g. ``"accuracy"``, ``"f1"``, ``"precision"``,
``"recall"``.
verbose:
If ``True``, print progress during threshold search.
method:
Optimization method:
- ``"auto"``: Automatically selects best method (default)
- ``"sort_scan"``: O(n log n) algorithm for piecewise metrics with
vectorized implementation
- ``"smart_brute"``: Evaluates all unique probabilities
- ``"minimize"``: Uses ``scipy.optimize.minimize_scalar``
- ``"gradient"``: Simple gradient ascent
- ``"coord_ascent"``: Coordinate ascent for coupled multiclass
optimization (single-label consistent)
comparison:
Comparison operator for thresholding: ">" (exclusive) or ">=" (inclusive).
"""
self.objective = objective
self.verbose = verbose
self.method = method
self.comparison = comparison
self.threshold_: float | np.ndarray | None = None
self.is_multiclass_: bool = False
[docs]
def fit(
self,
true_labs: ArrayLike,
pred_prob: ArrayLike,
sample_weight: SampleWeightLike = None,
) -> Self:
"""Estimate the optimal threshold(s) from labeled data.
Parameters
----------
true_labs:
Array of true labels. For binary: (0, 1). For multiclass:
(0, 1, 2, ..., n_classes-1).
pred_prob:
Predicted probabilities from a classifier. For binary: 1D array
(n_samples,).
For multiclass: 2D array (n_samples, n_classes).
sample_weight:
Optional array of sample weights for handling imbalanced datasets.
Returns
-------
Self
Fitted instance with ``threshold_`` attribute set.
"""
pred_prob = np.asarray(pred_prob)
# Check if multiclass
self.is_multiclass_ = pred_prob.ndim == 2
if (
self.is_multiclass_
or self.objective not in ["accuracy", "f1"]
or sample_weight is not None
):
# Use the more general optimizer
self.threshold_ = get_optimal_threshold(
true_labs,
pred_prob,
self.objective,
self.method,
sample_weight,
self.comparison,
)
else:
# Use legacy optimizer for backward compatibility (only when no sample
# weights)
self.threshold_ = get_probability(
true_labs, pred_prob, self.objective, self.verbose
)
return self
[docs]
def predict(self, pred_prob: ArrayLike) -> np.ndarray:
"""Convert probabilities to class predictions using the learned threshold(s).
Parameters
----------
pred_prob:
Array of predicted probabilities to be thresholded.
Returns
-------
np.ndarray
For binary: Boolean array of predicted class labels.
For multiclass: Integer array of predicted class labels.
"""
if self.threshold_ is None:
raise RuntimeError("ThresholdOptimizer has not been fitted.")
pred_prob = np.asarray(pred_prob)
if self.is_multiclass_:
# Multiclass prediction strategy depends on optimization method
if self.method == "coord_ascent":
# Coordinate ascent uses argmax(P - tau) for single-label consistency
return _assign_labels_shifted(pred_prob, self.threshold_)
else:
# One-vs-Rest prediction using per-class thresholds
n_samples, n_classes = pred_prob.shape
if self.comparison == ">":
binary_predictions = pred_prob > self.threshold_
else: # ">="
binary_predictions = pred_prob >= self.threshold_
# For each sample, predict the class with highest probability among
# those above threshold
# If no classes above threshold, predict the class with highest
# probability
predictions = np.zeros(n_samples, dtype=int)
for i in range(n_samples):
above_threshold = np.where(binary_predictions[i])[0]
if len(above_threshold) > 0:
# Among classes above threshold, pick the one with highest
# probability
predictions[i] = above_threshold[
np.argmax(pred_prob[i, above_threshold])
]
else:
# No class above threshold, pick highest probability class
predictions[i] = np.argmax(pred_prob[i])
return predictions
else:
# Binary prediction
if self.comparison == ">":
return pred_prob > self.threshold_
else: # ">="
return pred_prob >= self.threshold_