Source code for optimal_cutoffs.multiclass

"""Multi-class classification threshold optimization.

This module implements threshold optimization for multi-class classification
where we have K mutually exclusive classes and must predict exactly one class.

Key approaches:
1. OvR Independent: Treat each class as independent binary (multi-label style)
2. Margin Rule: Use argmax(p_j - τ_j) for coupled single-label predictions
3. Micro averaging: Single threshold applied to all classes

The margin rule is Bayes-optimal when costs have OvR structure but requires
coordinate ascent for general metrics like F1.
"""

from __future__ import annotations

import numpy as np
from numpy.typing import ArrayLike

from .core import OptimizationResult, Task
from .validation import validate_multiclass_classification


[docs] def optimize_ovr_independent( true_labels: ArrayLike, pred_proba: ArrayLike, *, metric: str = "f1", method: str = "auto", sample_weight: ArrayLike | None = None, comparison: str = ">", tolerance: float = 1e-10, ) -> OptimizationResult: """Optimize multiclass metrics using independent per-class thresholds (OvR). Treats each class as an independent binary problem (class vs rest). This does NOT enforce single-label predictions - can predict 0, 1, or multiple classes. Use this for macro-averaged metrics when you want exact optimization per class. Decision rule: ŷ_j = 1 if p_j ≥ τ_j (independent for each class) Parameters ---------- true_labels : array-like of shape (n_samples,) True class labels in {0, 1, ..., K-1} pred_proba : array-like of shape (n_samples, n_classes) Predicted probabilities for each class metric : str, default="f1" Metric to optimize per class method : str, default="auto" Binary optimization method sample_weight : array-like of shape (n_samples,), optional Sample weights comparison : str, default=">" Comparison operator tolerance : float, default=1e-10 Numerical tolerance Returns ------- OptimizationResult Result with per-class thresholds optimized independently Examples -------- >>> y_true = [0, 1, 2, 0, 1] >>> y_prob = [[0.7, 0.2, 0.1], [0.1, 0.8, 0.1], [0.1, 0.1, 0.8], ...] >>> result = optimize_ovr_independent(y_true, y_prob, metric="f1") >>> predictions = result.predict(y_prob) # Can predict multiple classes """ from .binary import optimize_metric_binary # Validate inputs true_labels, pred_proba, sample_weight = validate_multiclass_classification( true_labels, pred_proba, sample_weight, require_proba=True ) n_samples, n_classes = pred_proba.shape # Optimize each class vs rest independently optimal_thresholds = np.zeros(n_classes, dtype=np.float64) optimal_scores = np.zeros(n_classes, dtype=np.float64) for k in range(n_classes): # Create binary problem: class k vs rest y_true_k = (true_labels == k).astype(int) y_prob_k = pred_proba[:, k] # Optimize threshold for class k result_k = optimize_metric_binary( y_true_k, y_prob_k, metric=metric, method=method, sample_weight=sample_weight, comparison=comparison, tolerance=tolerance, ) optimal_thresholds[k] = result_k.thresholds[0] optimal_scores[k] = result_k.scores[0] # Macro average score macro_score = np.mean(optimal_scores) def predict_multiclass_independent(probs: ArrayLike) -> np.ndarray: """Independent per-class predictions (can predict multiple classes).""" p = np.asarray(probs, dtype=np.float64) if p.ndim != 2 or p.shape[1] != n_classes: raise ValueError(f"Expected probabilities shape (n_samples, {n_classes})") if comparison == ">=": predictions = (p >= optimal_thresholds[None, :]).astype(np.int32) else: predictions = (p > optimal_thresholds[None, :]).astype(np.int32) return predictions return OptimizationResult( thresholds=optimal_thresholds, scores=np.array([macro_score]), predict=predict_multiclass_independent, task=Task.MULTICLASS, metric=f"macro_{metric}_ovr_independent", n_classes=n_classes, )
[docs] def optimize_ovr_margin( true_labels: ArrayLike, pred_proba: ArrayLike, *, metric: str = "f1", max_iter: int = 30, sample_weight: ArrayLike | None = None, comparison: str = ">", tolerance: float = 1e-12, ) -> OptimizationResult: """Optimize multiclass metrics using margin rule with coordinate ascent. Uses margin-based prediction: ŷ = argmax_j (p_j - τ_j) This ensures exactly one class is predicted per sample (single-label). Thresholds are coupled because changing τ_j affects which samples are assigned to class j, which affects confusion matrices for all classes. Uses coordinate ascent to find local optimum. Parameters ---------- true_labels : array-like of shape (n_samples,) True class labels in {0, 1, ..., K-1} pred_proba : array-like of shape (n_samples, n_classes) Predicted probabilities for each class metric : str, default="f1" Metric to optimize (currently supports "f1" only) max_iter : int, default=30 Maximum coordinate ascent iterations sample_weight : array-like of shape (n_samples,), optional Sample weights comparison : str, default=">" Comparison operator (only ">" supported for margin rule) tolerance : float, default=1e-12 Convergence tolerance Returns ------- OptimizationResult Result with per-class thresholds optimized via coordinate ascent Examples -------- >>> result = optimize_ovr_margin(y_true, y_prob, metric="f1") >>> predictions = result.predict(y_prob) # Exactly one class per sample Notes ----- The margin rule is Bayes-optimal when costs have OvR structure: C(i,j) = -r_j if i=j, else c_j In this case, optimal thresholds are: τ_j = c_j/(c_j + r_j) (closed form!) """ from .optimize import coordinate_ascent_kernel # Validate inputs true_labels, pred_proba, sample_weight = validate_multiclass_classification( true_labels, pred_proba, sample_weight, require_proba=True ) n_samples, n_classes = pred_proba.shape if metric != "f1": raise NotImplementedError("supports 'f1' metric only") if comparison != ">": raise NotImplementedError("'>' is required") # Prepare data for coordinate ascent kernel true_labels_int32 = np.asarray(true_labels, dtype=np.int32) pred_proba_float64 = np.asarray(pred_proba, dtype=np.float64, order="C") weights = ( None if sample_weight is None else np.asarray(sample_weight, dtype=np.float64) ) # Run coordinate ascent thresholds, best_score, history = coordinate_ascent_kernel( true_labels_int32, pred_proba_float64, weights, max_iter=max_iter, tol=tolerance, ) def predict_multiclass_margin(probs: ArrayLike) -> np.ndarray: """Margin-based prediction: argmax(p_j - τ_j).""" p = np.asarray(probs, dtype=np.float64) if p.ndim != 2 or p.shape[1] != n_classes: raise ValueError(f"Expected probabilities shape (n_samples, {n_classes})") # Compute margins and predict class with highest margin margins = p - thresholds[None, :] predictions = np.argmax(margins, axis=1).astype(np.int32) return predictions return OptimizationResult( thresholds=thresholds.astype(np.float64), scores=np.array([best_score]), predict=predict_multiclass_margin, task=Task.MULTICLASS, metric=f"macro_{metric}_margin_rule", n_classes=n_classes, )
[docs] def optimize_micro_multiclass( true_labels: ArrayLike, pred_proba: ArrayLike, *, metric: str = "f1", method: str = "auto", sample_weight: ArrayLike | None = None, comparison: str = ">", tolerance: float = 1e-10, ) -> OptimizationResult: """Optimize micro-averaged multiclass metrics using single threshold. For micro averaging, we use a single threshold applied to all classes, then predict the class with highest valid probability. This reduces to a single binary optimization problem on flattened data. Decision rule: ŷ = argmax{j: p_j ≥ τ} p_j (or argmax p_j if none valid) Parameters ---------- true_labels : array-like of shape (n_samples,) True class labels in {0, 1, ..., K-1} pred_proba : array-like of shape (n_samples, n_classes) Predicted probabilities for each class metric : str, default="f1" Metric to optimize method : str, default="auto" Binary optimization method sample_weight : array-like of shape (n_samples,), optional Sample weights comparison : str, default=">" Comparison operator tolerance : float, default=1e-10 Numerical tolerance Returns ------- OptimizationResult Result with single threshold applied to all classes Examples -------- >>> result = optimize_micro_multiclass(y_true, y_prob, metric="f1") >>> result.thresholds # Same threshold for all classes [0.3, 0.3, 0.3] """ from .binary import optimize_metric_binary # Validate inputs true_labels, pred_proba, sample_weight = validate_multiclass_classification( true_labels, pred_proba, sample_weight, require_proba=True ) n_samples, n_classes = pred_proba.shape # Flatten to single binary problem for micro averaging # Each (sample, class) pair becomes a binary prediction classes = np.arange(n_classes) true_binary_flat = ( np.repeat(true_labels, n_classes) == np.tile(classes, n_samples) ).astype(int) pred_proba_flat = pred_proba.ravel() # Replicate sample weights if provided sample_weight_flat = ( None if sample_weight is None else np.repeat(sample_weight, n_classes) ) # Optimize single threshold on flattened problem result = optimize_metric_binary( true_binary_flat, pred_proba_flat, metric=metric, method=method, sample_weight=sample_weight_flat, comparison=comparison, tolerance=tolerance, ) optimal_threshold = result.thresholds[0] def predict_multiclass_micro(probs: ArrayLike) -> np.ndarray: """Predict using single threshold across all classes.""" p = np.asarray(probs, dtype=np.float64) if p.ndim != 2 or p.shape[1] != n_classes: raise ValueError(f"Expected probabilities shape (n_samples, {n_classes})") # Apply threshold to get valid classes if comparison == ">=": valid = p >= optimal_threshold else: valid = p > optimal_threshold # Predict class with highest valid probability masked_probs = np.where(valid, p, -np.inf) predictions = np.argmax(masked_probs, axis=1) # Fallback to argmax when no classes are valid no_valid = ~np.any(valid, axis=1) if np.any(no_valid): predictions[no_valid] = np.argmax(p[no_valid], axis=1) return predictions.astype(np.int32) # Return same threshold for all classes thresholds = np.full(n_classes, optimal_threshold, dtype=np.float64) return OptimizationResult( thresholds=thresholds, scores=result.scores, predict=predict_multiclass_micro, task=Task.MULTICLASS, metric=f"micro_{metric}", n_classes=n_classes, )
[docs] def optimize_multiclass( true_labels: ArrayLike, pred_proba: ArrayLike, *, metric: str = "f1", average: str = "macro", method: str = "auto", sample_weight: ArrayLike | None = None, comparison: str = ">", tolerance: float = 1e-10, ) -> OptimizationResult: """General multiclass threshold optimization with automatic method selection. Routes to appropriate algorithm based on averaging strategy and method: - Macro + auto/coord_ascent: Margin rule with coordinate ascent (single-label) - Macro + independent: Independent OvR optimization (can predict multiple) - Micro: Single threshold optimization (single-label) Parameters ---------- true_labels : array-like of shape (n_samples,) True class labels in {0, 1, ..., K-1} pred_proba : array-like of shape (n_samples, n_classes) Predicted probabilities for each class metric : str, default="f1" Metric to optimize average : {"macro", "micro"}, default="macro" Averaging strategy method : {"auto", "coord_ascent", "independent"}, default="auto" Optimization method: - "auto": For macro, uses coord_ascent (margin rule) - "coord_ascent": Margin rule with coordinate ascent - "independent": Independent per-class optimization (OvR) sample_weight : array-like of shape (n_samples,), optional Sample weights comparison : str, default=">" Comparison operator tolerance : float, default=1e-10 Numerical tolerance Returns ------- OptimizationResult Result with optimal thresholds and prediction function Examples -------- >>> # Margin rule (single-label, coordinate ascent) >>> result = optimize_multiclass(y_true, y_prob, method="coord_ascent") >>> >>> # Independent optimization (can predict multiple classes) >>> result = optimize_multiclass(y_true, y_prob, method="independent") >>> >>> # Micro averaging (single threshold) >>> result = optimize_multiclass(y_true, y_prob, average="micro") """ match average: case "micro": return optimize_micro_multiclass( true_labels, pred_proba, metric=metric, method=method, sample_weight=sample_weight, comparison=comparison, tolerance=tolerance, ) case "macro": match method: case "auto": # Auto method: choose best method based on metric and comparison compatibility if metric == "f1" and comparison == ">": # F1 with ">" is supported by coordinate ascent - use it for better coupling return optimize_ovr_margin( true_labels, pred_proba, metric=metric, max_iter=30, sample_weight=sample_weight, comparison=comparison, tolerance=tolerance, ) else: # Other metrics/comparisons not supported by coord_ascent - use independent return optimize_ovr_independent( true_labels, pred_proba, metric=metric, method="auto", sample_weight=sample_weight, comparison=comparison, tolerance=tolerance, ) case "coord_ascent": return optimize_ovr_margin( true_labels, pred_proba, metric=metric, max_iter=30, sample_weight=sample_weight, comparison=comparison, tolerance=tolerance, ) case "independent" | "minimize" | "unique_scan" | "gradient": # Route legacy and scipy methods to independent optimization # minimize, unique_scan, gradient are legacy binary methods - use independent for multiclass return optimize_ovr_independent( true_labels, pred_proba, metric=metric, method="auto", sample_weight=sample_weight, comparison=comparison, tolerance=tolerance, ) case _: raise ValueError(f"Unknown method for macro averaging: {method}") case _: raise ValueError(f"Unknown average: {average}. Use 'macro' or 'micro'")
__all__ = [ "optimize_ovr_independent", "optimize_ovr_margin", "optimize_micro_multiclass", "optimize_multiclass", ]