Source code for rank_preserving_calibration.soft_calibration

# rank_preserving_calibration/soft_calibration.py
"""
Soft-constraint calibration with tunable trade-offs.

This module provides calibration methods that use soft penalties instead of
hard constraints, allowing users to trade off between:
- Fitting the original probabilities P
- Matching target marginals M
- Preserving rank orderings

The soft-constraint formulation gives a Pareto frontier of solutions rather than
a single "correct" answer, which is more practical when distribution shifts are large.
"""

from __future__ import annotations

import warnings
from dataclasses import dataclass

import numpy as np

from .calibration import (
    CalibrationError,
    _compute_rank_violation,
    _configure_logging,
    _isotonic_regression,
    _project_row_simplex,
)


[docs] @dataclass(slots=True) class SoftCalibrationResult: """Result from soft-constraint calibration. Attributes: Q: Calibrated probability matrix. converged: Whether the algorithm converged. iterations: Number of iterations performed. objective_values: Objective function values over iterations. fit_term: Final ||Q - P||² term. marginal_term: Final ||col_sums(Q) - M||² term. rank_term: Final rank violation penalty term. max_row_error: Maximum row sum error. max_col_error: Maximum column sum error. max_rank_violation: Maximum rank violation. final_change: Final relative change between iterations. """ Q: np.ndarray converged: bool iterations: int objective_values: list[float] fit_term: float marginal_term: float rank_term: float max_row_error: float max_col_error: float max_rank_violation: float final_change: float
def _compute_rank_penalty(Q: np.ndarray, column_orders: list[np.ndarray]) -> float: """Compute total rank violation penalty: sum of (q[i] - q[i+1])_+ for all columns.""" penalty = 0.0 J = Q.shape[1] for j in range(J): q_sorted = Q[column_orders[j], j] if q_sorted.size > 1: diffs = q_sorted[:-1] - q_sorted[1:] penalty += float(np.maximum(diffs, 0.0).sum()) return penalty def _soft_isotonic_projection( v: np.ndarray, lam_r: float, step_size: float = 1.0 ) -> np.ndarray: """Soft isotonic projection via proximal gradient step. When lam_r is large, this approaches strict isotonic regression. When lam_r is small, it barely modifies the input. Args: v: Input vector (should be sorted by original scores). lam_r: Rank penalty weight. step_size: Gradient descent step size. Returns: Softly isotonic vector. """ if lam_r == 0.0 or v.size <= 1: return v.copy() # For large lambda, use exact isotonic regression if lam_r >= 1e6: return _isotonic_regression(v, rtol=0.0, ties="stable") # For moderate lambda, use a soft approach: # Gradient of rank penalty with respect to v is: # d/dv[i] sum_k (v[k] - v[k+1])_+ = +1 if v[i] > v[i+1] (for i < n-1) # -1 if v[i-1] > v[i] (for i > 0) # We take a proximal gradient step: v_new = v - step * lam_r * grad z = v.copy() for _ in range(10): # Few iterations of proximal gradient grad = np.zeros_like(z) for i in range(len(z) - 1): if z[i] > z[i + 1]: grad[i] += 1.0 grad[i + 1] -= 1.0 z = z - step_size * lam_r * grad / (len(z)) return z
[docs] def calibrate_soft( P: np.ndarray, M: np.ndarray, lam_m: float = 1.0, lam_r: float = 10.0, max_iters: int = 1000, tol: float = 1e-6, verbose: bool = False, step_size: float = 0.1, ) -> SoftCalibrationResult: """Calibrate with soft constraints on marginals and ranks. Solves the optimization problem: min_Q ||Q - P||² + lam_m·||col_sums(Q) - M||² + lam_r·rank_penalty(Q) s.t. Q on row simplex (rows >= 0, sum to 1) This allows trading off between: - Staying close to original predictions (small lam_m, lam_r) - Matching marginals (large lam_m) - Preserving ranks (large lam_r) Args: P: Input probability matrix of shape (N, J). M: Target column sums of shape (J,). Should sum to approximately N. lam_m: Marginal penalty weight. Larger = closer to target marginals. Set to 0 to ignore marginals entirely. lam_r: Rank penalty weight. Larger = more isotonic. Set to 0 to allow arbitrary rank violations. max_iters: Maximum number of iterations. tol: Convergence tolerance for relative change. verbose: If True, enables debug logging. step_size: Step size for gradient descent updates. Returns: SoftCalibrationResult with calibrated matrix and diagnostics. Raises: CalibrationError: If inputs are invalid. Examples: >>> import numpy as np >>> from rank_preserving_calibration import calibrate_soft >>> P = np.array([[0.7, 0.2, 0.1], [0.3, 0.5, 0.2]]) >>> M = np.array([1.0, 0.7, 0.3]) >>> # Balanced trade-off >>> result = calibrate_soft(P, M, lam_m=1.0, lam_r=10.0) >>> # Prioritize marginal matching >>> result = calibrate_soft(P, M, lam_m=100.0, lam_r=1.0) >>> # Prioritize rank preservation >>> result = calibrate_soft(P, M, lam_m=0.1, lam_r=100.0) """ _configure_logging(verbose) # Validate inputs if not isinstance(P, np.ndarray) or P.ndim != 2: raise CalibrationError("P must be a 2D numpy array") if P.size == 0: raise CalibrationError("P cannot be empty") if not np.isfinite(P).all(): raise CalibrationError("P must not contain NaN or infinite values") if np.any(P < 0): raise CalibrationError("P must contain non-negative values") N, J = P.shape if J < 2: raise CalibrationError("P must have at least 2 columns (classes)") if not isinstance(M, np.ndarray) or M.ndim != 1 or M.size != J: raise CalibrationError(f"M must be a 1D array of length {J}") if not np.isfinite(M).all(): raise CalibrationError("M must not contain NaN or infinite values") if np.any(M < 0): raise CalibrationError("M must contain non-negative values") if lam_m < 0 or lam_r < 0: raise CalibrationError("lam_m and lam_r must be non-negative") # Warn about feasibility M_sum = float(M.sum()) if abs(M_sum - N) > 0.1 * N: warnings.warn( f"Sum of M ({M_sum:.3f}) differs significantly from N ({N}). " "Consider adjusting target marginals.", UserWarning, stacklevel=2, ) P = np.asarray(P, dtype=np.float64) M = np.asarray(M, dtype=np.float64) # Precompute column orders column_orders = [np.argsort(P[:, j], kind="mergesort") for j in range(J)] # Initialize Q Q = P.copy() objective_values: list[float] = [] converged = False final_change = float("inf") iteration = 0 for iteration in range(1, max_iters + 1): Q_prev = Q.copy() # Gradient of ||Q - P||² is 2(Q - P) grad_fit = 2.0 * (Q - P) # Gradient of lam_m * ||col_sums(Q) - M||² # d/dQ[i,j] = 2 * lam_m * (col_sum_j - M_j) if lam_m > 0: col_sums = Q.sum(axis=0) marginal_error = col_sums - M grad_marginal = ( 2.0 * lam_m * np.ones((N, 1)) @ marginal_error.reshape(1, -1) ) else: grad_marginal = 0.0 # Total gradient (without rank term, handled separately) grad = grad_fit + grad_marginal # Gradient step Q = Q - step_size * grad # Soft isotonic projection for rank preservation if lam_r > 0: for j in range(J): idx = column_orders[j] v_sorted = Q[idx, j] z_sorted = _soft_isotonic_projection( v_sorted, lam_r * step_size, step_size=1.0 ) Q[idx, j] = z_sorted # Project onto row simplex (hard constraint) Q = _project_row_simplex(Q) # Compute objective fit_term = float(np.sum((Q - P) ** 2)) col_sums = Q.sum(axis=0) marginal_term = float(np.sum((col_sums - M) ** 2)) rank_term = _compute_rank_penalty(Q, column_orders) obj_val = fit_term + lam_m * marginal_term + lam_r * rank_term objective_values.append(obj_val) # Convergence check change_norm = np.linalg.norm(Q - Q_prev) Q_norm = np.linalg.norm(Q_prev) final_change = float(change_norm / (Q_norm + 1e-15)) if final_change < tol: converged = True break # Final diagnostics row_sums = Q.sum(axis=1) col_sums = Q.sum(axis=0) max_row_error = float(np.max(np.abs(row_sums - 1.0))) max_col_error = float(np.max(np.abs(col_sums - M))) max_rank_violation = _compute_rank_violation(Q, P) fit_term = float(np.sum((Q - P) ** 2)) marginal_term = float(np.sum((col_sums - M) ** 2)) rank_term = _compute_rank_penalty(Q, column_orders) return SoftCalibrationResult( Q=Q, converged=converged, iterations=iteration, objective_values=objective_values, fit_term=fit_term, marginal_term=marginal_term, rank_term=rank_term, max_row_error=max_row_error, max_col_error=max_col_error, max_rank_violation=max_rank_violation, final_change=final_change, )
[docs] def calibrate_soft_admm( P: np.ndarray, M: np.ndarray, lam_m: float = 1.0, lam_r: float = 10.0, rho: float = 1.0, max_iters: int = 1000, tol: float = 1e-6, verbose: bool = False, ) -> SoftCalibrationResult: """ADMM-based soft calibration with better convergence. Uses ADMM to solve the soft-constraint problem more reliably than simple gradient descent. Args: P: Input probability matrix of shape (N, J). M: Target column sums of shape (J,). lam_m: Marginal penalty weight. lam_r: Rank penalty weight. rho: ADMM penalty parameter. max_iters: Maximum iterations. tol: Convergence tolerance. verbose: Enable debug logging. Returns: SoftCalibrationResult with calibrated matrix. """ _configure_logging(verbose) # Validate inputs if not isinstance(P, np.ndarray) or P.ndim != 2: raise CalibrationError("P must be a 2D numpy array") if P.size == 0: raise CalibrationError("P cannot be empty") N, J = P.shape if J < 2: raise CalibrationError("P must have at least 2 columns") if not isinstance(M, np.ndarray) or M.ndim != 1 or M.size != J: raise CalibrationError(f"M must be a 1D array of length {J}") if lam_m < 0 or lam_r < 0: raise CalibrationError("lam_m and lam_r must be non-negative") P = np.asarray(P, dtype=np.float64) M = np.asarray(M, dtype=np.float64) # Precompute column orders column_orders = [np.argsort(P[:, j], kind="mergesort") for j in range(J)] # ADMM variables: # Q: primal variable (calibrated probabilities) # Z: auxiliary for simplex constraint # Y: auxiliary for isotonic constraint (per column) # u_z, u_y: dual variables Q = P.copy() Z = Q.copy() # simplex copy Y = Q.copy() # isotonic copy u_z = np.zeros_like(Q) u_y = np.zeros_like(Q) objective_values: list[float] = [] converged = False final_change = float("inf") iteration = 0 for iteration in range(1, max_iters + 1): Q_prev = Q.copy() # Q-update: solve quadratic # min (1/2)||Q - P||² + (lam_m/2)||Q.sum(0) - M||² # + (rho/2)||Q - Z + u_z||² + (rho/2)||Q - Y + u_y||² # This is a quadratic in Q, solve via closed form # Gradient is: (Q - P) + lam_m * (col_sums - M) broadcast + rho*(Q - Z + u_z) + rho*(Q - Y + u_y) # Setting to zero: # (1 + 2*rho)*Q + lam_m * broadcast = P + rho*(Z - u_z) + rho*(Y - u_y) # Per-column marginal term complicates this. Use iterative approach. numerator = P + rho * (Z - u_z) + rho * (Y - u_y) # Include marginal contribution if lam_m > 0: # Marginal term adds lam_m * (col_sum - M_j) to each element in column j # So Q_ij has gradient term lam_m * (sum_k Q_kj - M_j) # This couples all rows in a column # For tractability, update Q column by column Q_new = np.zeros_like(Q) for j in range(J): # For column j: solve # (1 + 2*rho + lam_m)*Q_j + lam_m*11^T Q_j / N = numerator_j + lam_m*M_j/N # Actually: gradient = Q_j - P_j + lam_m*(sum(Q_j) - M_j)*1 + rho*(Q_j - Z_j + u_z_j) + rho*(Q_j - Y_j + u_y_j) # = (1 + 2*rho)*Q_j + lam_m*1*1^T Q_j - P_j - rho*Z_j + rho*u_z_j - rho*Y_j + rho*u_y_j + lam_m*(-M_j)*1 # Setting to 0: # (1 + 2*rho)*Q_j + lam_m*1*sum(Q_j) = P_j + rho*Z_j - rho*u_z_j + rho*Y_j - rho*u_y_j + lam_m*M_j*1 # Let c = sum(Q_j). Then: # (1 + 2*rho)*Q_j = RHS - lam_m*c*1 # Q_j = (RHS - lam_m*c*1) / (1 + 2*rho) # Sum both sides: # c = (sum(RHS) - lam_m*c*N) / (1 + 2*rho) # c*(1 + 2*rho) = sum(RHS) - lam_m*N*c # c*(1 + 2*rho + lam_m*N) = sum(RHS) # c = sum(RHS) / (1 + 2*rho + lam_m*N) RHS_j = ( P[:, j] + rho * (Z[:, j] - u_z[:, j]) + rho * (Y[:, j] - u_y[:, j]) + lam_m * M[j] / N * np.ones(N) ) c = RHS_j.sum() / (1 + 2 * rho + lam_m) Q_new[:, j] = (RHS_j - lam_m * c / N * np.ones(N)) / (1 + 2 * rho) Q = Q_new else: Q = numerator / (1 + 2 * rho) # Z-update: project onto row simplex Z = _project_row_simplex(Q + u_z) # Y-update: soft isotonic projection per column for j in range(J): idx = column_orders[j] v_sorted = (Q + u_y)[idx, j] if lam_r >= 1e6: # Hard isotonic y_sorted = _isotonic_regression(v_sorted, rtol=0.0, ties="stable") elif lam_r > 0: # Soft isotonic via prox y_sorted = _soft_isotonic_projection( v_sorted, lam_r / rho, step_size=1.0 ) else: y_sorted = v_sorted Y[idx, j] = y_sorted # Dual updates u_z = u_z + Q - Z u_y = u_y + Q - Y # Compute objective fit_term = float(np.sum((Q - P) ** 2)) col_sums = Q.sum(axis=0) marginal_term = float(np.sum((col_sums - M) ** 2)) rank_term = _compute_rank_penalty(Q, column_orders) obj_val = 0.5 * fit_term + 0.5 * lam_m * marginal_term + lam_r * rank_term objective_values.append(obj_val) # Convergence change_norm = np.linalg.norm(Q - Q_prev) Q_norm = np.linalg.norm(Q_prev) final_change = float(change_norm / (Q_norm + 1e-15)) if final_change < tol: converged = True break # Final Q from Z (the simplex projection) Q = Z # Final diagnostics row_sums = Q.sum(axis=1) col_sums = Q.sum(axis=0) max_row_error = float(np.max(np.abs(row_sums - 1.0))) max_col_error = float(np.max(np.abs(col_sums - M))) max_rank_violation = _compute_rank_violation(Q, P) fit_term = float(np.sum((Q - P) ** 2)) marginal_term = float(np.sum((col_sums - M) ** 2)) rank_term = _compute_rank_penalty(Q, column_orders) return SoftCalibrationResult( Q=Q, converged=converged, iterations=iteration, objective_values=objective_values, fit_term=fit_term, marginal_term=marginal_term, rank_term=rank_term, max_row_error=max_row_error, max_col_error=max_col_error, max_rank_violation=max_rank_violation, final_change=final_change, )