Source code for stable_cart.unified_robust_prefix_tree

"""
RobustPrefixHonestTree: Enhanced with cross-method learning.

Now inherits from BaseStableTree and incorporates lessons from:
- LessGreedyHybridTree: Oblique splits, lookahead, ambiguity/correlation gating
- BootstrapVariancePenalizedTree: Explicit variance tracking
"""

from typing import Any, Literal

import numpy as np

from .base_stable_tree import BaseStableTree


[docs] class RobustPrefixHonestTree(BaseStableTree): """ Robust prefix honest tree with unified stability primitives. Enhanced with cross-method learning: - Oblique splits (from LessGreedy): Add Lasso-based oblique splits to locked prefix - Lookahead with beam search (from LessGreedy): Replace depth-1 stumps with k-step - Ambiguity gating (from LessGreedy): Apply consensus only when splits are ambiguous - Correlation gating (from LessGreedy): Check feature correlation before oblique splits - Explicit variance tracking (from Bootstrap): Monitor prediction variance as diagnostic Core Features: - Robust consensus-based prefix splits with honest leaf estimation - Winsorization for outlier robustness - Stratified honest data partitioning - Advanced consensus mechanisms with threshold binning Parameters ---------- task Prediction task type. max_depth Maximum tree depth. min_samples_leaf Minimum samples per leaf. top_levels Number of prefix levels to lock using robust consensus. consensus_samples Number of bootstrap samples for consensus. consensus_threshold Threshold for consensus decisions. consensus_subsample_frac Subsample fraction per bootstrap. val_frac Fraction of data for validation. est_frac Fraction of data for estimation. enable_stratified_sampling Enable stratified sampling in data partitioning. enable_winsorization Enable feature winsorization. winsor_quantiles Quantile bounds for winsorization. enable_threshold_binning Enable threshold binning to reduce micro-jitter. max_threshold_bins Maximum number of threshold bins. enable_oblique_splits Enable oblique split capability. oblique_strategy Strategy for oblique splits. oblique_regularization Regularization type for oblique splits. enable_correlation_gating Enable correlation-based feature gating. min_correlation_threshold Minimum correlation for feature selection. enable_lookahead Enable lookahead search. lookahead_depth Depth for lookahead search. beam_width Width of beam search. enable_beam_search_for_consensus Enable beam search for consensus. enable_ambiguity_gating Enable ambiguity-based gating. ambiguity_threshold Threshold for ambiguity detection. enable_gain_margin_logic Enable gain margin logic. margin_threshold Threshold for margin-based decisions. enable_bootstrap_variance_tracking Enable bootstrap variance tracking. variance_tracking_samples Number of samples for variance tracking. enable_explicit_variance_penalty Enable explicit variance penalty. variance_penalty_weight Weight for variance penalty. smoothing Smoothing parameter for leaf estimates. leaf_smoothing_strategy Strategy for leaf smoothing. classification_criterion Criterion for classification splits. random_state Random state for reproducibility. """
[docs] def __init__( self, task: Literal["regression", "classification"] = "regression", # === CORE TREE PARAMETERS === max_depth: int = 6, min_samples_leaf: int = 2, # More aggressive for RobustPrefix style # === ROBUST PREFIX CONSENSUS === top_levels: int = 2, # Signature feature: how many levels to make robust consensus_samples: int = 12, # Signature feature: bootstrap samples for consensus consensus_threshold: float = 0.5, consensus_subsample_frac: float = 0.8, # === HONEST PARTITIONING === val_frac: float = 0.2, est_frac: float = 0.4, # Larger estimation set for robust leaves enable_stratified_sampling: bool = True, # Signature feature # === OUTLIER ROBUSTNESS === enable_winsorization: bool = True, # Signature feature winsor_quantiles: tuple[float, float] = (0.01, 0.99), # === THRESHOLD DISCRETIZATION === enable_threshold_binning: bool = True, # Signature feature: reduce micro-jitter max_threshold_bins: int = 24, # === ENHANCED: OBLIQUE SPLITS (from LessGreedy) === enable_oblique_splits: bool = True, # NEW: add to locked prefix oblique_strategy: Literal[ "root_only", "all_levels", "adaptive" ] = "root_only", # Conservative for robust method oblique_regularization: Literal["lasso", "ridge", "elastic_net"] = "lasso", enable_correlation_gating: bool = True, # NEW: from LessGreedy min_correlation_threshold: float = 0.3, # === ENHANCED: LOOKAHEAD WITH BEAM SEARCH (from LessGreedy) === enable_lookahead: bool = True, # NEW: replace depth-1 stumps in consensus lookahead_depth: int = 2, # More informed prefix decisions beam_width: int = 12, enable_beam_search_for_consensus: bool = True, # NEW: enhanced consensus # === ENHANCED: AMBIGUITY GATING (from LessGreedy) === enable_ambiguity_gating: bool = True, # NEW: only apply expensive consensus when needed ambiguity_threshold: float = 0.05, # Similar to gain-margin logic enable_gain_margin_logic: bool = True, margin_threshold: float = 0.03, # === ENHANCED: VARIANCE TRACKING (from Bootstrap) === enable_bootstrap_variance_tracking: bool = True, # NEW: diagnostic variance_tracking_samples: int = 10, enable_explicit_variance_penalty: bool = False, # Optional enhancement variance_penalty_weight: float = 0.1, # === LEAF STABILIZATION === smoothing: float = 1.0, # m-estimate smoothing leaf_smoothing_strategy: Literal[ "m_estimate", "shrink_to_parent" ] = "m_estimate", # === CLASSIFICATION === classification_criterion: Literal["gini", "entropy"] = "gini", random_state: int | None = None, ): # Compute split_frac from val_frac and est_frac split_frac = 1.0 - val_frac - est_frac # Configure defaults that reflect RobustPrefix's personality super().__init__( task=task, max_depth=max_depth, min_samples_split=min_samples_leaf * 2, # Derive from min_samples_leaf min_samples_leaf=min_samples_leaf, # Honest partitioning - core feature enable_honest_estimation=True, split_frac=split_frac, val_frac=val_frac, est_frac=est_frac, enable_stratified_sampling=enable_stratified_sampling, # Validation checking - always enabled enable_validation_checking=True, validation_metric="median", # Robust approach # Robust prefix consensus - signature feature enable_prefix_consensus=True, prefix_levels=top_levels, consensus_samples=consensus_samples, consensus_threshold=consensus_threshold, # Outlier robustness - signature feature enable_winsorization=enable_winsorization, winsor_quantiles=winsor_quantiles, # Threshold discretization - signature feature enable_threshold_binning=enable_threshold_binning, enable_quantile_grid_thresholds=enable_threshold_binning, max_threshold_bins=max_threshold_bins, # ENHANCED: Oblique splits (from LessGreedy) enable_oblique_splits=enable_oblique_splits, oblique_strategy=oblique_strategy, oblique_regularization=oblique_regularization, enable_correlation_gating=enable_correlation_gating, min_correlation_threshold=min_correlation_threshold, # ENHANCED: Lookahead (from LessGreedy) enable_lookahead=enable_lookahead, lookahead_depth=lookahead_depth, beam_width=beam_width, enable_beam_search_for_consensus=enable_beam_search_for_consensus, # ENHANCED: Ambiguity gating (from LessGreedy) enable_ambiguity_gating=enable_ambiguity_gating, ambiguity_threshold=ambiguity_threshold, enable_gain_margin_logic=enable_gain_margin_logic, enable_margin_vetoes=enable_gain_margin_logic, margin_threshold=margin_threshold, # ENHANCED: Variance tracking (from Bootstrap) enable_bootstrap_variance_tracking=enable_bootstrap_variance_tracking, variance_tracking_samples=variance_tracking_samples, enable_explicit_variance_penalty=enable_explicit_variance_penalty, variance_penalty_weight=variance_penalty_weight, # Deterministic processing - signature feature enable_deterministic_preprocessing=True, enable_deterministic_tiebreaks=True, # Leaf stabilization - signature feature leaf_smoothing=smoothing, leaf_smoothing_strategy=leaf_smoothing_strategy, enable_calibrated_smoothing=True, # Classification classification_criterion=classification_criterion, # Focus on maximum stability algorithm_focus="stability", random_state=random_state, ) # Store RobustPrefix-specific parameters self.top_levels = top_levels self.consensus_B = consensus_samples self.consensus_subsample_frac = consensus_subsample_frac self.consensus_max_bins = max_threshold_bins # Cross-method enhancement flags self.enable_beam_search_for_consensus = enable_beam_search_for_consensus self.enable_bootstrap_variance_tracking = enable_bootstrap_variance_tracking self.enable_explicit_variance_penalty = enable_explicit_variance_penalty
[docs] def fit(self, X: np.ndarray, y: np.ndarray) -> "RobustPrefixHonestTree": """ Fit with robust prefix consensus. Parameters ---------- X Training features. y Training targets. Returns ------- RobustPrefixHonestTree Fitted estimator. Raises ------ ValueError If multi-class classification is attempted. """ # Validate for binary classification only if self.task == "classification": unique_classes = np.unique(y) if len(unique_classes) > 2: raise ValueError( "Multi-class classification not yet supported. " "RobustPrefixHonestTree currently supports binary classification only." ) super().fit(X, y) return self
[docs] def get_params(self, deep: bool = True) -> dict[str, Any]: """ Get parameters for sklearn compatibility. Parameters ---------- deep Whether to return deep parameter copy. Returns ------- dict[str, Any] Parameter dictionary. """ return super().get_params(deep=deep)
[docs] def set_params(self, **params: Any) -> "RobustPrefixHonestTree": """ Set parameters for sklearn compatibility. Parameters ---------- **params Parameter values to set. Returns ------- RobustPrefixHonestTree Self with updated parameters. """ return super().set_params(**params)