Source code for stable_cart.unified_less_greedy_tree

"""
LessGreedyHybridTree: Enhanced with cross-method learning.

Now inherits from BaseStableTree and incorporates lessons from:
- RobustPrefixHonestTree: Winsorization, consensus for ambiguous splits, stratified sampling
- BootstrapVariancePenalizedTree: Explicit variance tracking
"""

from typing import Any, Literal

from .base_stable_tree import BaseStableTree


[docs] class LessGreedyHybridTree(BaseStableTree): """ LessGreedyHybridTree with unified stability primitives. Enhanced with cross-method learning: - Winsorization (from RobustPrefix) - Bootstrap consensus for ambiguous splits (from RobustPrefix) - Stratified sampling (from RobustPrefix) - Explicit variance tracking (from Bootstrap) Core Features: - Honest data partitioning with lookahead beam search - Optional oblique root splits using regularized linear models - Leaf smoothing (shrinkage for regression, m-estimate for classification) - Advanced split selection with multiple strategies Parameters ---------- task Prediction task type. max_depth Maximum tree depth. min_samples_split Minimum samples to split a node. min_samples_leaf Minimum samples per leaf. split_frac Fraction of data for structure building. val_frac Fraction of data for validation. est_frac Fraction of data for estimation. enable_stratified_sampling Enable stratified sampling in data partitioning. enable_oblique_splits Enable oblique split capability. oblique_strategy Strategy for oblique splits. oblique_regularization Regularization type for oblique splits. enable_correlation_gating Enable correlation-based feature gating. min_correlation_threshold Minimum correlation for feature selection. enable_lookahead Enable lookahead search. lookahead_depth Depth for lookahead search. beam_width Width of beam search. enable_ambiguity_gating Enable ambiguity-based gating. ambiguity_threshold Threshold for ambiguity detection. min_samples_for_lookahead Minimum samples required for lookahead. enable_robust_consensus_for_ambiguous Enable robust consensus for ambiguous splits. consensus_samples Number of samples for consensus. consensus_threshold Threshold for consensus decisions. enable_threshold_binning Enable threshold binning to reduce micro-jitter. max_threshold_bins Maximum number of threshold bins. enable_winsorization Enable feature winsorization. winsor_quantiles Quantile bounds for winsorization. enable_bootstrap_variance_tracking Enable bootstrap variance tracking. variance_tracking_samples Number of samples for variance tracking. enable_explicit_variance_penalty Enable explicit variance penalty. variance_penalty_weight Weight for variance penalty. leaf_smoothing Smoothing parameter for leaf estimates. leaf_smoothing_strategy Strategy for leaf smoothing. enable_gain_margin_logic Enable gain margin logic. margin_threshold Threshold for margin-based decisions. classification_criterion Criterion for classification splits. random_state Random state for reproducibility. """
[docs] def __init__( self, task: Literal["regression", "classification"] = "regression", # === CORE TREE PARAMETERS === max_depth: int = 5, min_samples_split: int = 40, min_samples_leaf: int = 20, # === HONEST PARTITIONING === split_frac: float = 0.6, val_frac: float = 0.2, est_frac: float = 0.2, enable_stratified_sampling: bool = True, # ENHANCED: from RobustPrefix # === OBLIQUE SPLITS === enable_oblique_splits: bool = True, # Signature feature oblique_strategy: Literal["root_only", "all_levels", "adaptive"] = "root_only", oblique_regularization: Literal["lasso", "ridge", "elastic_net"] = "lasso", enable_correlation_gating: bool = True, min_correlation_threshold: float = 0.3, # === LOOKAHEAD WITH BEAM SEARCH === enable_lookahead: bool = True, # Signature feature lookahead_depth: int = 2, beam_width: int = 12, enable_ambiguity_gating: bool = True, ambiguity_threshold: float = 0.05, min_samples_for_lookahead: int = 600, # === ENHANCED: CONSENSUS FOR AMBIGUOUS SPLITS (from RobustPrefix) === enable_robust_consensus_for_ambiguous: bool = True, # NEW consensus_samples: int = 12, consensus_threshold: float = 0.5, enable_threshold_binning: bool = True, # NEW: reduce micro-jitter max_threshold_bins: int = 24, # === ENHANCED: OUTLIER ROBUSTNESS (from RobustPrefix) === enable_winsorization: bool = True, # NEW: robust preprocessing winsor_quantiles: tuple = (0.01, 0.99), # === ENHANCED: VARIANCE TRACKING (from Bootstrap) === enable_bootstrap_variance_tracking: bool = True, # NEW: diagnostic variance_tracking_samples: int = 10, enable_explicit_variance_penalty: bool = False, # Optional enhancement variance_penalty_weight: float = 0.1, # === LEAF STABILIZATION === leaf_smoothing: float = 0.0, # Conservative default for LessGreedy leaf_smoothing_strategy: Literal[ "m_estimate", "shrink_to_parent" ] = "shrink_to_parent", # === MARGIN-BASED LOGIC === enable_gain_margin_logic: bool = True, # Signature feature margin_threshold: float = 0.03, # === CLASSIFICATION === classification_criterion: Literal["gini", "entropy"] = "gini", random_state: int | None = None, ): # Configure defaults that reflect LessGreedy's personality super().__init__( task=task, max_depth=max_depth, min_samples_split=min_samples_split, min_samples_leaf=min_samples_leaf, # Honest partitioning - core feature enable_honest_estimation=True, split_frac=split_frac, val_frac=val_frac, est_frac=est_frac, enable_stratified_sampling=enable_stratified_sampling, # Validation checking - always enabled enable_validation_checking=True, validation_metric="variance_penalized", # ENHANCED: Outlier robustness (from RobustPrefix) enable_winsorization=enable_winsorization, winsor_quantiles=winsor_quantiles, # Oblique splits - signature feature enable_oblique_splits=enable_oblique_splits, oblique_strategy=oblique_strategy, oblique_regularization=oblique_regularization, enable_correlation_gating=enable_correlation_gating, min_correlation_threshold=min_correlation_threshold, # Lookahead - signature feature enable_lookahead=enable_lookahead, lookahead_depth=lookahead_depth, beam_width=beam_width, enable_ambiguity_gating=enable_ambiguity_gating, ambiguity_threshold=ambiguity_threshold, min_samples_for_lookahead=min_samples_for_lookahead, # ENHANCED: Consensus for ambiguous splits (from RobustPrefix) enable_robust_consensus_for_ambiguous=enable_robust_consensus_for_ambiguous, enable_prefix_consensus=enable_robust_consensus_for_ambiguous, # For ambiguous splits consensus_samples=consensus_samples, consensus_threshold=consensus_threshold, enable_threshold_binning=enable_threshold_binning, max_threshold_bins=max_threshold_bins, # Margin logic - signature feature enable_margin_vetoes=enable_gain_margin_logic, margin_threshold=margin_threshold, # ENHANCED: Variance tracking (from Bootstrap) enable_bootstrap_variance_tracking=enable_bootstrap_variance_tracking, variance_tracking_samples=variance_tracking_samples, enable_explicit_variance_penalty=enable_explicit_variance_penalty, variance_penalty_weight=variance_penalty_weight, # Leaf stabilization - conservative for accuracy leaf_smoothing=leaf_smoothing, leaf_smoothing_strategy=leaf_smoothing_strategy, # Classification classification_criterion=classification_criterion, # Focus on balanced accuracy + stability algorithm_focus="accuracy", random_state=random_state, ) # Store LessGreedy-specific parameters self.inner_k = 1 # Simplified for unified version self.oblique_cv = 5 # Fixed for unified version # Cross-method enhancement flags self.enable_robust_consensus_for_ambiguous = ( enable_robust_consensus_for_ambiguous ) self.enable_bootstrap_variance_tracking = enable_bootstrap_variance_tracking self.enable_explicit_variance_penalty = enable_explicit_variance_penalty
[docs] def get_params(self, deep: bool = True) -> dict[str, Any]: """ Get parameters for sklearn compatibility. Parameters ---------- deep Whether to return deep parameter copy. Returns ------- dict[str, Any] Parameter dictionary. """ # Use the parent method which gets constructor parameters return super().get_params(deep=deep)
[docs] def set_params(self, **params: Any) -> "LessGreedyHybridTree": """ Set parameters for sklearn compatibility. Parameters ---------- **params Parameter values to set. Returns ------- LessGreedyHybridTree Self with updated parameters. """ return super().set_params(**params)