# rank_preserving_calibration/kl_calibration.py
"""
KL-divergence rank-preserving calibration.
This module provides rank-preserving calibration using KL divergence (relative entropy)
as the loss function instead of squared Euclidean distance. KL divergence is a natural
choice for probability calibration in label-shift scenarios.
Key innovations:
1. **Anchor-Reference Decoupling**: Separate A (ranking anchor) from R (reference for KL)
2. **Geometric Mean Pooling**: KL isotonic regression uses geometric (not arithmetic) mean
3. **Multiplicative Rescaling**: Sum constraints via multiplication (not addition)
4. **Pareto Frontier**: Report whole λ-path, not single tuned point
Exports:
- KLCalibrationResult: Result container for hard KL calibration
- KLParetoResult: Result container for Pareto frontier sweep
- calibrate_kl: Main hard KL solver with anchor-reference decoupling
- calibrate_kl_soft: Soft KL calibration with λ penalty
- calibrate_kl_pareto: Pareto frontier solver with warm-start sweep
"""
from __future__ import annotations
import warnings
from dataclasses import dataclass, field
from typing import Any
import numpy as np
from .calibration import (
CalibrationError,
_compute_rank_violation,
_configure_logging,
_validate_inputs,
)
type NDArrayFloat = np.ndarray[Any, np.dtype[np.floating[Any]]]
__all__ = [
"KLCalibrationResult",
"KLParetoResult",
"calibrate_kl",
"calibrate_kl_pareto",
"calibrate_kl_soft",
]
# ---------------------------------------------------------------------
# Data containers
# ---------------------------------------------------------------------
[docs]
@dataclass(slots=True)
class KLCalibrationResult:
"""Result container for KL-divergence rank-preserving calibration.
Attributes:
Q: Calibrated probability matrix of shape (N, J) where rows sum to 1
and columns preserve rank ordering from anchor scores.
converged: True if algorithm converged within specified tolerance.
iterations: Number of iterations performed before termination.
kl_divergence: Final KL(Q||R) divergence value.
max_row_error: Maximum absolute error in row sum constraint.
max_col_error: Maximum absolute error in column sum constraint.
max_rank_violation: Maximum rank-order violation across all columns.
final_change: Final relative change in solution between iterations.
"""
Q: np.ndarray
converged: bool
iterations: int
kl_divergence: float
max_row_error: float
max_col_error: float
max_rank_violation: float
final_change: float
[docs]
@dataclass(slots=True)
class KLParetoResult:
"""Result container for KL Pareto frontier computation.
Attributes:
solutions: List of (lambda, Q) pairs along the Pareto frontier.
kl_values: KL divergence values for each solution.
rank_violations: Total rank violation for each solution.
lambda_path: Lambda values used in the sweep.
"""
solutions: list[tuple[float, np.ndarray]] = field(default_factory=list)
kl_values: list[float] = field(default_factory=list)
rank_violations: list[float] = field(default_factory=list)
lambda_path: list[float] = field(default_factory=list)
# ---------------------------------------------------------------------
# KL divergence utilities
# ---------------------------------------------------------------------
def _safe_log(x: np.ndarray, eps: float = 1e-300) -> np.ndarray:
"""Safe log avoiding -inf for small values."""
return np.log(np.maximum(x, eps))
def _kl_div_matrix(Q: np.ndarray, R: np.ndarray, eps: float = 1e-300) -> float:
"""Compute KL(Q||R) = sum Q * log(Q/R) for probability matrices.
Uses convention 0 * log(0/x) = 0.
"""
Q = np.asarray(Q, dtype=np.float64)
R = np.asarray(R, dtype=np.float64)
# Handle zeros: 0 * log(0/x) = 0
mask = Q > eps
kl = np.zeros_like(Q)
kl[mask] = Q[mask] * (_safe_log(Q[mask], eps) - _safe_log(R[mask], eps))
return float(np.sum(kl))
# ---------------------------------------------------------------------
# KL isotonic regression (generalized PAV with geometric mean)
# ---------------------------------------------------------------------
def _kl_isotonic_regression(
y: np.ndarray,
weights: np.ndarray | None = None,
eps: float = 1e-300,
) -> np.ndarray:
"""Isotonic regression minimizing weighted KL divergence.
Unlike Euclidean PAV which pools using arithmetic mean, KL isotonic
regression uses **geometric mean** pooling:
z_B = exp(1/|B| * sum_{i in B} log(w_i))
This is the exact projection onto the isotone cone for KL divergence.
Args:
y: Input sequence to make isotonic (non-decreasing).
weights: Positive weights for each element. If None, uses y as weights.
eps: Small constant for numerical stability.
Returns:
Isotonic fit minimizing KL divergence.
"""
y = np.asarray(y, dtype=np.float64)
n = y.size
if n <= 1:
return np.maximum(y.copy(), eps)
# Ensure positive values for log
y = np.maximum(y, eps)
if weights is None:
w = np.ones(n, dtype=np.float64)
else:
w = np.asarray(weights, dtype=np.float64)
if w.shape != y.shape:
raise ValueError("weights must have same shape as y")
if np.any(w <= 0):
raise ValueError("weights must be positive")
# Work in log space for geometric mean
log_y = np.log(y)
# Block stacks: start index, weighted log sum, weight sum
start = np.empty(n, dtype=np.int64)
log_sum = np.empty(n, dtype=np.float64)
wsum = np.empty(n, dtype=np.float64)
top = -1
for i in range(n):
top += 1
start[top] = i
log_sum[top] = w[i] * log_y[i]
wsum[top] = w[i]
# Merge backward while violating monotonicity
# Geometric mean of block: exp(log_sum / wsum)
while top > 0:
left_mean = log_sum[top - 1] / wsum[top - 1]
right_mean = log_sum[top] / wsum[top]
if left_mean <= right_mean:
break
# Merge blocks
log_sum[top - 1] += log_sum[top]
wsum[top - 1] += wsum[top]
top -= 1
# Expand block means
z = np.empty(n, dtype=np.float64)
for j in range(top + 1):
s = start[j]
e = start[j + 1] if j < top else n
block_mean = log_sum[j] / wsum[j]
z[s:e] = np.exp(block_mean)
return z
# ---------------------------------------------------------------------
# KL column projection (isotonic + multiplicative sum constraint)
# ---------------------------------------------------------------------
def _project_column_kl_isotonic_sum(
column: np.ndarray,
column_order: np.ndarray,
target_sum: float,
weights: np.ndarray | None = None,
eps: float = 1e-300,
) -> np.ndarray:
"""Project column onto KL-isotonic with sum constraint.
Key difference from Euclidean:
- Uses generalized PAV (geometric mean pooling)
- Uses **multiplicative** rescaling to hit sum target
Args:
column: Column values to project.
column_order: Indices that sort by anchor scores.
target_sum: Target sum for the column.
weights: Optional weights for isotonic regression.
eps: Small constant for numerical stability.
Returns:
Projected column satisfying isotonicity and sum constraint.
"""
if column.size == 0:
return column.copy()
# Sort by anchor order
y = np.maximum(column[column_order], eps)
# Apply KL isotonic regression
z_iso = _kl_isotonic_regression(y, weights=weights, eps=eps)
# Multiplicative rescaling to hit target sum
current_sum = z_iso.sum()
if current_sum > eps:
scale = target_sum / current_sum
z_scaled = z_iso * scale
else:
# Edge case: all near zero, distribute uniformly
z_scaled = np.full_like(z_iso, target_sum / z_iso.size)
# Restore original order
projected = np.empty_like(column, dtype=np.float64)
projected[column_order] = z_scaled
return projected
# ---------------------------------------------------------------------
# KL row simplex projection (multiplicative normalization)
# ---------------------------------------------------------------------
def _project_row_kl_simplex(rows: np.ndarray, eps: float = 1e-300) -> np.ndarray:
"""Project rows onto probability simplex using KL projection.
For KL divergence, the projection onto the simplex is simply
**multiplicative normalization**: q_ij = p_ij / sum_k(p_ik).
This is different from the Euclidean simplex projection which
uses a sorting-based algorithm with additive adjustments.
Args:
rows: Matrix of shape (N, J) to project.
eps: Small constant for numerical stability.
Returns:
Projected matrix with rows summing to 1.
"""
rows = np.maximum(rows, eps)
row_sums = rows.sum(axis=1, keepdims=True)
row_sums = np.maximum(row_sums, eps)
return rows / row_sums
# ---------------------------------------------------------------------
# Main KL calibration (Dykstra-style alternating projections)
# ---------------------------------------------------------------------
[docs]
def calibrate_kl(
P: np.ndarray,
M: np.ndarray,
R: np.ndarray | None = None,
A: np.ndarray | None = None,
max_iters: int = 3000,
tol: float = 1e-7,
feasibility_tol: float = 0.1,
verbose: bool = False,
eps: float = 1e-300,
) -> KLCalibrationResult:
"""Calibrate using KL divergence with Dykstra's alternating projections.
Projects multiclass probabilities onto the intersection of:
(A) row simplex: {rows ≥ 0, rows sum to 1}
(B) column-wise isotone-by-anchor + fixed column sums
Minimizes KL(Q||R) subject to constraints, where:
- R is the reference distribution for KL divergence (default: P)
- A is the anchor for rank ordering (default: P)
This anchor-reference decoupling allows separating "what we measure
divergence from" (R) from "what determines rank order" (A).
Args:
P: Input probability matrix of shape (N, J).
M: Target column sums of shape (J,).
R: Reference distribution for KL divergence. If None, uses P.
A: Anchor for rank ordering. If None, uses P.
max_iters: Maximum number of iterations.
tol: Convergence tolerance for relative change.
feasibility_tol: Tolerance for feasibility warnings.
verbose: If True, enables debug logging.
eps: Small constant for numerical stability.
Returns:
KLCalibrationResult with calibrated matrix and diagnostics.
Raises:
CalibrationError: If inputs are invalid or algorithm fails to converge.
Examples:
Basic KL calibration:
>>> import numpy as np
>>> from rank_preserving_calibration import calibrate_kl
>>> P = np.array([[0.7, 0.2, 0.1], [0.3, 0.5, 0.2]])
>>> M = np.array([1.0, 0.7, 0.3])
>>> result = calibrate_kl(P, M)
With anchor-reference decoupling:
>>> # Use P for KL reference, but A for rank ordering
>>> A = np.array([[0.6, 0.3, 0.1], [0.4, 0.4, 0.2]])
>>> result = calibrate_kl(P, M, R=P, A=A)
"""
_configure_logging(verbose)
_N, J = _validate_inputs(P, M, max_iters, tol, feasibility_tol)
P = np.asarray(P, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
# Set defaults for R and A
if R is None:
R = P.copy()
else:
R = np.asarray(R, dtype=np.float64)
if R.shape != P.shape:
raise CalibrationError(f"R must have shape {P.shape}, got {R.shape}")
if A is None:
A = P.copy()
else:
A = np.asarray(A, dtype=np.float64)
if A.shape != P.shape:
raise CalibrationError(f"A must have shape {P.shape}, got {A.shape}")
# Precompute column orders from anchor A
column_orders = [np.argsort(A[:, j], kind="mergesort") for j in range(J)]
# Initialize Q close to R (for KL minimization)
Q = np.maximum(R.copy(), eps)
# Dykstra memory terms (in multiplicative form for KL)
# For KL projections, we use multiplicative corrections
U = np.ones_like(P, dtype=np.float64) # row memory (multiplicative)
V = np.ones_like(P, dtype=np.float64) # col memory (multiplicative)
converged = False
final_change = float("inf")
iteration = 0
for iteration in range(1, max_iters + 1):
Q_prev = Q.copy()
# Row projection with Dykstra correction
Y = Q * U
Q = _project_row_kl_simplex(Y, eps=eps)
# Update multiplicative memory
U = Y / np.maximum(Q, eps)
# Column projections with Dykstra correction
Y = Q * V
for j in range(J):
Y[:, j] = _project_column_kl_isotonic_sum(
Y[:, j],
column_orders[j],
float(M[j]),
eps=eps,
)
Q = Y
V = Y / np.maximum(Q, eps)
# Convergence check
change_abs = np.linalg.norm(Q - Q_prev)
norm_Q_prev = np.linalg.norm(Q_prev)
final_change = (
float(change_abs / norm_Q_prev) if norm_Q_prev > 0 else float(change_abs)
)
row_ok = np.allclose(Q.sum(axis=1), 1.0, atol=1e-10)
col_ok = np.allclose(Q.sum(axis=0), M, atol=1e-8)
if final_change < tol and row_ok and col_ok:
converged = True
break
# Final diagnostics
row_sums = Q.sum(axis=1)
col_sums = Q.sum(axis=0)
max_row_error = float(np.max(np.abs(row_sums - 1.0)))
max_col_error = float(np.max(np.abs(col_sums - M)))
max_rank_violation = _compute_rank_violation(Q, A)
kl_divergence = _kl_div_matrix(Q, R, eps=eps)
if not converged:
raise CalibrationError(
f"KL calibration failed to converge after {iteration} iterations. "
f"Final change: {final_change:.2e} (tolerance: {tol:.2e}). "
f"Max row error: {max_row_error:.2e}, max col error: {max_col_error:.2e}. "
f"Try: increasing max_iters or relaxing tol."
)
return KLCalibrationResult(
Q=Q,
converged=converged,
iterations=iteration,
kl_divergence=kl_divergence,
max_row_error=max_row_error,
max_col_error=max_col_error,
max_rank_violation=max_rank_violation,
final_change=final_change,
)
# ---------------------------------------------------------------------
# Soft KL calibration with λ penalty
# ---------------------------------------------------------------------
def _compute_rank_penalty_from_orders(
Q: np.ndarray, column_orders: list[np.ndarray]
) -> float:
"""Compute total rank violation penalty: sum of (q[i] - q[i+1])_+ for all columns."""
penalty = 0.0
J = Q.shape[1]
for j in range(J):
q_sorted = Q[column_orders[j], j]
if q_sorted.size > 1:
diffs = q_sorted[:-1] - q_sorted[1:]
penalty += float(np.maximum(diffs, 0.0).sum())
return penalty
[docs]
def calibrate_kl_soft(
P: np.ndarray,
M: np.ndarray,
lam: float = 1.0,
R: np.ndarray | None = None,
A: np.ndarray | None = None,
max_iters: int = 1000,
tol: float = 1e-6,
verbose: bool = False,
eps: float = 1e-300,
) -> KLCalibrationResult:
"""Soft KL calibration with λ-weighted rank penalty.
Solves:
min KL(Q||R) + λ·V_A(Q) s.t. Q ∈ C(M)
Where:
- KL(Q||R) is the KL divergence from reference R
- V_A(Q) is the rank violation penalty using anchor A
- C(M) is row simplex + column sums = M
Args:
P: Input probability matrix of shape (N, J).
M: Target column sums of shape (J,).
lam: Rank penalty weight. Larger = more isotonic.
R: Reference for KL divergence. Default: P.
A: Anchor for rank ordering. Default: P.
max_iters: Maximum iterations.
tol: Convergence tolerance.
verbose: Enable debug logging.
eps: Numerical stability constant.
Returns:
KLCalibrationResult with calibrated matrix.
Examples:
>>> result = calibrate_kl_soft(P, M, lam=10.0) # Strong rank enforcement
>>> result = calibrate_kl_soft(P, M, lam=0.1) # Weak rank enforcement
"""
_configure_logging(verbose)
_, J = _validate_inputs(P, M, max_iters, tol, feasibility_tol=0.1)
P = np.asarray(P, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
if R is None:
R = P.copy()
else:
R = np.asarray(R, dtype=np.float64)
if A is None:
A = P.copy()
else:
A = np.asarray(A, dtype=np.float64)
column_orders = [np.argsort(A[:, j], kind="mergesort") for j in range(J)]
# Initialize
Q = np.maximum(P.copy(), eps)
Q = _project_row_kl_simplex(Q, eps=eps)
converged = False
final_change = float("inf")
for iteration in range(1, max_iters + 1):
Q_prev = Q.copy()
# Gradient of KL(Q||R) w.r.t. Q: log(Q) - log(R) + 1
# For multiplicative gradient descent, use: Q * grad
# Take exponentiated gradient step: Q_new = Q * exp(-step * grad)
step = 0.1 / (1 + iteration * 0.01)
# KL gradient component
kl_grad = _safe_log(Q, eps) - _safe_log(R, eps) + 1.0
# Rank penalty gradient (subgradient)
rank_grad = np.zeros_like(Q)
for j in range(J):
idx = column_orders[j]
q_sorted = Q[idx, j]
for i in range(len(q_sorted) - 1):
if q_sorted[i] > q_sorted[i + 1]:
rank_grad[idx[i], j] += 1.0
rank_grad[idx[i + 1], j] -= 1.0
# Combined gradient
total_grad = kl_grad + lam * rank_grad
# Multiplicative update (exponentiated gradient)
Q = Q * np.exp(-step * total_grad)
Q = np.maximum(Q, eps)
# Project onto row simplex
Q = _project_row_kl_simplex(Q, eps=eps)
# Project columns onto sum constraints
for j in range(J):
current_sum = Q[:, j].sum()
if current_sum > eps:
Q[:, j] *= M[j] / current_sum
# Convergence check
change_norm = np.linalg.norm(Q - Q_prev)
Q_norm = np.linalg.norm(Q_prev)
final_change = float(change_norm / (Q_norm + 1e-15))
if final_change < tol:
converged = True
break
# Final diagnostics
row_sums = Q.sum(axis=1)
col_sums = Q.sum(axis=0)
max_row_error = float(np.max(np.abs(row_sums - 1.0)))
max_col_error = float(np.max(np.abs(col_sums - M)))
max_rank_violation = _compute_rank_violation(Q, A)
kl_divergence = _kl_div_matrix(Q, R, eps=eps)
if not converged:
warnings.warn(
f"KL soft calibration did not converge after {iteration} iterations. "
f"Final change: {final_change:.2e}",
UserWarning,
stacklevel=2,
)
return KLCalibrationResult(
Q=Q,
converged=converged,
iterations=iteration,
kl_divergence=kl_divergence,
max_row_error=max_row_error,
max_col_error=max_col_error,
max_rank_violation=max_rank_violation,
final_change=final_change,
)
# ---------------------------------------------------------------------
# Pareto frontier solver
# ---------------------------------------------------------------------
[docs]
def calibrate_kl_pareto(
P: np.ndarray,
M: np.ndarray,
lambda_grid: list[float] | np.ndarray | None = None,
R: np.ndarray | None = None,
A: np.ndarray | None = None,
max_iters: int = 500,
tol: float = 1e-5,
verbose: bool = False,
eps: float = 1e-300,
) -> KLParetoResult:
"""Compute Pareto frontier of KL divergence vs rank violation.
Sweeps over λ values with warm starting, reporting the full
Pareto frontier rather than a single tuned point.
Args:
P: Input probability matrix of shape (N, J).
M: Target column sums of shape (J,).
lambda_grid: Grid of λ values to sweep. Default: geometric grid.
R: Reference for KL divergence. Default: P.
A: Anchor for rank ordering. Default: P.
max_iters: Max iterations per λ value.
tol: Convergence tolerance.
verbose: Enable debug logging.
eps: Numerical stability constant.
Returns:
KLParetoResult with solutions along the frontier.
Examples:
>>> result = calibrate_kl_pareto(P, M)
>>> for lam, Q in result.solutions:
... print(f"λ={lam:.2f}: KL={result.kl_values[i]:.4f}")
"""
_configure_logging(verbose)
P = np.asarray(P, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
if R is None:
R = P.copy()
else:
R = np.asarray(R, dtype=np.float64)
if A is None:
A = P.copy()
else:
A = np.asarray(A, dtype=np.float64)
if lambda_grid is None:
lambda_grid = np.logspace(-2, 3, 20)
else:
lambda_grid = np.asarray(lambda_grid)
column_orders = [np.argsort(A[:, j], kind="mergesort") for j in range(A.shape[1])]
result = KLParetoResult()
# Initialize with smallest lambda
Q_warm = np.maximum(P.copy(), eps)
Q_warm = _project_row_kl_simplex(Q_warm, eps=eps)
for lam in sorted(lambda_grid):
# Warm-start from previous solution
try:
sol = calibrate_kl_soft(
P,
M,
lam=float(lam),
R=R,
A=A,
max_iters=max_iters,
tol=tol,
verbose=False,
eps=eps,
)
Q = sol.Q
except CalibrationError:
continue
kl_val = _kl_div_matrix(Q, R, eps=eps)
rank_viol = _compute_rank_penalty_from_orders(Q, column_orders)
result.solutions.append((float(lam), Q.copy()))
result.kl_values.append(kl_val)
result.rank_violations.append(rank_viol)
result.lambda_path.append(float(lam))
Q_warm = Q
return result