Source code for stable_cart.base_stable_tree

"""
BaseStableTree: Unified implementation of all stability primitives.

This base class implements the 7 core stability "atoms" that can be composed
across different tree methods. Each method can inherit from this and configure
different defaults to maintain their distinct personalities.
"""

from typing import Any, Literal

import numpy as np
from numpy.typing import NDArray
from sklearn.base import BaseEstimator  # type: ignore[import-untyped]
from sklearn.metrics import accuracy_score, r2_score  # type: ignore[import-untyped]
from sklearn.utils.validation import (  # type: ignore[import-untyped]
    check_array,
    check_X_y,
)

from .split_strategies import HybridStrategy, SplitStrategy, create_split_strategy
from .stability_utils import (
    honest_data_partition,
    stabilize_leaf_estimate,
    winsorize_features,
)


[docs] class BaseStableTree(BaseEstimator): """ Unified base class implementing all 7 stability primitives. The 7 stability primitives are: 1. Prefix stability (robust consensus on early splits) 2. Validation-checked split selection 3. Honesty (separate data for structure vs estimation) 4. Leaf stabilization (shrinkage/smoothing) 5. Data regularization (winsorization, etc.) 6. Candidate diversity with deterministic resolution 7. Variance-aware stopping All tree methods inherit from this and configure different defaults to maintain their distinct personalities while sharing the unified stability infrastructure. Parameters ---------- task The prediction task type. max_depth Maximum tree depth. min_samples_split Minimum samples required to split an internal node. min_samples_leaf Minimum samples required in a leaf node. enable_honest_estimation Enable honest estimation (separate data for structure vs estimation). split_frac Fraction of data used for building tree structure. val_frac Fraction of data used for validation. est_frac Fraction of data used for estimation. enable_stratified_sampling Use stratified sampling for data partitioning. enable_validation_checking Enable validation-checked split selection. validation_metric Metric for validation-based split selection. validation_consistency_weight Weight for validation consistency in split selection. enable_prefix_consensus Enable prefix stability through consensus on early splits. prefix_levels Number of tree levels to apply prefix consensus. consensus_samples Number of bootstrap samples for consensus building. consensus_threshold Minimum agreement threshold for consensus splits. enable_quantile_grid_thresholds Use quantile-based threshold grids. max_threshold_bins Maximum number of threshold bins per feature. leaf_smoothing Smoothing parameter for leaf value stabilization. leaf_smoothing_strategy Strategy for leaf value stabilization. enable_calibrated_smoothing Use calibrated smoothing based on sample size. min_leaf_samples_for_stability Minimum samples required for stable leaf estimation. enable_winsorization Enable feature winsorization for robustness. winsor_quantiles Quantiles for winsorization bounds. enable_feature_standardization Standardize features before splitting. enable_oblique_splits Enable oblique (linear combination) splits. oblique_strategy Where to apply oblique splits in the tree. oblique_regularization Regularization for oblique split learning. enable_correlation_gating Gate splits based on feature correlations. min_correlation_threshold Minimum correlation for correlation gating. enable_lookahead Enable lookahead for better split selection. lookahead_depth Depth of lookahead search. beam_width Beam width for lookahead search. enable_ambiguity_gating Gate splits in ambiguous regions. ambiguity_threshold Threshold for ambiguity detection. min_samples_for_lookahead Minimum samples required for lookahead. enable_deterministic_preprocessing Use deterministic preprocessing for reproducibility. enable_deterministic_tiebreaks Use deterministic tiebreaking in split selection. enable_margin_vetoes Enable margin-based split vetoing. margin_threshold Threshold for margin-based vetoing. enable_variance_aware_stopping Enable variance-aware stopping criteria. variance_stopping_weight Weight for variance in stopping decisions. variance_stopping_strategy Strategy for variance-aware stopping. enable_bootstrap_variance_tracking Track split variance using bootstrap sampling. variance_tracking_samples Number of bootstrap samples for variance tracking. enable_explicit_variance_penalty Apply explicit variance penalty to splits. variance_penalty_weight Weight for variance penalty. split_strategy Explicit split strategy specification. algorithm_focus Algorithm focus for automatic strategy selection. classification_criterion Splitting criterion for classification. random_state Random state for reproducibility. enable_threshold_binning Enable threshold binning for continuous features. enable_gain_margin_logic Apply margin logic to information gain. enable_beam_search_for_consensus Use beam search for consensus building. enable_robust_consensus_for_ambiguous Use robust consensus in ambiguous regions. Raises ------ ValueError If split_frac + val_frac + est_frac does not sum to 1.0. """
[docs] def __init__( self, # === TASK AND CORE PARAMETERS === task: str = "regression", max_depth: int = 5, min_samples_split: int = 40, min_samples_leaf: int = 20, # === 3. HONESTY - Data Partitioning === enable_honest_estimation: bool = True, split_frac: float = 0.6, val_frac: float = 0.2, est_frac: float = 0.2, enable_stratified_sampling: bool = True, # === 2. VALIDATION-CHECKED SPLIT SELECTION === enable_validation_checking: bool = True, validation_metric: Literal[ "median", "one_se", "variance_penalized" ] = "variance_penalized", validation_consistency_weight: float = 1.0, # === 1. PREFIX STABILITY === enable_prefix_consensus: bool = False, prefix_levels: int = 2, consensus_samples: int = 12, consensus_threshold: float = 0.5, enable_quantile_grid_thresholds: bool = False, max_threshold_bins: int = 24, # === 4. LEAF STABILIZATION === leaf_smoothing: float = 0.0, leaf_smoothing_strategy: Literal[ "m_estimate", "shrink_to_parent", "beta_smoothing" ] = "m_estimate", enable_calibrated_smoothing: bool = False, min_leaf_samples_for_stability: int = 5, # === 5. DATA REGULARIZATION === enable_winsorization: bool = False, winsor_quantiles: tuple[float, float] = (0.01, 0.99), enable_feature_standardization: bool = False, # === 6. CANDIDATE DIVERSITY === enable_oblique_splits: bool = False, oblique_strategy: Literal["root_only", "all_levels", "adaptive"] = "root_only", oblique_regularization: Literal["lasso", "ridge", "elastic_net"] = "lasso", enable_correlation_gating: bool = True, min_correlation_threshold: float = 0.3, enable_lookahead: bool = False, lookahead_depth: int = 1, beam_width: int = 8, enable_ambiguity_gating: bool = True, ambiguity_threshold: float = 0.05, min_samples_for_lookahead: int = 100, enable_deterministic_preprocessing: bool = False, enable_deterministic_tiebreaks: bool = True, enable_margin_vetoes: bool = False, margin_threshold: float = 0.03, # === 7. VARIANCE-AWARE STOPPING === enable_variance_aware_stopping: bool = False, variance_stopping_weight: float = 1.0, variance_stopping_strategy: Literal[ "one_se", "variance_penalty", "both" ] = "variance_penalty", enable_bootstrap_variance_tracking: bool = False, variance_tracking_samples: int = 10, enable_explicit_variance_penalty: bool = False, variance_penalty_weight: float = 0.1, # === ADVANCED CONFIGURATION === split_strategy: str | None = None, algorithm_focus: Literal["speed", "stability", "accuracy"] = "stability", # === CLASSIFICATION === classification_criterion: Literal["gini", "entropy"] = "gini", # === OTHER === random_state: int | None = None, # === ADDITIONAL PARAMETERS FOR CROSS-METHOD LEARNING === enable_threshold_binning: bool = False, enable_gain_margin_logic: bool = False, enable_beam_search_for_consensus: bool = False, enable_robust_consensus_for_ambiguous: bool = False, ): # Validate fractions sum to 1 if abs(split_frac + val_frac + est_frac - 1.0) > 1e-6: raise ValueError("split_frac + val_frac + est_frac must sum to 1.0") # === CORE PARAMETERS === self.task = task self.max_depth = max_depth self.min_samples_split = min_samples_split self.min_samples_leaf = min_samples_leaf # === 3. HONESTY === self.enable_honest_estimation = enable_honest_estimation self.split_frac = split_frac self.val_frac = val_frac self.est_frac = est_frac self.enable_stratified_sampling = enable_stratified_sampling # === 2. VALIDATION === self.enable_validation_checking = enable_validation_checking self.validation_metric = validation_metric self.validation_consistency_weight = validation_consistency_weight # === 1. PREFIX STABILITY === self.enable_prefix_consensus = enable_prefix_consensus self.prefix_levels = prefix_levels self.consensus_samples = consensus_samples self.consensus_threshold = consensus_threshold self.enable_quantile_grid_thresholds = enable_quantile_grid_thresholds self.max_threshold_bins = max_threshold_bins # === 4. LEAF STABILIZATION === self.leaf_smoothing = leaf_smoothing self.leaf_smoothing_strategy: Literal[ "m_estimate", "shrink_to_parent", "beta_smoothing" ] = leaf_smoothing_strategy self.enable_calibrated_smoothing = enable_calibrated_smoothing self.min_leaf_samples_for_stability = min_leaf_samples_for_stability # === 5. DATA REGULARIZATION === self.enable_winsorization = enable_winsorization self.winsor_quantiles = winsor_quantiles self.enable_feature_standardization = enable_feature_standardization # === 6. CANDIDATE DIVERSITY === self.enable_oblique_splits = enable_oblique_splits self.oblique_strategy = oblique_strategy self.oblique_regularization = oblique_regularization self.enable_correlation_gating = enable_correlation_gating self.min_correlation_threshold = min_correlation_threshold self.enable_lookahead = enable_lookahead self.lookahead_depth = lookahead_depth self.beam_width = beam_width self.enable_ambiguity_gating = enable_ambiguity_gating self.ambiguity_threshold = ambiguity_threshold self.min_samples_for_lookahead = min_samples_for_lookahead self.enable_deterministic_preprocessing = enable_deterministic_preprocessing self.enable_deterministic_tiebreaks = enable_deterministic_tiebreaks self.enable_margin_vetoes = enable_margin_vetoes self.margin_threshold = margin_threshold # === 7. VARIANCE-AWARE STOPPING === self.enable_variance_aware_stopping = enable_variance_aware_stopping self.variance_stopping_weight = variance_stopping_weight self.variance_stopping_strategy = variance_stopping_strategy self.enable_bootstrap_variance_tracking = enable_bootstrap_variance_tracking self.variance_tracking_samples = variance_tracking_samples self.enable_explicit_variance_penalty = enable_explicit_variance_penalty self.variance_penalty_weight = variance_penalty_weight # === ADVANCED === self.split_strategy = split_strategy self.algorithm_focus: Literal["speed", "stability", "accuracy"] = ( algorithm_focus ) # === CLASSIFICATION === self.classification_criterion = classification_criterion # === OTHER === self.random_state = random_state # === CROSS-METHOD LEARNING === self.enable_threshold_binning = enable_threshold_binning self.enable_gain_margin_logic = enable_gain_margin_logic self.enable_beam_search_for_consensus = enable_beam_search_for_consensus self.enable_robust_consensus_for_ambiguous = ( enable_robust_consensus_for_ambiguous ) # Initialize fitted attributes with proper type annotations self.tree_: dict[str, Any] | None = None self.classes_: np.ndarray | None = None self.n_classes_: int | None = None self._split_strategy_: SplitStrategy | None = None self._winsor_bounds_: tuple[np.ndarray, np.ndarray] | None = None self._global_prior_: float | None = None
[docs] def fit(self, X: NDArray[np.floating], y: NDArray[Any]) -> "BaseStableTree": """ Fit the stable tree to the training data. Parameters ---------- X Training feature matrix of shape (n_samples, n_features). y Training target values of shape (n_samples,). Returns ------- BaseStableTree Fitted estimator. Raises ------ ValueError If multi-class classification is attempted (not yet supported). """ # Validate inputs X, y = check_X_y(X, y, accept_sparse=False) # === 1. TASK SETUP === if self.task == "classification": self.classes_ = np.unique(y) self.n_classes_ = len(self.classes_) if self.n_classes_ > 2: raise ValueError( "Multi-class classification not yet supported. " "Use binary classification or regression." ) # Convert to 0/1 for binary classification y = (y == self.classes_[1]).astype(int) self._global_prior_ = float(np.mean(y)) else: self.classes_ = None self.n_classes_ = None self._global_prior_ = float(np.mean(y)) if len(y) > 0 else 0.0 # === 5. DATA REGULARIZATION === X_processed = self._preprocess_features(X) # === 3. HONESTY - Data Partitioning === data_splits = self._partition_data(X_processed, y) (X_split, y_split), (X_val, y_val), (X_est, y_est) = data_splits # === Configure Split Strategy === self._split_strategy_ = self._create_split_strategy() # === Build Tree Structure === self.tree_ = self._build_tree( X_split, y_split, X_val, y_val, X_est, y_est, depth=0 ) # Record timing and diagnostics return self
[docs] def predict(self, X: NDArray[np.floating]) -> NDArray[Any]: """ Predict targets for samples in X. Parameters ---------- X Feature matrix of shape (n_samples, n_features). Returns ------- NDArray[Any] Predicted values of shape (n_samples,). Raises ------ ValueError If the tree has not been fitted. """ check_array(X, accept_sparse=False) if self.tree_ is None: raise ValueError("Tree not fitted yet") # Apply same preprocessing as training X_processed = self._preprocess_features(X, fitted=True) predictions = np.array( [self._predict_sample(x, self.tree_) for x in X_processed] ) if self.task == "classification": # Convert back to original class labels assert self.classes_ is not None, ( "Classes must be defined for classification" ) return np.where(predictions > 0.5, self.classes_[1], self.classes_[0]) else: return predictions
[docs] def predict_proba(self, X: NDArray[np.floating]) -> NDArray[np.floating]: """ Predict class probabilities for classification tasks. Parameters ---------- X Feature matrix of shape (n_samples, n_features). Returns ------- NDArray[np.floating] Class probabilities of shape (n_samples, n_classes). Raises ------ ValueError If called on regression task or tree not fitted. """ if self.task != "classification": raise ValueError("predict_proba is only available for classification tasks") check_array(X, accept_sparse=False) if self.tree_ is None: raise ValueError("Tree not fitted yet") # Apply same preprocessing as training X_processed = self._preprocess_features(X, fitted=True) # Get probability of positive class proba_positive = np.array( [self._predict_sample(x, self.tree_) for x in X_processed] ) # Return as [P(class=0), P(class=1)] proba_negative = 1 - proba_positive return np.column_stack([proba_negative, proba_positive])
[docs] def score(self, X: NDArray[np.floating], y: NDArray[Any]) -> float: """ Return the mean accuracy (classification) or R² (regression). Parameters ---------- X Feature matrix for evaluation. y True target values. Returns ------- float Accuracy for classification, R² for regression. """ y_pred = self.predict(X) if self.task == "regression": return r2_score(y, y_pred) else: return accuracy_score(y, y_pred)
# ======================================================================== # INTERNAL METHODS - STABILITY PRIMITIVES # ======================================================================== def _preprocess_features( self, X: NDArray[np.floating], fitted: bool = False ) -> NDArray[np.floating]: """ Apply data regularization preprocessing. Parameters ---------- X Feature matrix to preprocess. fitted Whether to use fitted preprocessing parameters. Returns ------- NDArray[np.floating] Preprocessed feature matrix. """ X_processed = X.copy() # === 5. DATA REGULARIZATION === if self.enable_winsorization: if fitted and self._winsor_bounds_ is not None: X_processed, _ = winsorize_features( X_processed, fitted_bounds=self._winsor_bounds_ ) else: X_processed, self._winsor_bounds_ = winsorize_features( X_processed, self.winsor_quantiles ) # Feature standardization (rarely needed for trees) if self.enable_feature_standardization: # Would implement standardization here pass return X_processed def _partition_data( self, X: NDArray[np.floating], y: NDArray[Any] ) -> tuple[ tuple[NDArray[np.floating], NDArray[Any]], tuple[NDArray[np.floating], NDArray[Any]], tuple[NDArray[np.floating], NDArray[Any]], ]: """ Partition data using honest splitting. Parameters ---------- X Feature matrix to partition. y Target values to partition. Returns ------- tuple[tuple[NDArray[np.floating], NDArray[Any]], tuple[NDArray[np.floating], NDArray[Any]], tuple[NDArray[np.floating], NDArray[Any]]] Tuple of (split_data, val_data, est_data) where each is (X, y). """ if not self.enable_honest_estimation: # Use all data for both structure and estimation return (X, y), (X, y), (X, y) return honest_data_partition( X, y, split_frac=self.split_frac, val_frac=self.val_frac, est_frac=self.est_frac, enable_stratification=self.enable_stratified_sampling, task=self.task, random_state=self.random_state, ) def _create_split_strategy(self) -> SplitStrategy: """ Create the split strategy based on enabled features. Returns ------- SplitStrategy Configured split strategy instance. """ if self.split_strategy is not None: # Explicit strategy specified return create_split_strategy( self.split_strategy, task=self.task, random_state=self.random_state, # Pass relevant parameters oblique_regularization=self.oblique_regularization, enable_correlation_gating=self.enable_correlation_gating, min_correlation=self.min_correlation_threshold, consensus_samples=self.consensus_samples, consensus_threshold=self.consensus_threshold, lookahead_depth=self.lookahead_depth, beam_width=self.beam_width, variance_penalty_weight=self.variance_penalty_weight, ) else: # Auto-select based on enabled features and algorithm focus return HybridStrategy( focus=self.algorithm_focus, task=self.task, random_state=self.random_state, ) def _build_tree( self, X_split: NDArray[np.floating], y_split: NDArray[Any], X_val: NDArray[np.floating], y_val: NDArray[Any], X_est: NDArray[np.floating], y_est: NDArray[Any], depth: int = 0, ) -> dict[str, Any]: """ Recursively build the tree structure. Parameters ---------- X_split Features for structure building. y_split Targets for structure building. X_val Features for validation. y_val Targets for validation. X_est Features for estimation. y_est Targets for estimation. depth Current tree depth. Returns ------- dict[str, Any] Tree node dictionary. """ n_samples = len(X_split) # Ensure split strategy is initialized assert self._split_strategy_ is not None, ( "Split strategy must be initialized before building tree" ) # Base stopping conditions if ( depth >= self.max_depth or n_samples < self.min_samples_split or len(np.unique(y_split)) <= 1 ): return self._make_leaf(y_est, y_split, depth) # Find best split using configured strategy best_split = self._split_strategy_.find_best_split( X_split, y_split, X_val if self.enable_validation_checking else None, y_val if self.enable_validation_checking else None, depth=depth, max_depth=self.max_depth, min_samples_split=self.min_samples_split, min_samples_leaf=self.min_samples_leaf, ) if best_split is None: return self._make_leaf(y_est, y_split, depth) # === 7. VARIANCE-AWARE STOPPING === if ( self.enable_variance_aware_stopping and best_split.variance_estimate is not None ): should_stop = self._split_strategy_.should_stop( X_split, y_split, best_split.gain, depth, variance_estimate=best_split.variance_estimate, max_depth=self.max_depth, min_samples_split=self.min_samples_split, ) if should_stop: return self._make_leaf(y_est, y_split, depth) # Apply split to all data partitions left_indices_split, right_indices_split = self._apply_split_to_data( X_split, best_split ) left_indices_val, right_indices_val = self._apply_split_to_data( X_val, best_split ) left_indices_est, right_indices_est = self._apply_split_to_data( X_est, best_split ) # Check minimum leaf size if ( len(left_indices_split) < self.min_samples_leaf or len(right_indices_split) < self.min_samples_leaf ): return self._make_leaf(y_est, y_split, depth) # Recursively build children left_child = self._build_tree( X_split[left_indices_split], y_split[left_indices_split], X_val[left_indices_val], y_val[left_indices_val], X_est[left_indices_est], y_est[left_indices_est], depth + 1, ) right_child = self._build_tree( X_split[right_indices_split], y_split[right_indices_split], X_val[right_indices_val], y_val[right_indices_val], X_est[right_indices_est], y_est[right_indices_est], depth + 1, ) # Create internal node return { "type": "split_oblique" if best_split.is_oblique else "split", "feature_idx": best_split.feature_idx, "threshold": best_split.threshold, "gain": best_split.gain, "depth": depth, "n_samples_split": len(X_split), "n_samples_val": len(X_val), "n_samples_est": len(X_est), "oblique_weights": best_split.oblique_weights if best_split.is_oblique else None, "consensus_support": getattr(best_split, "consensus_support", None), "variance_estimate": getattr(best_split, "variance_estimate", None), "left": left_child, "right": right_child, } def _apply_split_to_data( self, X: NDArray[np.floating], split_candidate: Any ) -> tuple[NDArray[np.int_], NDArray[np.int_]]: """ Apply a split to data and return left/right indices. Parameters ---------- X Feature array to split. split_candidate Split candidate containing split information. Returns ------- tuple[NDArray[np.int_], NDArray[np.int_]] Tuple of (left_indices, right_indices). """ if split_candidate.is_oblique and split_candidate.oblique_weights is not None: projections = X @ split_candidate.oblique_weights left_mask = projections <= split_candidate.threshold else: left_mask = X[:, split_candidate.feature_idx] <= split_candidate.threshold left_indices = np.where(left_mask)[0] right_indices = np.where(~left_mask)[0] return left_indices, right_indices def _make_leaf( self, y_est: NDArray[Any], y_split: NDArray[Any], depth: int ) -> dict[str, Any]: """ Create a leaf node with stabilized estimates. Parameters ---------- y_est Target values for estimation. y_split Target values from structure building. depth Current tree depth. Returns ------- dict[str, Any] Leaf node dictionary. """ # === 4. LEAF STABILIZATION === if len(y_est) == 0: y_est = y_split # Fallback to split data # Get parent data for shrinkage (use split data as proxy) stabilized_value = stabilize_leaf_estimate( y_est, y_split, strategy=self.leaf_smoothing_strategy, smoothing=self.leaf_smoothing, task=self.task, min_samples=self.min_leaf_samples_for_stability, ) if self.task == "regression": return { "type": "leaf", "value": stabilized_value, "depth": depth, "n_samples_split": len(y_split), "n_samples_est": len(y_est), } else: # For classification, stabilized_value is probability array or scalar if isinstance(stabilized_value, (float, int)): prob = stabilized_value else: # stabilized_value is an array of class probabilities if len(stabilized_value) >= 2: prob = stabilized_value[1] # P(class=1) for binary classification else: # Only one class present, assume class 0 prob = 0.0 return { "type": "leaf", "proba": float(prob), "depth": depth, "n_samples_split": len(y_split), "n_samples_est": len(y_est), } def _predict_sample(self, x: NDArray[np.floating], node: Any) -> float: """ Predict a single sample by traversing the tree. Parameters ---------- x Single sample feature vector. node Current tree node. Returns ------- float Predicted value or probability. """ if node["type"] == "leaf": if self.task == "regression": return node["value"] else: return node["proba"] # Apply split if node["type"] == "split_oblique" and node["oblique_weights"] is not None: projection = x @ node["oblique_weights"] go_left = projection <= node["threshold"] else: go_left = x[node["feature_idx"]] <= node["threshold"] # Recurse if go_left: return self._predict_sample(x, node["left"]) else: return self._predict_sample(x, node["right"])