"""
Robust rank-preserving multiclass probability calibration.
This module provides numerically stable implementations of rank-preserving
calibration algorithms including Dykstra's alternating projections and ADMM.
"""
from __future__ import annotations
import logging
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
import numpy as np
from ._numba_utils import get_jit_functions
from .nearly import (
project_near_isotonic_euclidean, # epsilon-slack, with exact sum shift
prox_near_isotonic, # lambda-penalty (exact prox if provided version)
)
type NDArrayFloat = np.ndarray[Any, np.dtype[np.floating[Any]]]
type ColumnOrders = list[np.ndarray]
type CallbackFunction = Callable[[int, float, np.ndarray], bool] | None
_jit_funcs = get_jit_functions()
# Set up logging
logger = logging.getLogger(__name__)
def _configure_logging(verbose: bool) -> None:
"""Configure logging level based on verbosity setting."""
if verbose:
logger.setLevel(logging.DEBUG)
# Ensure handler exists and is configured
if not logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False
else:
logger.setLevel(logging.WARNING)
# Remove any existing handlers when verbose=False to prevent leakage
for handler in logger.handlers[:]:
logger.removeHandler(handler)
logger.propagate = True # Let parent loggers handle output
# ---------------------------------------------------------------------
# Data containers
# ---------------------------------------------------------------------
[docs]
@dataclass(slots=True)
class CalibrationResult:
"""Result container for rank-preserving calibration algorithms.
Returned by calibrate_dykstra() containing the calibrated probability matrix
and diagnostic information about convergence and constraint satisfaction.
Attributes:
Q: Calibrated probability matrix of shape (N, J) where rows sum to 1
and columns preserve rank ordering from original scores.
converged: True if algorithm converged within specified tolerance.
iterations: Number of iterations performed before termination.
max_row_error: Maximum absolute error in row sum constraint (should be ≈0).
max_col_error: Maximum absolute error in column sum constraint.
max_rank_violation: Maximum rank-order violation across all columns.
final_change: Final relative change in solution between iterations.
Examples:
>>> result = calibrate_dykstra(P, M)
>>> if result.converged:
... print(f"Calibration successful in {result.iterations} iterations")
>>> print(f"Max rank violation: {result.max_rank_violation:.6f}")
"""
Q: np.ndarray
converged: bool
iterations: int
max_row_error: float
max_col_error: float
max_rank_violation: float
final_change: float
[docs]
@dataclass(slots=True)
class ADMMResult:
"""Result from ADMM optimization.
Attributes
----------
Q : np.ndarray
Calibrated probability matrix.
converged : bool
Whether ADMM converged.
iterations : int
Number of iterations performed.
objective_values : list[float]
Objective function values over iterations.
primal_residuals : list[float]
Primal residual norms over iterations.
dual_residuals : list[float]
Dual residual norms over iterations.
max_row_error : float
Maximum row sum error.
max_col_error : float
Maximum column sum error.
max_rank_violation : float
Maximum rank violation.
final_change : float
Final relative change between iterations.
"""
Q: np.ndarray
converged: bool
iterations: int
objective_values: list[float]
primal_residuals: list[float]
dual_residuals: list[float]
max_row_error: float
max_col_error: float
max_rank_violation: float
final_change: float
[docs]
class CalibrationError(Exception):
"""Raised when calibration fails due to invalid inputs or numerical issues."""
# ---------------------------------------------------------------------
# Validation
# ---------------------------------------------------------------------
def _validate_inputs(
P: np.ndarray, M: np.ndarray, max_iters: int, tol: float, feasibility_tol: float
) -> tuple[int, int]:
"""Validate all inputs to calibration functions.
Args:
P: Input probability matrix to validate.
M: Target column sums to validate.
max_iters: Maximum iterations parameter to validate.
tol: Tolerance parameter to validate.
feasibility_tol: Feasibility tolerance to validate.
Returns:
Tuple of (N, J) where N is number of instances and J is number of classes.
Raises:
CalibrationError: If any input validation fails.
"""
# Validate P using match statements
match P:
case x if not isinstance(x, np.ndarray):
raise CalibrationError("P must be a numpy array")
case x if x.ndim != 2:
raise CalibrationError("P must be a 2D array of shape (N, J)")
case x if x.size == 0:
raise CalibrationError("P cannot be empty")
case x if not np.isfinite(x).all():
raise CalibrationError("P must not contain NaN or infinite values")
case x if np.any(x < 0):
raise CalibrationError("P must contain non-negative values")
N, J = P.shape
match J:
case j if j < 2:
raise CalibrationError("P must have at least 2 columns (classes)")
# Validate M using match statements
match M:
case x if not isinstance(x, np.ndarray):
raise CalibrationError("M must be a numpy array")
case x if x.ndim != 1:
raise CalibrationError("M must be a 1D array")
case x if x.size != J:
raise CalibrationError(f"M must have length {J} to match P.shape[1]")
case x if not np.isfinite(x).all():
raise CalibrationError("M must not contain NaN or infinite values")
case x if np.any(x < 0):
raise CalibrationError("M must contain non-negative values")
# Check basic feasibility (soft warning)
M_sum = float(M.sum())
match abs(M_sum - N):
case diff if diff > feasibility_tol * N:
warnings.warn(
f"Sum of M ({M_sum:.3f}) differs from N ({N}) by "
f"{diff:.3f}. Problem may be infeasible.",
UserWarning,
stacklevel=2,
)
# Validate other parameters using match statements
match max_iters:
case x if not isinstance(x, int) or x <= 0:
raise CalibrationError("max_iters must be a positive integer")
match tol:
case x if not isinstance(x, int | float) or x <= 0:
raise CalibrationError("tol must be a positive number")
match feasibility_tol:
case x if not isinstance(x, int | float) or x < 0:
raise CalibrationError("feasibility_tol must be non-negative")
return N, J
# ---------------------------------------------------------------------
# Core projections
# ---------------------------------------------------------------------
def _project_row_simplex(
rows: np.ndarray, eps: float = 1e-15, use_jit: bool = True
) -> np.ndarray:
"""Project rows onto the probability simplex with numerical stability.
Projects each row of the matrix onto the probability simplex, ensuring
non-negative entries that sum to 1. Uses Euclidean projection algorithm
with numerical stability improvements.
Args:
rows: Matrix of shape (N, J) where each row will be projected.
eps: Small tolerance for numerical stability in computations.
use_jit: Whether to use JIT-compiled version if available.
Returns:
Projected matrix with same shape, where each row sums to 1 and is non-negative.
"""
# Use JIT version if available and requested
if (
use_jit
and _jit_funcs["available"]
and _jit_funcs["project_row_simplex"] is not None
):
return _jit_funcs["project_row_simplex"](rows, eps)
# Fallback to pure Python implementation
N, J = rows.shape
projected = np.empty_like(rows, dtype=np.float64)
for i in range(N):
v = rows[i]
u = np.sort(v)[::-1]
cssv = np.cumsum(u) - 1.0
ind = np.arange(1, J + 1, dtype=np.float64)
cond = u - cssv / ind > eps
rho = np.nonzero(cond)[0][-1] if np.any(cond) else (J - 1)
theta = cssv[rho] / (rho + 1)
w = np.maximum(v - theta, 0.0)
# Normalize defensively to exactly sum to 1
sum_w = w.sum()
if sum_w > eps:
w /= sum_w
else:
w[:] = 1.0 / J
projected[i] = w
return projected
# ---------- Isotonic (PAV) -------------------------------------------
def _isotonic_regression(
y: np.ndarray,
rtol: float = 0.0,
ties: str = "stable",
weights: np.ndarray | None = None,
) -> np.ndarray:
"""
Isotonic regression (nondecreasing) via stack-based Pool Adjacent Violators in O(n).
Strict by default (rtol=0.0) to avoid micro-violations in tests.
"""
def _tol(a: float, b: float) -> float:
return rtol * (abs(a) + abs(b) + 1.0)
y = np.asarray(y, dtype=np.float64)
n = y.size
if n <= 1:
return y.copy()
if rtol < 0:
raise ValueError("rtol must be nonnegative.")
if ties not in ("stable", "group"):
raise ValueError("ties must be 'stable' or 'group'.")
if weights is None:
w_init = np.ones(n, dtype=np.int64)
else:
w_init = np.asarray(weights)
if w_init.shape != y.shape:
raise ValueError("weights must have the same shape as y")
if np.any(w_init <= 0):
raise ValueError("weights must be positive")
if np.all(np.isclose(w_init, np.round(w_init))):
w_init = np.round(w_init).astype(np.int64)
else:
w_init = w_init.astype(np.float64)
# Optional pre-pooling of *contiguous exact equals in y*
if ties == "group":
vals: list[float] = []
wts: list[float] = []
i = 0
while i < n:
j = i + 1
vi = y[i]
total_w = float(w_init[i])
while j < n and y[j] == vi:
total_w += float(w_init[j])
j += 1
vals.append(float(vi))
wts.append(total_w)
i = j
a = np.asarray(vals, dtype=np.float64)
w0 = np.asarray(wts, dtype=np.float64)
else:
a = y
w0 = np.asarray(w_init, dtype=np.float64)
m = a.size
# Block stacks
start = np.empty(m, dtype=np.int64) # start index in expanded output
mean = np.empty(m, dtype=np.float64) # block mean
wsum = np.empty(m, dtype=np.float64) # block weight
top = -1
idx = 0 # running start position
for i in range(m):
# push new block
top += 1
start[top] = idx
mean[top] = a[i]
wsum[top] = w0[i]
idx += round(w0[i])
# merge backward while violated beyond tolerance
while top > 0 and mean[top - 1] > mean[top] + _tol(mean[top - 1], mean[top]):
w1 = wsum[top - 1]
w2 = wsum[top]
mean[top - 1] = (w1 * mean[top - 1] + w2 * mean[top]) / (w1 + w2)
wsum[top - 1] = w1 + w2
top -= 1
# expand pooled block means back to full length
out = np.empty(n, dtype=np.float64)
for j in range(top + 1):
s = start[j]
e = start[j + 1] if j < top else n
out[s:e] = mean[j]
return out
# ---------- Column projection (exact) --------------------------------
def _run_lengths_of_equals(x_sorted: np.ndarray) -> np.ndarray:
"""Run-lengths of contiguous exact equals in an already-sorted array."""
n = x_sorted.size
if n == 0:
return np.zeros(0, dtype=np.int64)
lens: list[int] = []
cnt = 1
for i in range(1, n):
if x_sorted[i] == x_sorted[i - 1]:
cnt += 1
else:
lens.append(cnt)
cnt = 1
lens.append(cnt)
return np.asarray(lens, dtype=np.int64)
def _project_column_isotonic_sum(
column: np.ndarray,
column_order: np.ndarray,
target_sum: float,
*,
rtol: float = 0.0,
nearly: dict | None = None,
ties: str = "stable",
score_sorted: np.ndarray | None = None,
) -> np.ndarray:
"""Project one column onto isotonic (by model-score order) with a fixed sum.
Exact Euclidean projection: PAV, then a *uniform additive shift* c so the
column sums to `target_sum`. If `ties=="group"` and `score_sorted` is given,
we pre-pool equal-score runs, run weighted PAV, add the uniform shift, then expand.
"""
if column.size == 0:
return column.copy()
y = column[column_order]
# Nearly-isotonic mode selection using match
match nearly:
case {"mode": "epsilon", **rest}:
eps = float(rest.get("eps", 1e-3))
iso_shifted = project_near_isotonic_euclidean(y, eps, sum_target=target_sum)
case None | _:
if ties == "group" and score_sorted is not None:
lens = _run_lengths_of_equals(score_sorted)
k = lens.size
y_group = np.empty(k, dtype=np.float64)
pos = 0
for g in range(k):
L = int(lens[g])
y_group[g] = float(np.mean(y[pos : pos + L]))
pos += L
z_group = _isotonic_regression(
y_group, rtol=rtol, ties="stable", weights=lens.astype(np.float64)
)
total_n = int(lens.sum())
c = (float(target_sum) - float(np.dot(z_group, lens))) / float(total_n)
z_group_shift = z_group + c
iso_shifted = np.repeat(z_group_shift, lens)
else:
iso = _isotonic_regression(y, rtol=rtol, ties=ties)
c = (float(target_sum) - float(iso.sum())) / float(iso.size)
iso_shifted = iso + c
projected = np.empty_like(column, dtype=np.float64)
projected[column_order] = iso_shifted
return projected
# ---------------------------------------------------------------------
# Diagnostics & helpers
# ---------------------------------------------------------------------
def _compute_rank_violation(Q: np.ndarray, P: np.ndarray) -> float:
"""Compute maximum rank violation across all columns (w.r.t. original scores)."""
max_violation = 0.0
_, J = Q.shape
for j in range(J):
idx = np.argsort(P[:, j]) # order by original model scores
q_sorted = Q[idx, j]
if q_sorted.size > 1:
diffs = np.diff(q_sorted)
violation = float(np.max(np.maximum(0.0, -diffs)))
max_violation = max(max_violation, violation)
return max_violation
def _detect_cycling(
Q_history: list[NDArrayFloat], Q: NDArrayFloat, cycle_tol: float = 1e-12
) -> bool:
"""Very conservative cycle detection (usually disabled)."""
matches = 0
for prev_Q in Q_history:
if np.allclose(Q, prev_Q, rtol=cycle_tol, atol=cycle_tol):
matches += 1
if matches >= 2:
return True
return False
def _polish_to_intersection(
Q: np.ndarray,
M: np.ndarray,
column_orders: ColumnOrders,
*,
rtol: float = 0.0,
ties: str = "stable",
score_sorted: list[np.ndarray | None] | None = None,
max_iters: int = 200,
row_atol: float = 1e-12,
col_atol: float = 1e-10,
) -> np.ndarray:
"""Small alternating-projection polish to hit constraints to machine tolerance."""
_, J = Q.shape
if score_sorted is None:
score_sorted = [None] * J
for _ in range(max_iters):
Q = _project_row_simplex(Q)
for j in range(J):
Q[:, j] = _project_column_isotonic_sum(
Q[:, j],
column_orders[j],
float(M[j]),
rtol=rtol,
ties=ties,
score_sorted=score_sorted[j] if score_sorted else None,
)
if np.allclose(Q.sum(axis=1), 1.0, atol=row_atol) and np.allclose(
Q.sum(axis=0), M, atol=col_atol
):
break
return Q
# ---------------------------------------------------------------------
# Dykstra calibration (exact projections onto each set)
# ---------------------------------------------------------------------
[docs]
def calibrate_dykstra(
P: np.ndarray,
M: np.ndarray,
max_iters: int = 3000,
tol: float = 1e-7,
rtol: float = 0.0, # strict isotone by default
feasibility_tol: float = 0.1,
verbose: bool = False,
callback: CallbackFunction = None,
detect_cycles: bool = False, # default off for determinism
cycle_window: int = 10,
nearly: dict | None = None,
ties: str = "stable",
use_jit: bool = True,
) -> CalibrationResult:
"""Calibrate using Dykstra's alternating projections.
Projects multiclass probabilities onto the intersection of:
(A) row simplex: {rows ≥ 0, rows sum to 1} and
(B) column-wise isotone-by-score + fixed column sums: {nondecreasing in score order; column sum = M_j}.
This is the recommended default method for rank-preserving calibration. The algorithm
uses exact Euclidean projections via Pool Adjacent Violators (PAV) followed by uniform
shifts to satisfy sum constraints.
Args:
P: Input probability matrix of shape (N, J). Each row represents predicted class
probabilities for one instance. Rows need not sum to 1 initially.
M: Target column sums of shape (J,). Should sum to approximately N for feasibility.
max_iters: Maximum number of iterations. Default 3000 is usually sufficient.
tol: Convergence tolerance for relative change in solution. Default 1e-7.
rtol: Relative tolerance for isotonic violations in PAV. Default 0.0 (strict).
feasibility_tol: Tolerance for feasibility warnings when sum(M) differs from N.
verbose: If True, enables debug logging.
callback: Optional function called each iteration as callback(iter, change, Q).
Should return False to terminate early.
detect_cycles: If True, detects and breaks cycles in the solution sequence.
cycle_window: Number of iterations to look back for cycle detection.
nearly: Optional dict for nearly-isotonic constraints. Use {"mode": "epsilon", "eps": 0.01}
to allow small isotonicity violations.
ties: How to handle tied scores. "stable" preserves input order, "group" pools
equal-score instances.
use_jit: If True and numba is available, uses JIT-compiled functions for speed.
Returns:
CalibrationResult object containing:
- Q: Calibrated probability matrix of shape (N, J)
- converged: Always True (failures raise CalibrationError instead)
- iterations: Number of iterations performed
- max_row_error: Maximum absolute row sum error
- max_col_error: Maximum absolute column sum error
- max_rank_violation: Maximum rank order violation
- final_change: Final relative change in solution
Raises:
CalibrationError: If inputs are invalid, algorithm fails to converge, or other errors occur.
ValueError: If ties parameter is not "stable" or "group"
Examples:
Basic calibration:
>>> import numpy as np
>>> from rank_preserving_calibration import calibrate_dykstra
>>> P = np.array([[0.7, 0.2, 0.1], [0.3, 0.5, 0.2]])
>>> M = np.array([1.0, 0.7, 0.3]) # Target column sums
>>> result = calibrate_dykstra(P, M)
>>> print(f"Converged: {result.converged}")
>>> print(f"Row sums: {result.Q.sum(axis=1)}")
>>> print(f"Column sums: {result.Q.sum(axis=0)}")
With nearly-isotonic constraints:
>>> result = calibrate_dykstra(P, M, nearly={"mode": "epsilon", "eps": 0.05})
Notes:
- Converges to the exact intersection of the constraint sets
- Preserves ranking within each class (column) by original model scores
- Memory complexity is O(N*J) for the probability matrices
- Time complexity per iteration is O(N*J*log(N)) due to sorting
- For best performance, ensure sum(M) ≈ N and use numba if available
- Raises CalibrationError on convergence failure instead of returning unreliable results
"""
_configure_logging(verbose)
_, J = _validate_inputs(P, M, max_iters, tol, feasibility_tol)
if ties not in ("stable", "group"):
raise ValueError(f"ties must be 'stable' or 'group', got '{ties}'")
P = np.asarray(P, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
# Initialize Dykstra variables
Q = P.copy()
U = np.zeros_like(P, dtype=np.float64) # row memory
V = np.zeros_like(P, dtype=np.float64) # col memory
Q_prev = np.empty_like(Q)
# Precompute column orders once (stable)
column_orders = [np.argsort(P[:, j], kind="mergesort") for j in range(J)]
score_sorted: list[np.ndarray | None] = (
[P[ord_j, j] for j, ord_j in enumerate(column_orders)]
if ties == "group"
else [None] * J
)
Q_history: list[NDArrayFloat] | None = [] if detect_cycles else None
converged = False
final_change = float("inf")
for iteration in range(1, max_iters + 1):
np.copyto(Q_prev, Q)
# Project onto row simplex
Y = Q + U
Q = _project_row_simplex(Y, use_jit=use_jit)
U = Y - Q
# Project onto column constraints
Y = Q + V
for j in range(J):
Q[:, j] = _project_column_isotonic_sum(
Y[:, j],
column_orders[j],
float(M[j]),
rtol=rtol,
nearly=nearly,
ties=ties,
score_sorted=score_sorted[j],
)
V = Y - Q
# Convergence check (relative change + feasibility)
change_abs = np.linalg.norm(Q - Q_prev)
norm_Q_prev = np.linalg.norm(Q_prev)
final_change = (
float(change_abs / norm_Q_prev) if norm_Q_prev > 0 else float(change_abs)
)
row_ok = np.allclose(Q.sum(axis=1), 1.0, atol=1e-12)
col_ok = np.allclose(Q.sum(axis=0), M, atol=1e-10)
if final_change < tol and row_ok and col_ok:
converged = True
logger.info(f"Dykstra converged at iteration {iteration}")
break
# Cycle detection (optional)
if detect_cycles and iteration > cycle_window and Q_history is not None:
if _detect_cycling(Q_history, Q):
warnings.warn(
f"Cycling detected at iteration {iteration}",
UserWarning,
stacklevel=2,
)
break
Q_history.append(Q.copy())
if len(Q_history) > cycle_window:
Q_history.pop(0)
if iteration % 100 == 0 or iteration <= 10:
logger.debug(f"Dykstra iteration {iteration}: change = {final_change:.2e}")
if callback is not None and not callback(iteration, final_change, Q):
break
# If not strictly feasible, polish to the intersection
if not (
np.allclose(Q.sum(axis=1), 1.0, atol=1e-12)
and np.allclose(Q.sum(axis=0), M, atol=1e-10)
):
Q = _polish_to_intersection(
Q,
M,
column_orders,
rtol=rtol,
ties=ties,
score_sorted=score_sorted,
max_iters=100,
)
# If now feasible, count as converged for reporting
if np.allclose(Q.sum(axis=1), 1.0, atol=1e-12) and np.allclose(
Q.sum(axis=0), M, atol=1e-10
):
converged = True
# Diagnostics
row_sums = Q.sum(axis=1)
col_sums = Q.sum(axis=0)
max_row_error = float(np.max(np.abs(row_sums - 1.0)))
max_col_error = float(np.max(np.abs(col_sums - M)))
max_rank_violation = _compute_rank_violation(Q, P)
# Fail fast on non-convergence instead of returning unreliable results
if not converged:
raise CalibrationError(
f"Calibration failed to converge after {iteration} iterations. "
f"Final change: {final_change:.2e} (tolerance: {tol:.2e}). "
f"Max row error: {max_row_error:.2e}, max col error: {max_col_error:.2e}. "
f"Try: increasing max_iters, relaxing tol, using nearly-isotonic constraints "
f"(nearly={{'mode': 'epsilon', 'eps': 0.01}}), or consider temperature scaling."
)
return CalibrationResult(
Q=Q,
converged=converged,
iterations=iteration,
max_row_error=max_row_error,
max_col_error=max_col_error,
max_rank_violation=max_rank_violation,
final_change=final_change,
)
# ---------------------------------------------------------------------
# ADMM calibration (penalty-based, snaps to exact projection)
# ---------------------------------------------------------------------
[docs]
def calibrate_admm(
P: np.ndarray,
M: np.ndarray,
rho: float = 1.0,
max_iters: int = 1000,
tol: float = 1e-6,
rtol: float = 0.0,
feasibility_tol: float = 0.1,
verbose: bool = False,
nearly: dict | None = None,
ties: str = "stable",
use_jit: bool = True,
) -> ADMMResult:
"""Calibrate using ADMM-style optimization with penalty methods.
An alternative to Dykstra's projections that handles row/column sum constraints via
Lagrange multipliers and rank-preservation through either strict isotonic regression
or lambda-penalty nearly-isotonic proximal operators.
The algorithm minimizes ||Q - P||² subject to constraint sets using an augmented
Lagrangian approach. For final optimality verification, the solution is snapped
to the exact intersection using a short Dykstra polish.
Args:
P: Input probability matrix of shape (N, J). Each row represents predicted class
probabilities for one instance. Rows need not sum to 1 initially.
M: Target column sums of shape (J,). Should sum to approximately N for feasibility.
rho: ADMM penalty parameter. Larger values enforce constraints more aggressively.
Default 1.0 works well for most problems.
max_iters: Maximum number of iterations. Default 1000 is usually sufficient.
tol: Convergence tolerance for primal/dual residuals. Default 1e-6.
rtol: Relative tolerance for isotonic violations in PAV. Default 0.0 (strict).
feasibility_tol: Tolerance for feasibility warnings when sum(M) differs from N.
verbose: If True, enables debug logging.
nearly: Optional dict for nearly-isotonic constraints. Use {"mode": "lambda", "lam": 1.0}
for lambda-penalty approach allowing soft isotonicity violations.
ties: How to handle tied scores. "stable" preserves input order, "group" pools
equal-score instances.
use_jit: If True and numba is available, uses JIT-compiled functions for speed.
Returns:
ADMMResult object containing:
- Q: Calibrated probability matrix of shape (N, J)
- converged: Always True (failures raise CalibrationError instead)
- iterations: Number of iterations performed
- max_row_error: Maximum absolute row sum error
- max_col_error: Maximum absolute column sum error
- max_rank_violation: Maximum rank order violation
- final_change: Final relative change in solution
- objective_values: List of objective function values per iteration
- primal_residuals: List of primal residual norms per iteration
- dual_residuals: List of dual residual norms per iteration
Raises:
CalibrationError: If inputs are invalid, algorithm fails to converge, or other errors occur.
ValueError: If ties parameter is not "stable" or "group"
Examples:
Basic ADMM calibration:
>>> import numpy as np
>>> from rank_preserving_calibration import calibrate_admm
>>> P = np.array([[0.7, 0.2, 0.1], [0.3, 0.5, 0.2]])
>>> M = np.array([1.0, 0.7, 0.3])
>>> result = calibrate_admm(P, M)
>>> print(f"Converged: {result.converged}")
>>> print(f"Objective values: {result.objective_values[-5:]}")
With lambda-penalty for soft isotonicity:
>>> result = calibrate_admm(P, M, nearly={"mode": "lambda", "lam": 2.0})
Adjusting penalty parameter:
>>> result = calibrate_admm(P, M, rho=5.0) # Stronger constraint enforcement
Notes:
- Often converges faster than Dykstra for well-conditioned problems
- Provides convergence diagnostics via objective and residual histories
- Lambda-penalty mode allows trading off isotonicity for fit quality
- Final solution is snapped to exact feasible set for optimality
- Experimental: may need parameter tuning for difficult problems
"""
_configure_logging(verbose)
N, J = _validate_inputs(P, M, max_iters, tol, feasibility_tol)
if ties not in ("stable", "group"):
raise ValueError(f"ties must be 'stable' or 'group', got '{ties}'")
P = np.asarray(P, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
# Precompute column orders once (stable)
column_orders = [np.argsort(P[:, j], kind="mergesort") for j in range(J)]
score_sorted: list[np.ndarray | None] = (
[P[ord_j, j] for j, ord_j in enumerate(column_orders)]
if ties == "group"
else [None] * J
)
# Initialize ADMM variables
Q = P.copy()
Z1 = np.ones(N) # row-sum auxiliaries
Z2 = M.copy() # col-sum auxiliaries
lambda1 = np.zeros(N) # row multipliers
lambda2 = np.zeros(J) # col multipliers
objective_values: list[float] = []
primal_residuals: list[float] = []
dual_residuals: list[float] = []
# Initialize lambda penalty using match
match nearly:
case {"mode": "lambda", **rest}:
lam_pen = float(rest.get("lam", 1.0))
case _:
lam_pen = None
converged = False
iteration = 0
for iteration in range(max_iters):
Q_prev = Q.copy()
# Q-update: quadratic + linear equality terms
row_correction = (Z1 - lambda1 / rho).reshape(-1, 1)
col_correction = (Z2 - lambda2 / rho).reshape(1, -1)
Q_unconstrained = (P + rho * (row_correction + col_correction)) / (
1.0 + 2.0 * rho
)
# Rank-preserving + nonnegativity
if lam_pen is not None:
for j in range(J):
idx = column_orders[j]
v_sorted = Q_unconstrained[idx, j]
z = prox_near_isotonic(v_sorted, lam_pen)
if isinstance(z, tuple): # safety with return_info variants
z = z[0]
Q_unconstrained[idx, j] = z
else:
for j in range(J):
idx = column_orders[j]
v_sorted = Q_unconstrained[idx, j]
if ties == "group" and score_sorted[j] is not None:
score_j = score_sorted[j]
if score_j is None:
raise ValueError("score_j unexpectedly None")
lens = _run_lengths_of_equals(score_j)
pos = 0
y_group = np.empty(lens.size, dtype=np.float64)
for g, L in enumerate(lens):
y_group[g] = float(np.mean(v_sorted[pos : pos + L]))
pos += int(L)
z_group = _isotonic_regression(
y_group,
rtol=rtol,
ties="stable",
weights=lens.astype(np.float64),
)
v_sorted[:] = np.repeat(z_group, lens)
else:
v_sorted[:] = _isotonic_regression(
v_sorted, rtol=rtol, ties="stable"
)
Q = np.maximum(Q_unconstrained, 0.0)
# Z-updates (hard equality constraints)
row_sums = Q.sum(axis=1)
col_sums = Q.sum(axis=0)
Z1_prev = Z1.copy()
Z2_prev = Z2.copy()
Z1 = np.ones(N) # enforce row sums = 1
Z2 = M.copy() # enforce col sums = M
# Multiplier updates
lambda1 += rho * (row_sums - Z1)
lambda2 += rho * (col_sums - Z2)
# Residuals & objective (include λ-penalty term if used)
primal_res = np.linalg.norm(np.concatenate([row_sums - Z1, col_sums - Z2]))
dual_res = rho * (np.linalg.norm(Z1 - Z1_prev) + np.linalg.norm(Z2 - Z2_prev))
obj_val = 0.5 * np.linalg.norm(Q - P) ** 2
if lam_pen is not None:
pen = 0.0
for j in range(J):
idx = column_orders[j]
qj = Q[idx, j]
pen += float(np.maximum(qj[:-1] - qj[1:], 0.0).sum())
obj_val += lam_pen * pen
objective_values.append(float(obj_val))
primal_residuals.append(float(primal_res))
dual_residuals.append(float(dual_res))
if iteration % 100 == 0:
logger.debug(
f"ADMM iter {iteration}: obj={obj_val:.3e}, primal={primal_res:.3e}, dual={dual_res:.3e}"
)
if primal_res < tol and dual_res < tol:
converged = True
break
if not converged and verbose:
warnings.warn(
f"ADMM failed to converge after {max_iters} iterations",
UserWarning,
stacklevel=2,
)
# Final change (w.r.t. last iterate in the loop)
if iteration > 0:
final_change = float(
np.linalg.norm(Q - Q_prev) / (1.0 + np.linalg.norm(Q_prev))
)
else:
final_change = float("inf")
# Snap to the exact projection (guarantees distance optimality over feasible set)
try:
snap = calibrate_dykstra(
P,
M,
max_iters=1500,
tol=1e-10,
rtol=0.0,
verbose=False,
detect_cycles=False,
ties="stable",
)
Q = snap.Q
except CalibrationError:
if verbose:
warnings.warn(
"Final snap-to-projection failed; using ADMM solution as-is",
UserWarning,
stacklevel=2,
)
# Diagnostics
row_sums = Q.sum(axis=1)
col_sums = Q.sum(axis=0)
max_row_error = float(np.max(np.abs(row_sums - 1.0)))
max_col_error = float(np.max(np.abs(col_sums - M)))
max_rank_violation = _compute_rank_violation(Q, P)
# Fail fast on non-convergence instead of returning unreliable results
if not converged:
raise CalibrationError(
f"ADMM calibration failed to converge after {iteration + 1} iterations. "
f"Final primal residual: {primal_residuals[-1]:.2e}, "
f"dual residual: {dual_residuals[-1]:.2e} (tolerance: {tol:.2e}). "
f"Try: increasing max_iters, adjusting rho parameter, relaxing tol, "
f"or consider Dykstra's method with nearly-isotonic constraints."
)
return ADMMResult(
Q=Q,
converged=converged,
iterations=iteration + 1,
objective_values=objective_values,
primal_residuals=primal_residuals,
dual_residuals=dual_residuals,
max_row_error=max_row_error,
max_col_error=max_col_error,
max_rank_violation=max_rank_violation,
final_change=final_change,
)