Source code for rank_preserving_calibration.two_stage

# rank_preserving_calibration/two_stage.py
"""
Two-stage calibration approach using Iterative Proportional Fitting (IPF).

This module provides an alternative calibration strategy that may produce
less "flat" solutions when distribution shifts are large:

1. IPF (raking) to match target marginals while preserving relative structure
2. Optional isotonic projection to restore rank ordering

IPF tends to preserve more of the original probability structure compared
to direct Dykstra projection, making it useful when approximate marginal
matching is acceptable.
"""

from __future__ import annotations

import warnings
from dataclasses import dataclass

import numpy as np

from .calibration import (
    CalibrationError,
    _compute_rank_violation,
    _configure_logging,
    _isotonic_regression,
    _project_row_simplex,
)


[docs] @dataclass(slots=True) class TwoStageResult: """Result from two-stage IPF-based calibration. Attributes: Q: Calibrated probability matrix. converged: Whether the algorithm converged. ipf_iterations: Number of IPF iterations performed. projection_iterations: Number of isotonic projection iterations. max_row_error: Maximum row sum error. max_col_error: Maximum column sum error. max_rank_violation: Maximum rank violation. ipf_result: The intermediate IPF result before isotonic projection. """ Q: np.ndarray converged: bool ipf_iterations: int projection_iterations: int max_row_error: float max_col_error: float max_rank_violation: float ipf_result: np.ndarray
[docs] @dataclass(slots=True) class IPFResult: """Result from Iterative Proportional Fitting. Attributes: Q: Probability matrix after IPF. converged: Whether IPF converged. iterations: Number of iterations. max_row_error: Maximum row sum error. max_col_error: Maximum column sum error. final_change: Final relative change between iterations. """ Q: np.ndarray converged: bool iterations: int max_row_error: float max_col_error: float final_change: float
[docs] def calibrate_ipf( P: np.ndarray, M: np.ndarray, max_iters: int = 100, tol: float = 1e-8, verbose: bool = False, ) -> IPFResult: """Iterative Proportional Fitting (raking) to match target marginals. IPF alternately scales rows and columns to match target constraints: - Row scaling: Ensure rows sum to 1 (probability simplex) - Column scaling: Scale columns to match target marginals M Unlike Dykstra/ADMM, IPF preserves the relative structure of probabilities within rows (ratios between classes). This can produce less "flat" solutions but may not perfectly satisfy all constraints simultaneously. Args: P: Input probability matrix of shape (N, J). M: Target column sums of shape (J,). Should sum to N. max_iters: Maximum number of row/column scaling iterations. tol: Convergence tolerance for relative change. verbose: If True, enables debug logging. Returns: IPFResult with the scaled probability matrix. Raises: CalibrationError: If inputs are invalid. Examples: >>> import numpy as np >>> from rank_preserving_calibration import calibrate_ipf >>> P = np.array([[0.7, 0.2, 0.1], [0.3, 0.5, 0.2]]) >>> M = np.array([0.8, 0.8, 0.4]) >>> result = calibrate_ipf(P, M) >>> print(f"Column sums: {result.Q.sum(axis=0)}") """ _configure_logging(verbose) # Validate inputs if not isinstance(P, np.ndarray) or P.ndim != 2: raise CalibrationError("P must be a 2D numpy array") if P.size == 0: raise CalibrationError("P cannot be empty") if not np.isfinite(P).all(): raise CalibrationError("P must not contain NaN or infinite values") if np.any(P < 0): raise CalibrationError("P must contain non-negative values") N, J = P.shape if J < 2: raise CalibrationError("P must have at least 2 columns (classes)") if not isinstance(M, np.ndarray) or M.ndim != 1 or M.size != J: raise CalibrationError(f"M must be a 1D array of length {J}") if not np.isfinite(M).all(): raise CalibrationError("M must not contain NaN or infinite values") if np.any(M < 0): raise CalibrationError("M must contain non-negative values") M_sum = float(M.sum()) if abs(M_sum - N) > 0.1 * N: warnings.warn( f"Sum of M ({M_sum:.3f}) differs significantly from N ({N}). " "IPF may not converge well.", UserWarning, stacklevel=2, ) P = np.asarray(P, dtype=np.float64) M = np.asarray(M, dtype=np.float64) Q = P.copy() # Ensure non-zero entries for multiplicative updates Q = np.maximum(Q, 1e-15) converged = False final_change = float("inf") iteration = 0 for iteration in range(1, max_iters + 1): Q_prev = Q.copy() # Row scaling: normalize rows to sum to 1 row_sums = Q.sum(axis=1, keepdims=True) row_sums = np.maximum(row_sums, 1e-15) Q = Q / row_sums # Column scaling: scale columns to match target marginals col_sums = Q.sum(axis=0) col_sums = np.maximum(col_sums, 1e-15) scale_factors = M / col_sums Q = Q * scale_factors # Check convergence change_norm = np.linalg.norm(Q - Q_prev) Q_norm = np.linalg.norm(Q_prev) final_change = float(change_norm / (Q_norm + 1e-15)) if final_change < tol: converged = True break # Final row normalization row_sums = Q.sum(axis=1, keepdims=True) row_sums = np.maximum(row_sums, 1e-15) Q = Q / row_sums # Compute diagnostics row_sums_final = Q.sum(axis=1) col_sums_final = Q.sum(axis=0) max_row_error = float(np.max(np.abs(row_sums_final - 1.0))) max_col_error = float(np.max(np.abs(col_sums_final - M))) return IPFResult( Q=Q, converged=converged, iterations=iteration, max_row_error=max_row_error, max_col_error=max_col_error, final_change=final_change, )
[docs] def calibrate_two_stage( P: np.ndarray, M: np.ndarray, ipf_max_iters: int = 100, ipf_tol: float = 1e-8, proj_max_iters: int = 100, proj_tol: float = 1e-8, preserve_marginals: bool = False, verbose: bool = False, ) -> TwoStageResult: """Two-stage calibration: IPF followed by isotonic projection. This approach first uses IPF (raking) to approximately match target marginals while preserving the relative probability structure, then applies isotonic projection to restore rank ordering. The two-stage approach can produce less "flat" solutions compared to direct Dykstra projection when distribution shifts are large, because IPF preserves probability ratios rather than doing Euclidean projection. Args: P: Input probability matrix of shape (N, J). M: Target column sums of shape (J,). Should sum to N. ipf_max_iters: Maximum IPF iterations in stage 1. ipf_tol: Convergence tolerance for IPF. proj_max_iters: Maximum iterations for isotonic projection in stage 2. proj_tol: Convergence tolerance for projection. preserve_marginals: If True, re-apply column scaling after isotonic projection to maintain marginals (may re-introduce rank violations). verbose: If True, enables debug logging. Returns: TwoStageResult with calibrated matrix and diagnostics. Raises: CalibrationError: If inputs are invalid. Examples: >>> import numpy as np >>> from rank_preserving_calibration import calibrate_two_stage >>> P = np.array([[0.7, 0.2, 0.1], [0.3, 0.5, 0.2], [0.2, 0.3, 0.5]]) >>> M = np.array([1.0, 1.2, 0.8]) >>> result = calibrate_two_stage(P, M) >>> print(f"Converged: {result.converged}") """ _configure_logging(verbose) # Validate inputs if not isinstance(P, np.ndarray) or P.ndim != 2: raise CalibrationError("P must be a 2D numpy array") if P.size == 0: raise CalibrationError("P cannot be empty") _N, J = P.shape if J < 2: raise CalibrationError("P must have at least 2 columns") if not isinstance(M, np.ndarray) or M.ndim != 1 or M.size != J: raise CalibrationError(f"M must be a 1D array of length {J}") P = np.asarray(P, dtype=np.float64) M = np.asarray(M, dtype=np.float64) # Stage 1: IPF to match marginals ipf_result = calibrate_ipf( P, M, max_iters=ipf_max_iters, tol=ipf_tol, verbose=verbose ) Q = ipf_result.Q.copy() ipf_Q = Q.copy() # Precompute column orders from original P column_orders = [np.argsort(P[:, j], kind="mergesort") for j in range(J)] # Stage 2: Isotonic projection to restore rank ordering converged = ipf_result.converged proj_iterations = 0 for proj_iterations in range(1, proj_max_iters + 1): Q_prev = Q.copy() # Apply isotonic regression per column (maintaining original rank order) for j in range(J): idx = column_orders[j] v_sorted = Q[idx, j] iso = _isotonic_regression(v_sorted, rtol=0.0, ties="stable") Q[idx, j] = iso # Project onto row simplex Q = _project_row_simplex(Q) # Optionally re-scale columns to preserve marginals if preserve_marginals: col_sums = Q.sum(axis=0) col_sums = np.maximum(col_sums, 1e-15) scale_factors = M / col_sums Q = Q * scale_factors # Re-normalize rows row_sums = Q.sum(axis=1, keepdims=True) row_sums = np.maximum(row_sums, 1e-15) Q = Q / row_sums # Check convergence change_norm = np.linalg.norm(Q - Q_prev) Q_norm = np.linalg.norm(Q_prev) final_change = float(change_norm / (Q_norm + 1e-15)) if final_change < proj_tol: break # Compute diagnostics row_sums = Q.sum(axis=1) col_sums = Q.sum(axis=0) max_row_error = float(np.max(np.abs(row_sums - 1.0))) max_col_error = float(np.max(np.abs(col_sums - M))) max_rank_violation = _compute_rank_violation(Q, P) return TwoStageResult( Q=Q, converged=converged, ipf_iterations=ipf_result.iterations, projection_iterations=proj_iterations, max_row_error=max_row_error, max_col_error=max_col_error, max_rank_violation=max_rank_violation, ipf_result=ipf_Q, )