Source code for stable_cart.stability_utils

"""
Core stability utility functions implementing the 7 stability primitives.

These are the fundamental "atoms" of tree stability that can be composed
across different methods.
"""

import numpy as np
from typing import Tuple, List, Optional, Literal, Union
from dataclasses import dataclass
from sklearn.linear_model import LassoCV, RidgeCV, ElasticNetCV, LogisticRegressionCV
from sklearn.model_selection import train_test_split


[docs] @dataclass class SplitCandidate: """Represents a potential split with all relevant information.""" feature_idx: int threshold: float gain: float left_indices: np.ndarray right_indices: np.ndarray is_oblique: bool = False oblique_weights: Optional[np.ndarray] = None validation_score: Optional[float] = None variance_estimate: Optional[float] = None consensus_support: Optional[float] = None
[docs] @dataclass class StabilityMetrics: """Container for stability diagnostic information.""" prefix_consensus_scores: List[float] validation_consistency: float leaf_variance_estimates: List[float] split_margins: List[float] bootstrap_variance: Optional[float] = None
# ============================================================================ # 1. PREFIX STABILITY # ============================================================================ def bootstrap_consensus_split( X: np.ndarray, y: np.ndarray, n_samples: int = 12, max_candidates: int = 20, threshold: float = 0.5, enable_quantile_binning: bool = True, max_bins: int = 24, random_state: Optional[int] = None, ) -> Tuple[Optional[SplitCandidate], List[SplitCandidate]]: """ Find consensus split using bootstrap voting with quantile-binned thresholds. Returns ------- best_split : SplitCandidate or None Consensus split if one achieves threshold support all_candidates : List[SplitCandidate] All evaluated candidates with their consensus scores """ if len(X) < 10: # Too few samples for meaningful consensus return None, [] rng = np.random.RandomState(random_state) n_samples_bootstrap = max(len(X) // 2, 10) # Collect votes from bootstrap samples candidate_votes = {} # (feature, binned_threshold) -> count for _ in range(n_samples): # Bootstrap sample bootstrap_idx = rng.choice(len(X), size=n_samples_bootstrap, replace=True) X_boot, y_boot = X[bootstrap_idx], y[bootstrap_idx] # Find best splits in this sample candidates = _find_candidate_splits(X_boot, y_boot, max_candidates) for candidate in candidates: # Bin the threshold if enabled if enable_quantile_binning: feature_values = X[:, candidate.feature_idx] binned_threshold = _bin_threshold(candidate.threshold, feature_values, max_bins) else: binned_threshold = candidate.threshold key = (candidate.feature_idx, binned_threshold) candidate_votes[key] = candidate_votes.get(key, 0) + 1 if not candidate_votes: return None, [] # Convert votes to candidates with consensus scores consensus_candidates = [] for (feature_idx, threshold), votes in candidate_votes.items(): consensus_score = votes / n_samples if consensus_score >= threshold: # Evaluate this consensus candidate on full data left_mask = X[:, feature_idx] <= threshold if np.sum(left_mask) > 0 and np.sum(~left_mask) > 0: gain = _evaluate_split_gain(y, left_mask) candidate = SplitCandidate( feature_idx=feature_idx, threshold=threshold, gain=gain, left_indices=np.where(left_mask)[0], right_indices=np.where(~left_mask)[0], consensus_support=consensus_score, ) consensus_candidates.append(candidate) if not consensus_candidates: return None, [] # Return best consensus candidate best_candidate = max(consensus_candidates, key=lambda c: c.gain) return best_candidate, consensus_candidates def _bin_threshold(threshold: float, feature_values: np.ndarray, max_bins: int) -> float: """Bin threshold to quantile grid to reduce micro-jitter.""" if len(np.unique(feature_values)) <= max_bins: return threshold quantiles = np.linspace(0, 1, max_bins + 1) bins = np.quantile(feature_values, quantiles) bins = np.unique(bins) # Remove duplicates # Find closest bin closest_idx = np.argmin(np.abs(bins - threshold)) return bins[closest_idx] def enable_deterministic_tiebreaking(candidates: List[SplitCandidate]) -> List[SplitCandidate]: """Sort candidates deterministically to break ties consistently.""" return sorted( candidates, key=lambda c: ( -c.gain, # Best gain first c.feature_idx, # Consistent feature ordering c.threshold, # Consistent threshold ordering ), ) def apply_margin_veto( candidates: List[SplitCandidate], margin_threshold: float = 0.03 ) -> List[SplitCandidate]: """Veto splits where the margin between best candidates is too small.""" if len(candidates) < 2: return candidates # Sort by gain sorted_candidates = sorted(candidates, key=lambda c: c.gain, reverse=True) best_gain = sorted_candidates[0].gain second_best_gain = sorted_candidates[1].gain if len(sorted_candidates) > 1 else 0 margin = best_gain - second_best_gain relative_margin = margin / (best_gain + 1e-10) if relative_margin < margin_threshold: # Margin too small - return empty to trigger more careful evaluation return [] return [sorted_candidates[0]] # Return only clear winner # ============================================================================ # 2. VALIDATION-CHECKED SPLIT SELECTION # ============================================================================ def validation_checked_split_selection( X_split: np.ndarray, y_split: np.ndarray, X_val: np.ndarray, y_val: np.ndarray, candidates: List[SplitCandidate], metric: Literal["median", "one_se", "variance_penalized"] = "variance_penalized", consistency_weight: float = 1.0, task: str = "regression", ) -> Optional[SplitCandidate]: """ Evaluate split candidates on validation data and select based on consistency. """ if not candidates: return None scored_candidates = [] for candidate in candidates: # Apply split to validation data if candidate.is_oblique and candidate.oblique_weights is not None: val_score = X_val @ candidate.oblique_weights left_mask_val = val_score <= candidate.threshold else: left_mask_val = X_val[:, candidate.feature_idx] <= candidate.threshold # Evaluate on validation set val_score = _evaluate_split_performance(y_val, left_mask_val, task) candidate.validation_score = val_score scored_candidates.append(candidate) # Select based on metric if metric == "median": return _select_by_median_score(scored_candidates) elif metric == "one_se": return _select_by_one_se_rule(scored_candidates) elif metric == "variance_penalized": return _select_by_variance_penalty(scored_candidates, consistency_weight) else: raise ValueError(f"Unknown validation metric: {metric}") def _evaluate_split_performance(y: np.ndarray, left_mask: np.ndarray, task: str) -> float: """Evaluate split performance on validation data.""" if np.sum(left_mask) == 0 or np.sum(~left_mask) == 0: return 0.0 if task == "regression": # Use reduction in MSE total_var = np.var(y) if len(y) > 1 else 0 left_var = np.var(y[left_mask]) if np.sum(left_mask) > 1 else 0 right_var = np.var(y[~left_mask]) if np.sum(~left_mask) > 1 else 0 weighted_var = (np.sum(left_mask) * left_var + np.sum(~left_mask) * right_var) / len(y) return total_var - weighted_var else: # Use reduction in Gini impurity total_gini = _gini_impurity(y) left_gini = _gini_impurity(y[left_mask]) if np.sum(left_mask) > 0 else 0 right_gini = _gini_impurity(y[~left_mask]) if np.sum(~left_mask) > 0 else 0 weighted_gini = (np.sum(left_mask) * left_gini + np.sum(~left_mask) * right_gini) / len(y) return total_gini - weighted_gini def _select_by_variance_penalty( candidates: List[SplitCandidate], penalty_weight: float ) -> Optional[SplitCandidate]: """Select split using validation score minus variance penalty.""" if not candidates: return None best_candidate = None best_score = -np.inf for candidate in candidates: if candidate.validation_score is None: continue # Estimate variance (placeholder - would use bootstrap in practice) variance_penalty = penalty_weight * 0.1 # Simplified penalized_score = candidate.validation_score - variance_penalty if penalized_score > best_score: best_score = penalized_score best_candidate = candidate return best_candidate # ============================================================================ # 3. HONESTY (SPLIT vs EST) # ============================================================================ def honest_data_partition( X: np.ndarray, y: np.ndarray, split_frac: float = 0.6, val_frac: float = 0.2, est_frac: float = 0.2, enable_stratification: bool = True, task: str = "regression", random_state: Optional[int] = None, ) -> Tuple[ Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray] ]: """ Partition data into SPLIT/VAL/EST subsets with optional stratification. Returns ------- (X_split, y_split), (X_val, y_val), (X_est, y_est) """ assert abs(split_frac + val_frac + est_frac - 1.0) < 1e-6, "Fractions must sum to 1" if enable_stratification and task == "regression": # Stratify by target quantiles for regression y_binned = _create_target_bins(y, n_bins=5) stratify = y_binned elif enable_stratification and task == "classification": stratify = y else: stratify = None # First split: SPLIT vs (VAL + EST) test_size = val_frac + est_frac X_split, X_temp, y_split, y_temp = train_test_split( X, y, test_size=test_size, stratify=stratify, random_state=random_state ) # Second split: VAL vs EST est_size_relative = est_frac / (val_frac + est_frac) if enable_stratification and task == "regression": temp_stratify = _create_target_bins(y_temp, n_bins=5) elif enable_stratification and task == "classification": temp_stratify = y_temp else: temp_stratify = None X_val, X_est, y_val, y_est = train_test_split( X_temp, y_temp, test_size=est_size_relative, stratify=temp_stratify, random_state=random_state, ) return (X_split, y_split), (X_val, y_val), (X_est, y_est) def _create_target_bins(y: np.ndarray, n_bins: int = 5) -> np.ndarray: """Create stratification bins for regression targets using quantiles.""" if len(np.unique(y)) <= n_bins: return y.astype(int) quantiles = np.linspace(0, 1, n_bins + 1) bins = np.quantile(y, quantiles) return np.digitize(y, bins[1:-1]) # Exclude first/last bins to avoid edge effects # ============================================================================ # 4. LEAF STABILIZATION # ============================================================================ def stabilize_leaf_estimate( y_est: np.ndarray, y_parent: np.ndarray, strategy: Literal["m_estimate", "shrink_to_parent", "beta_smoothing"] = "m_estimate", smoothing: float = 1.0, task: str = "regression", min_samples: int = 5, ) -> Union[float, np.ndarray]: """ Stabilize leaf estimates using various smoothing strategies. """ if len(y_est) == 0: # Fall back to parent estimate if task == "regression": return float(np.mean(y_parent)) if len(y_parent) > 0 else 0.0 else: # Return uniform probabilities n_classes = len(np.unique(y_parent)) if len(y_parent) > 0 else 2 return np.ones(n_classes) / n_classes if len(y_est) < min_samples and strategy != "shrink_to_parent": # Force shrinkage for very small leaves strategy = "shrink_to_parent" if task == "regression": return _stabilize_regression_leaf(y_est, y_parent, strategy, smoothing) else: return _stabilize_classification_leaf(y_est, y_parent, strategy, smoothing) def _stabilize_regression_leaf( y_est: np.ndarray, y_parent: np.ndarray, strategy: str, smoothing: float ) -> float: """Stabilize regression leaf estimate.""" leaf_mean = np.mean(y_est) parent_mean = np.mean(y_parent) if len(y_parent) > 0 else leaf_mean if strategy == "m_estimate": # M-estimate: weighted average with parent n = len(y_est) return (n * leaf_mean + smoothing * parent_mean) / (n + smoothing) elif strategy == "shrink_to_parent": # James-Stein style shrinkage shrinkage_factor = smoothing / (1 + smoothing) return (1 - shrinkage_factor) * leaf_mean + shrinkage_factor * parent_mean else: # beta_smoothing - simplified for regression return _stabilize_regression_leaf(y_est, y_parent, "m_estimate", smoothing) def _stabilize_classification_leaf( y_est: np.ndarray, y_parent: np.ndarray, strategy: str, smoothing: float ) -> np.ndarray: """Stabilize classification leaf probabilities.""" unique_classes = ( np.unique(np.concatenate([y_est, y_parent])) if len(y_parent) > 0 else np.unique(y_est) ) n_classes = len(unique_classes) # Leaf counts leaf_counts = np.bincount(y_est.astype(int), minlength=n_classes) parent_counts = ( np.bincount(y_parent.astype(int), minlength=n_classes) if len(y_parent) > 0 else leaf_counts ) if strategy == "m_estimate": # M-estimate smoothing prior = ( parent_counts / np.sum(parent_counts) if np.sum(parent_counts) > 0 else np.ones(n_classes) / n_classes ) smoothed_counts = leaf_counts + smoothing * prior * np.sum(leaf_counts) return smoothed_counts / np.sum(smoothed_counts) elif strategy == "beta_smoothing": # Beta-Binomial smoothing alpha = smoothing beta = smoothing return (leaf_counts + alpha) / (np.sum(leaf_counts) + alpha + beta) else: # shrink_to_parent parent_probs = ( parent_counts / np.sum(parent_counts) if np.sum(parent_counts) > 0 else np.ones(n_classes) / n_classes ) leaf_probs = leaf_counts / np.sum(leaf_counts) if np.sum(leaf_counts) > 0 else parent_probs shrinkage_factor = smoothing / (1 + smoothing) return (1 - shrinkage_factor) * leaf_probs + shrinkage_factor * parent_probs # ============================================================================ # 5. DATA REGULARIZATION # ============================================================================ def winsorize_features( X: np.ndarray, quantiles: Tuple[float, float] = (0.01, 0.99), fitted_bounds: Optional[Tuple[np.ndarray, np.ndarray]] = None, ) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """ Winsorize features to reduce outlier influence. Returns ------- X_winsorized : np.ndarray Winsorized features bounds : Tuple[np.ndarray, np.ndarray] (lower_bounds, upper_bounds) for future application """ if fitted_bounds is not None: lower_bounds, upper_bounds = fitted_bounds else: lower_bounds = np.quantile(X, quantiles[0], axis=0) upper_bounds = np.quantile(X, quantiles[1], axis=0) X_winsorized = np.clip(X, lower_bounds, upper_bounds) return X_winsorized, (lower_bounds, upper_bounds) # ============================================================================ # 6. CANDIDATE DIVERSITY WITH DETERMINISTIC RESOLUTION # ============================================================================ def generate_oblique_candidates( X: np.ndarray, y: np.ndarray, strategy: Literal["lasso", "ridge", "elastic_net"] = "lasso", enable_correlation_gating: bool = True, min_correlation: float = 0.3, task: str = "regression", random_state: Optional[int] = None, ) -> List[SplitCandidate]: """ Generate oblique split candidates using linear projections. """ if X.shape[1] < 2: return [] # Need at least 2 features for oblique splits if enable_correlation_gating: # Check if features are correlated enough to justify oblique splits corr_matrix = np.corrcoef(X.T) max_corr = np.max(np.abs(corr_matrix - np.eye(X.shape[1]))) if max_corr < min_correlation: return [] # Features not correlated enough try: if task == "regression": if strategy == "lasso": model = LassoCV(cv=3, random_state=random_state) elif strategy == "ridge": model = RidgeCV(cv=3) else: # elastic_net model = ElasticNetCV(cv=3, random_state=random_state) model.fit(X, y) weights = model.coef_ else: # classification model = LogisticRegressionCV( cv=3, random_state=random_state, penalty="l1", solver="liblinear" ) model.fit(X, y) weights = model.coef_[0] if model.coef_.ndim > 1 else model.coef_ # Only proceed if we got non-trivial weights if np.sum(np.abs(weights) > 1e-6) < 2: return [] # Create oblique split candidate projections = X @ weights # Try different threshold percentiles candidates = [] for percentile in [25, 50, 75]: threshold = np.percentile(projections, percentile) left_mask = projections <= threshold if np.sum(left_mask) > 0 and np.sum(~left_mask) > 0: gain = _evaluate_split_gain(y, left_mask) candidate = SplitCandidate( feature_idx=-1, # Special marker for oblique threshold=threshold, gain=gain, left_indices=np.where(left_mask)[0], right_indices=np.where(~left_mask)[0], is_oblique=True, oblique_weights=weights, ) candidates.append(candidate) return candidates except Exception: # Fallback gracefully if oblique fitting fails return [] def beam_search_splits( X: np.ndarray, y: np.ndarray, depth: int = 2, beam_width: int = 12, enable_ambiguity_gating: bool = True, ambiguity_threshold: float = 0.05, task: str = "regression", ) -> List[SplitCandidate]: """ Use beam search to find splits with lookahead. """ if len(X) < 20: # Too small for meaningful beam search return _find_candidate_splits(X, y, max_candidates=beam_width) # Get initial candidates candidates = _find_candidate_splits(X, y, max_candidates=beam_width * 2) if enable_ambiguity_gating and len(candidates) >= 2: # Check if top candidates are ambiguous enough to justify beam search sorted_candidates = sorted(candidates, key=lambda c: c.gain, reverse=True) top_gain = sorted_candidates[0].gain second_gain = sorted_candidates[1].gain if len(sorted_candidates) > 1 else 0 if top_gain > 0: relative_gap = (top_gain - second_gain) / top_gain if relative_gap > ambiguity_threshold: # Clear winner - no need for expensive beam search return [sorted_candidates[0]] # Perform beam search if we reach here return _perform_beam_search(X, y, candidates[:beam_width], depth, task) def _perform_beam_search( X: np.ndarray, y: np.ndarray, initial_candidates: List[SplitCandidate], depth: int, task: str ) -> List[SplitCandidate]: """Simplified beam search implementation.""" if depth <= 1: return initial_candidates # For now, return initial candidates with improved evaluation # Full beam search would recursively evaluate subsequent splits for candidate in initial_candidates: # Add lookahead score (simplified) lookahead_bonus = 0.1 * candidate.gain # Placeholder for actual lookahead candidate.gain += lookahead_bonus return sorted(initial_candidates, key=lambda c: c.gain, reverse=True) # ============================================================================ # 7. VARIANCE-AWARE STOPPING # ============================================================================ def should_stop_splitting( current_gain: float, variance_estimate: float, variance_weight: float = 1.0, strategy: Literal["one_se", "variance_penalty", "both"] = "variance_penalty", ) -> bool: """ Determine if splitting should stop based on variance-aware criteria. """ if strategy == "variance_penalty": return current_gain < variance_weight * variance_estimate elif strategy == "one_se": # Simplified 1-SE rule return current_gain < variance_estimate # Would use SE in practice else: # both penalty_stop = current_gain < variance_weight * variance_estimate se_stop = current_gain < variance_estimate return penalty_stop or se_stop def estimate_split_variance( X: np.ndarray, y: np.ndarray, split_candidate: SplitCandidate, n_bootstrap: int = 10, task: str = "regression", random_state: Optional[int] = None, ) -> float: """ Estimate variance that would be introduced by this split. """ rng = np.random.RandomState(random_state) n_samples = len(X) bootstrap_scores = [] for _ in range(n_bootstrap): # Bootstrap sample bootstrap_idx = rng.choice(n_samples, size=n_samples, replace=True) X_boot = X[bootstrap_idx] y_boot = y[bootstrap_idx] # Apply the split to bootstrap sample if split_candidate.is_oblique and split_candidate.oblique_weights is not None: projections = X_boot @ split_candidate.oblique_weights left_mask = projections <= split_candidate.threshold else: left_mask = X_boot[:, split_candidate.feature_idx] <= split_candidate.threshold # Evaluate split on this bootstrap sample if np.sum(left_mask) > 0 and np.sum(~left_mask) > 0: score = _evaluate_split_gain(y_boot, left_mask) bootstrap_scores.append(score) if len(bootstrap_scores) < 2: return 0.0 return float(np.var(bootstrap_scores, ddof=1)) # ============================================================================ # HELPER FUNCTIONS # ============================================================================ def _find_candidate_splits( X: np.ndarray, y: np.ndarray, max_candidates: int = 20 ) -> List[SplitCandidate]: """Find basic axis-aligned split candidates.""" candidates = [] n_features = X.shape[1] for feature_idx in range(n_features): feature_values = X[:, feature_idx] unique_values = np.unique(feature_values) if len(unique_values) < 2: continue # Try thresholds between unique values for i in range(min(len(unique_values) - 1, max_candidates // n_features)): threshold = (unique_values[i] + unique_values[i + 1]) / 2 left_mask = feature_values <= threshold if np.sum(left_mask) > 0 and np.sum(~left_mask) > 0: gain = _evaluate_split_gain(y, left_mask) candidate = SplitCandidate( feature_idx=feature_idx, threshold=threshold, gain=gain, left_indices=np.where(left_mask)[0], right_indices=np.where(~left_mask)[0], ) candidates.append(candidate) # Return top candidates candidates.sort(key=lambda c: c.gain, reverse=True) return candidates[:max_candidates] def _evaluate_split_gain(y: np.ndarray, left_mask: np.ndarray) -> float: """Evaluate information gain from a split.""" if len(y) == 0 or np.sum(left_mask) == 0 or np.sum(~left_mask) == 0: return 0.0 # Determine if this looks like regression or classification if len(np.unique(y)) > 10 or y.dtype in [np.float32, np.float64]: # Regression: use variance reduction total_var = np.var(y) if len(y) > 1 else 0 left_var = np.var(y[left_mask]) if np.sum(left_mask) > 1 else 0 right_var = np.var(y[~left_mask]) if np.sum(~left_mask) > 1 else 0 n_left = np.sum(left_mask) n_right = np.sum(~left_mask) n_total = len(y) weighted_var = (n_left * left_var + n_right * right_var) / n_total return total_var - weighted_var else: # Classification: use Gini reduction total_gini = _gini_impurity(y) left_gini = _gini_impurity(y[left_mask]) right_gini = _gini_impurity(y[~left_mask]) n_left = np.sum(left_mask) n_right = np.sum(~left_mask) n_total = len(y) weighted_gini = (n_left * left_gini + n_right * right_gini) / n_total return total_gini - weighted_gini def _gini_impurity(y: np.ndarray) -> float: """Calculate Gini impurity.""" if len(y) == 0: return 0.0 _, counts = np.unique(y, return_counts=True) probabilities = counts / len(y) return 1.0 - np.sum(probabilities**2) def _select_by_median_score(candidates: List[SplitCandidate]) -> Optional[SplitCandidate]: """Select candidate with best median validation score.""" if not candidates: return None scored = [c for c in candidates if c.validation_score is not None] if not scored: return None return max(scored, key=lambda c: c.validation_score) def _select_by_one_se_rule(candidates: List[SplitCandidate]) -> Optional[SplitCandidate]: """Select using one-standard-error rule.""" if not candidates: return None scored = [c for c in candidates if c.validation_score is not None] if not scored: return None scores = [c.validation_score for c in scored] best_score = max(scores) score_std = np.std(scores) if len(scores) > 1 else 0 # Find simplest model within one SE of best threshold = best_score - score_std viable_candidates = [c for c in scored if c.validation_score >= threshold] # Return "simplest" (axis-aligned over oblique, lower feature index) return min(viable_candidates, key=lambda c: (c.is_oblique, c.feature_idx))