"""Metric registry, confusion matrix utilities, and built-in metrics."""
from collections.abc import Callable
import numpy as np
from .types import ArrayLike, ComparisonOperator, MetricFunc
from .validation import (
_validate_comparison_operator,
_validate_inputs,
_validate_threshold,
)
METRIC_REGISTRY: dict[str, MetricFunc] = {}
VECTORIZED_REGISTRY: dict[str, Callable] = {} # For vectorized metric functions
METRIC_PROPERTIES: dict[str, dict[str, bool | float]] = {}
[docs]
def register_metric(
name: str | None = None,
func: MetricFunc | None = None,
vectorized_func: Callable | None = None,
is_piecewise: bool = True,
maximize: bool = True,
needs_proba: bool = False,
) -> MetricFunc | Callable[[MetricFunc], MetricFunc]:
"""Register a metric function with optional vectorized version.
Parameters
----------
name:
Optional key under which to store the metric. If not provided the
function's ``__name__`` is used.
func:
Metric callable accepting ``tp, tn, fp, fn`` scalars and returning a float.
When supplied the function is registered immediately. If omitted, the
returned decorator can be used to annotate a metric function.
vectorized_func:
Optional vectorized version of the metric that accepts ``tp, tn, fp, fn``
as arrays and returns an array of scores. Used for O(n log n) optimization.
is_piecewise:
Whether the metric is piecewise-constant with respect to threshold changes.
Piecewise metrics can be optimized using O(n log n) algorithms.
maximize:
Whether the metric should be maximized (True) or minimized (False).
needs_proba:
Whether the metric requires probability scores rather than just thresholds.
Used for metrics like log-loss or Brier score.
Returns
-------
MetricFunc | Callable[[MetricFunc], MetricFunc]
The registered function or decorator.
"""
if func is not None:
metric_name = name or func.__name__
METRIC_REGISTRY[metric_name] = func
if vectorized_func is not None:
VECTORIZED_REGISTRY[metric_name] = vectorized_func
METRIC_PROPERTIES[metric_name] = {
"is_piecewise": is_piecewise,
"maximize": maximize,
"needs_proba": needs_proba,
}
return func
def decorator(f: MetricFunc) -> MetricFunc:
metric_name = name or f.__name__
METRIC_REGISTRY[metric_name] = f
if vectorized_func is not None:
VECTORIZED_REGISTRY[metric_name] = vectorized_func
METRIC_PROPERTIES[metric_name] = {
"is_piecewise": is_piecewise,
"maximize": maximize,
"needs_proba": needs_proba,
}
return f
return decorator
[docs]
def register_metrics(
metrics: dict[str, MetricFunc],
is_piecewise: bool = True,
maximize: bool = True,
needs_proba: bool = False,
) -> None:
"""Register multiple metric functions.
Parameters
----------
metrics:
Mapping of metric names to callables that accept ``tp, tn, fp, fn``.
is_piecewise:
Whether the metrics are piecewise-constant with respect to threshold changes.
maximize:
Whether the metrics should be maximized (True) or minimized (False).
needs_proba:
Whether the metrics require probability scores rather than just thresholds.
Returns
-------
None
This function mutates the global :data:`METRIC_REGISTRY` in-place.
"""
METRIC_REGISTRY.update(metrics)
for name in metrics:
METRIC_PROPERTIES[name] = {
"is_piecewise": is_piecewise,
"maximize": maximize,
"needs_proba": needs_proba,
}
[docs]
def is_piecewise_metric(metric_name: str) -> bool:
"""Check if a metric is piecewise-constant.
Parameters
----------
metric_name:
Name of the metric to check.
Returns
-------
bool
True if the metric is piecewise-constant, False otherwise.
Defaults to True for unknown metrics.
"""
return METRIC_PROPERTIES.get(metric_name, {"is_piecewise": True})["is_piecewise"]
[docs]
def should_maximize_metric(metric_name: str) -> bool:
"""Check if a metric should be maximized.
Parameters
----------
metric_name:
Name of the metric to check.
Returns
-------
bool
True if the metric should be maximized, False if minimized.
Defaults to True for unknown metrics.
"""
return METRIC_PROPERTIES.get(metric_name, {"maximize": True})["maximize"]
[docs]
def needs_probability_scores(metric_name: str) -> bool:
"""Check if a metric needs probability scores rather than just thresholds.
Parameters
----------
metric_name:
Name of the metric to check.
Returns
-------
bool
True if the metric needs probability scores, False otherwise.
Defaults to False for unknown metrics.
"""
return METRIC_PROPERTIES.get(metric_name, {"needs_proba": False})["needs_proba"]
[docs]
def has_vectorized_implementation(metric_name: str) -> bool:
"""Check if a metric has a vectorized implementation available.
Parameters
----------
metric_name:
Name of the metric to check.
Returns
-------
bool
True if the metric has a vectorized implementation, False otherwise.
"""
return metric_name in VECTORIZED_REGISTRY
[docs]
def get_vectorized_metric(metric_name: str) -> Callable:
"""Get vectorized version of a metric function.
Parameters
----------
metric_name:
Name of the metric.
Returns
-------
Callable
Vectorized metric function that accepts arrays.
Raises
------
ValueError
If metric is not available in vectorized form.
"""
if metric_name not in VECTORIZED_REGISTRY:
available = list(VECTORIZED_REGISTRY.keys())
raise ValueError(
f"Vectorized implementation not available for metric '{metric_name}'. "
f"Available: {available}"
)
return VECTORIZED_REGISTRY[metric_name]
# Vectorized metric implementations for O(n log n) optimization
def _f1_vectorized(
tp: np.ndarray, tn: np.ndarray, fp: np.ndarray, fn: np.ndarray
) -> np.ndarray:
"""Vectorized F1 score computation."""
precision = np.where(tp + fp > 0, tp / (tp + fp), 0.0)
recall = np.where(tp + fn > 0, tp / (tp + fn), 0.0)
return np.where(
precision + recall > 0, 2 * precision * recall / (precision + recall), 0.0
)
def _accuracy_vectorized(
tp: np.ndarray, tn: np.ndarray, fp: np.ndarray, fn: np.ndarray
) -> np.ndarray:
"""Vectorized accuracy computation."""
total = tp + tn + fp + fn
return np.where(total > 0, (tp + tn) / total, 0.0)
def _precision_vectorized(
tp: np.ndarray, tn: np.ndarray, fp: np.ndarray, fn: np.ndarray
) -> np.ndarray:
"""Vectorized precision computation."""
return np.where(tp + fp > 0, tp / (tp + fp), 0.0)
def _recall_vectorized(
tp: np.ndarray, tn: np.ndarray, fp: np.ndarray, fn: np.ndarray
) -> np.ndarray:
"""Vectorized recall computation."""
return np.where(tp + fn > 0, tp / (tp + fn), 0.0)
[docs]
@register_metric("f1", vectorized_func=_f1_vectorized)
def f1_score(
tp: int | float, tn: int | float, fp: int | float, fn: int | float
) -> float:
r"""Compute the F\ :sub:`1` score.
Parameters
----------
tp, tn, fp, fn:
Elements of the confusion matrix.
Returns
-------
float
The harmonic mean of precision and recall.
"""
precision = tp / (tp + fp) if tp + fp > 0 else 0.0
recall = tp / (tp + fn) if tp + fn > 0 else 0.0
return (
2 * precision * recall / (precision + recall)
if (precision + recall) > 0
else 0.0
)
[docs]
@register_metric("accuracy", vectorized_func=_accuracy_vectorized)
def accuracy_score(
tp: int | float, tn: int | float, fp: int | float, fn: int | float
) -> float:
"""Compute classification accuracy.
Parameters
----------
tp, tn, fp, fn:
Elements of the confusion matrix.
Returns
-------
float
Ratio of correct predictions to total samples.
"""
total = tp + tn + fp + fn
return (tp + tn) / total if total > 0 else 0.0
[docs]
@register_metric("precision", vectorized_func=_precision_vectorized)
def precision_score(
tp: int | float, tn: int | float, fp: int | float, fn: int | float
) -> float:
"""Compute precision (positive predictive value).
Parameters
----------
tp, tn, fp, fn:
Elements of the confusion matrix.
Returns
-------
float
Ratio of true positives to predicted positives.
"""
return tp / (tp + fp) if tp + fp > 0 else 0.0
[docs]
@register_metric("recall", vectorized_func=_recall_vectorized)
def recall_score(
tp: int | float, tn: int | float, fp: int | float, fn: int | float
) -> float:
"""Compute recall (sensitivity, true positive rate).
Parameters
----------
tp, tn, fp, fn:
Elements of the confusion matrix.
Returns
-------
float
Ratio of true positives to actual positives.
"""
return tp / (tp + fn) if tp + fn > 0 else 0.0
[docs]
def multiclass_metric(
confusion_matrices: list[tuple[int | float, int | float, int | float, int | float]],
metric_name: str,
average: str = "macro",
) -> float | np.ndarray:
"""Compute multiclass metrics from per-class confusion matrices.
Parameters
----------
confusion_matrices:
List of per-class confusion matrix tuples ``(tp, tn, fp, fn)``.
metric_name:
Name of the metric to compute (must be in METRIC_REGISTRY).
average:
Averaging strategy: "macro", "micro", "weighted", or "none".
- "macro": Unweighted mean of per-class metrics (treats all classes equally)
- "micro": Global metric computed on pooled confusion matrix
(treats all samples equally)
- "weighted": Weighted mean by support (number of true instances per class)
- "none": No averaging, returns array of per-class metrics
Returns
-------
float | np.ndarray
Aggregated metric score (float) or per-class scores (array) if average="none".
"""
if metric_name not in METRIC_REGISTRY:
raise ValueError(f"Unknown metric: {metric_name}")
metric_func = METRIC_REGISTRY[metric_name]
if average == "macro":
# Unweighted mean of per-class scores
scores = [metric_func(*cm) for cm in confusion_matrices]
return float(np.mean(scores))
elif average == "micro":
# For micro averaging, sum only TP, FP, FN
# (not TN which is inflated in One-vs-Rest)
total_tp = sum(cm[0] for cm in confusion_matrices)
total_fp = sum(cm[2] for cm in confusion_matrices)
total_fn = sum(cm[3] for cm in confusion_matrices)
# Compute micro metrics directly
if metric_name == "precision":
return float(
total_tp / (total_tp + total_fp) if total_tp + total_fp > 0 else 0.0
)
elif metric_name == "recall":
return float(
total_tp / (total_tp + total_fn) if total_tp + total_fn > 0 else 0.0
)
elif metric_name == "f1":
precision = (
total_tp / (total_tp + total_fp) if total_tp + total_fp > 0 else 0.0
)
recall = (
total_tp / (total_tp + total_fn) if total_tp + total_fn > 0 else 0.0
)
return float(
2 * precision * recall / (precision + recall)
if (precision + recall) > 0
else 0.0
)
elif metric_name == "accuracy":
# For accuracy in One-vs-Rest, compute as correct predictions
# / total predictions
# This is equivalent to (total_tp) / (total_tp + total_fp + total_fn)
total_predictions = total_tp + total_fp + total_fn
return float(total_tp / total_predictions if total_predictions > 0 else 0.0)
else:
# Fallback: try using the metric function with computed values
# Note: TN is not meaningful in One-vs-Rest micro averaging
return float(metric_func(total_tp, 0, total_fp, total_fn))
elif average == "weighted":
# Weighted by support (number of true instances for each class)
scores = []
supports = []
for cm in confusion_matrices:
tp, tn, fp, fn = cm
scores.append(metric_func(*cm))
supports.append(tp + fn) # actual positives for this class
total_support = sum(supports)
if total_support == 0:
return 0.0
weighted_score = (
sum(
score * support
for score, support in zip(scores, supports, strict=False)
)
/ total_support
)
return float(weighted_score)
elif average == "none":
# No averaging: return per-class scores
scores = [metric_func(*cm) for cm in confusion_matrices]
return np.array(scores)
else:
raise ValueError(
f"Unknown averaging method: {average}. "
f"Must be one of: 'macro', 'micro', 'weighted', 'none'."
)
[docs]
def get_confusion_matrix(
true_labs: ArrayLike,
pred_prob: ArrayLike,
prob: float,
sample_weight: ArrayLike | None = None,
comparison: ComparisonOperator = ">",
) -> tuple[int | float, int | float, int | float, int | float]:
"""Compute confusion-matrix counts for a given threshold.
Parameters
----------
true_labs:
Array of true binary labels in {0, 1}.
pred_prob:
Array of predicted probabilities in [0, 1].
prob:
Decision threshold applied to ``pred_prob``.
sample_weight:
Optional array of sample weights. If None, all samples have equal weight.
comparison:
Comparison operator for thresholding: ">" (exclusive) or ">=" (inclusive).
- ">": pred_prob > threshold (default, excludes ties)
- ">=": pred_prob >= threshold (includes ties)
Returns
-------
tuple[int | float, int | float, int | float, int | float]
Counts ``(tp, tn, fp, fn)``. Returns int when sample_weight is None,
float when sample_weight is provided to preserve fractional weighted counts.
"""
# Validate inputs
true_labs, pred_prob, sample_weight = _validate_inputs(
true_labs,
pred_prob,
require_binary=True,
sample_weight=sample_weight,
allow_multiclass=False,
)
_validate_threshold(float(prob))
_validate_comparison_operator(comparison)
# Apply threshold with specified comparison operator
if comparison == ">":
pred_labs = pred_prob > prob
else: # ">="
pred_labs = pred_prob >= prob
if sample_weight is None:
tp = np.sum(np.logical_and(pred_labs == 1, true_labs == 1))
tn = np.sum(np.logical_and(pred_labs == 0, true_labs == 0))
fp = np.sum(np.logical_and(pred_labs == 1, true_labs == 0))
fn = np.sum(np.logical_and(pred_labs == 0, true_labs == 1))
return int(tp), int(tn), int(fp), int(fn)
else:
sample_weight = np.asarray(sample_weight)
if len(sample_weight) != len(true_labs):
raise ValueError(
f"Length mismatch: sample_weight ({len(sample_weight)}) "
f"vs true_labs ({len(true_labs)})"
)
tp = np.sum(sample_weight * np.logical_and(pred_labs == 1, true_labs == 1))
tn = np.sum(sample_weight * np.logical_and(pred_labs == 0, true_labs == 0))
fp = np.sum(sample_weight * np.logical_and(pred_labs == 1, true_labs == 0))
fn = np.sum(sample_weight * np.logical_and(pred_labs == 0, true_labs == 1))
# Return float values when using sample weights to preserve fractional counts
return float(tp), float(tn), float(fp), float(fn)
[docs]
def get_multiclass_confusion_matrix(
true_labs: ArrayLike,
pred_prob: ArrayLike,
thresholds: ArrayLike,
sample_weight: ArrayLike | None = None,
comparison: ComparisonOperator = ">",
) -> list[tuple[int | float, int | float, int | float, int | float]]:
"""Compute per-class confusion-matrix counts for multiclass classification
using One-vs-Rest.
Parameters
----------
true_labs:
Array of true class labels (0, 1, 2, ..., n_classes-1).
pred_prob:
Array of predicted probabilities with shape (n_samples, n_classes).
thresholds:
Array of decision thresholds, one per class.
sample_weight:
Optional array of sample weights. If None, all samples have equal weight.
comparison:
Comparison operator for thresholding: ">" (exclusive) or ">=" (inclusive).
Returns
-------
list[tuple[int | float, int | float, int | float, int | float]]
List of per-class counts ``(tp, tn, fp, fn)`` for each class.
Returns int when sample_weight is None, float when sample_weight is provided.
"""
# Validate inputs
true_labs, pred_prob, sample_weight = _validate_inputs(
true_labs, pred_prob, sample_weight=sample_weight
)
_validate_comparison_operator(comparison)
if pred_prob.ndim == 1:
# Binary case - backward compatibility
thresholds = np.asarray(thresholds)
_validate_threshold(thresholds[0])
return [
get_confusion_matrix(
true_labs, pred_prob, thresholds[0], sample_weight, comparison
)
]
# Multiclass case
n_classes = pred_prob.shape[1]
thresholds = np.asarray(thresholds)
_validate_threshold(thresholds, n_classes)
confusion_matrices = []
for class_idx in range(n_classes):
# One-vs-Rest: current class vs all others
true_binary = (true_labs == class_idx).astype(int)
pred_binary_prob = pred_prob[:, class_idx]
threshold = thresholds[class_idx]
cm = get_confusion_matrix(
true_binary, pred_binary_prob, threshold, sample_weight, comparison
)
confusion_matrices.append(cm)
return confusion_matrices