onlinerake.OnlineRakingMWUΒΆ
- class onlinerake.OnlineRakingMWU(targets, learning_rate: float = 1.0, min_weight: float = 0.001, max_weight: float = 100.0, n_steps: int = 3, verbose: bool = False, track_convergence: bool = True, convergence_window: int = 20, compute_weight_stats: bool | int = False)[source]ΒΆ
Bases:
OnlineRakingSGDOnline raking via multiplicative weights updates.
- Parameters:
targets (
Targets) β Target population proportions for each feature.learning_rate (float, optional) β Step size used in the exponent of the multiplicative update. A typical default is
learning_rate=1.0. The algorithm automatically clips extreme exponents based on the weights dtype to prevent numerical overflow/underflow, making it robust even with very large learning rates.min_weight (float, optional) β Lower bound applied to the weights after each update. This prevents weights from collapsing to zero. Must be positive.
max_weight (float, optional) β Upper bound applied to the weights after each update. This prevents runaway weights. Must exceed
min_weight.n_steps (int, optional) β Number of multiplicative updates applied each time a new observation arrives.
compute_weight_stats (bool or int, optional) β Controls computation of weight distribution statistics for performance. If True, compute on every call. If False, never compute. If integer k, compute every k observations. Default is False.
- __init__(targets, learning_rate: float = 1.0, min_weight: float = 0.001, max_weight: float = 100.0, n_steps: int = 3, verbose: bool = False, track_convergence: bool = True, convergence_window: int = 20, compute_weight_stats: bool | int = False) None[source]ΒΆ
Methods
__init__(targets[, learning_rate, ...])check_convergence([tolerance])Check if algorithm has converged based on loss stability.
detect_oscillation([threshold])Detect if loss is oscillating rather than converging.
fit_one(obs)Consume a single observation and update weights multiplicatively.
partial_fit(obs)Consume a single observation and update weights multiplicatively.
partial_fit_batch(observations)Process multiple observations in batch.
Attributes
convergedReturn True if the algorithm has detected convergence.
convergence_stepGet step number where convergence was detected.
effective_sample_sizeReturn the effective sample size (ESS).
gradient_norm_historyGet history of gradient norms.
lossGet current squared-error loss.
loss_moving_averageReturn moving average of loss over convergence window.
marginsGet current weighted margins.
raw_marginsGet unweighted (raw) margins.
weight_distribution_statsReturn comprehensive weight distribution statistics.
weightsGet copy of current weight vector.
- fit_one(obs: dict[str, Any] | Any) NoneΒΆ
Consume a single observation and update weights multiplicatively.
- Parameters:
obs (dict or object) β An observation containing feature indicators. For dict input, keys should match feature names in targets. For object input, features are accessed as attributes. Values should be binary (0/1 or False/True).
- Returns:
The internal state is updated in place.
- Return type:
None
- partial_fit(obs: dict[str, Any] | Any) None[source]ΒΆ
Consume a single observation and update weights multiplicatively.
- Parameters:
obs (dict or object) β An observation containing feature indicators. For dict input, keys should match feature names in targets. For object input, features are accessed as attributes. Values should be binary (0/1 or False/True).
- Returns:
The internal state is updated in place.
- Return type:
None