"""
BootstrapVariancePenalizedTree: Enhanced with cross-method learning.
Now inherits from BaseStableTree and incorporates lessons from:
- RobustPrefixHonestTree: Stratified bootstraps, winsorization, threshold binning, robust consensus
- LessGreedyHybridTree: Oblique splits, lookahead, beam search
"""
from typing import Literal, Optional
from .base_stable_tree import BaseStableTree
[docs]
class BootstrapVariancePenalizedTree(BaseStableTree):
"""
Bootstrap variance penalized tree with unified stability primitives.
Enhanced with cross-method learning:
- Stratified bootstraps (from RobustPrefix)
- Winsorization (from RobustPrefix)
- Threshold binning/bucketing (from RobustPrefix)
- Robust consensus mechanism (from RobustPrefix)
- Oblique splits (from LessGreedy)
- Lookahead (from LessGreedy)
- Beam search (from LessGreedy)
Core Features:
- Explicit bootstrap variance penalty during split selection
- Honest data partitioning for unbiased estimation
- Advanced split strategies with variance awareness
"""
[docs]
def __init__(
self,
task: Literal["regression", "classification"] = "regression",
# === CORE TREE PARAMETERS ===
max_depth: int = 5,
min_samples_split: int = 40,
min_samples_leaf: int = 20,
# === BOOTSTRAP VARIANCE PENALTY ===
variance_penalty: float = 1.0, # Signature feature
n_bootstrap: int = 10,
bootstrap_max_depth: int = 2, # Depth for variance estimation trees
enable_variance_aware_stopping: bool = True, # Signature feature
# === HONEST PARTITIONING ===
split_frac: float = 0.6,
val_frac: float = 0.2,
est_frac: float = 0.2,
enable_stratified_sampling: bool = True, # ENHANCED: from RobustPrefix
# === ENHANCED: STRATIFIED BOOTSTRAPS (from RobustPrefix) ===
enable_stratified_bootstraps: bool = True, # NEW: target-stratified sampling
bootstrap_stratification_bins: int = 5, # For regression quantile bins
# === ENHANCED: WINSORIZATION (from RobustPrefix) ===
enable_winsorization: bool = True, # NEW: apply before bootstrap sampling
winsor_quantiles: tuple = (0.01, 0.99),
# === ENHANCED: THRESHOLD BINNING (from RobustPrefix) ===
enable_threshold_binning: bool = True, # NEW: bin thresholds to reduce micro-jitter
max_threshold_bins: int = 24,
# === ENHANCED: ROBUST CONSENSUS (from RobustPrefix) ===
enable_robust_consensus: bool = True, # NEW: replace SimpleTree with consensus
consensus_samples: int = 12,
consensus_threshold: float = 0.5,
# === ENHANCED: OBLIQUE SPLITS (from LessGreedy) ===
enable_oblique_splits: bool = True, # NEW: can significantly reduce bootstrap variance
oblique_strategy: Literal["root_only", "all_levels", "adaptive"] = "adaptive",
oblique_regularization: Literal["lasso", "ridge", "elastic_net"] = "lasso",
enable_correlation_gating: bool = True,
min_correlation_threshold: float = 0.3,
# === ENHANCED: LOOKAHEAD (from LessGreedy) ===
enable_lookahead: bool = True, # NEW: combine with variance penalty
lookahead_depth: int = 1, # Conservative for variance method
beam_width: int = 8, # Smaller beam for efficiency
enable_ambiguity_gating: bool = True, # Use lookahead when penalty alone is ambiguous
ambiguity_threshold: float = 0.1, # More conservative threshold
min_samples_for_lookahead: int = 100,
# === LEAF STABILIZATION ===
leaf_smoothing: float = 0.0, # Conservative default
leaf_smoothing_strategy: Literal["m_estimate", "shrink_to_parent"] = "m_estimate",
# === MARGIN-BASED LOGIC ===
enable_gain_margin_logic: bool = True,
margin_threshold: float = 0.03,
# === CLASSIFICATION ===
classification_criterion: Literal["gini", "entropy"] = "gini",
random_state: Optional[int] = None,
):
# Configure defaults that reflect Bootstrap method's personality
super().__init__(
task=task,
max_depth=max_depth,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
# Honest partitioning - core feature
enable_honest_estimation=True,
split_frac=split_frac,
val_frac=val_frac,
est_frac=est_frac,
enable_stratified_sampling=enable_stratified_sampling,
# Validation checking - always enabled
enable_validation_checking=True,
validation_metric="variance_penalized", # Signature approach
# ENHANCED: Winsorization (from RobustPrefix)
enable_winsorization=enable_winsorization,
winsor_quantiles=winsor_quantiles,
# ENHANCED: Threshold binning (from RobustPrefix)
enable_threshold_binning=enable_threshold_binning,
max_threshold_bins=max_threshold_bins,
# ENHANCED: Robust consensus (from RobustPrefix)
enable_prefix_consensus=enable_robust_consensus,
consensus_samples=consensus_samples,
consensus_threshold=consensus_threshold,
enable_quantile_grid_thresholds=enable_threshold_binning,
# ENHANCED: Oblique splits (from LessGreedy)
enable_oblique_splits=enable_oblique_splits,
oblique_strategy=oblique_strategy,
oblique_regularization=oblique_regularization,
enable_correlation_gating=enable_correlation_gating,
min_correlation_threshold=min_correlation_threshold,
# ENHANCED: Lookahead (from LessGreedy)
enable_lookahead=enable_lookahead,
lookahead_depth=lookahead_depth,
beam_width=beam_width,
enable_ambiguity_gating=enable_ambiguity_gating,
ambiguity_threshold=ambiguity_threshold,
min_samples_for_lookahead=min_samples_for_lookahead,
# Variance awareness - signature feature
enable_variance_aware_stopping=enable_variance_aware_stopping,
variance_stopping_weight=variance_penalty,
enable_bootstrap_variance_tracking=True,
variance_tracking_samples=n_bootstrap,
enable_explicit_variance_penalty=True, # Core feature
variance_penalty_weight=variance_penalty,
# Margin logic
enable_margin_vetoes=enable_gain_margin_logic,
margin_threshold=margin_threshold,
# Leaf stabilization
leaf_smoothing=leaf_smoothing,
leaf_smoothing_strategy=leaf_smoothing_strategy,
# Classification
classification_criterion=classification_criterion,
# Focus on maximum stability
algorithm_focus="stability",
random_state=random_state,
)
# Store Bootstrap-specific parameters for backwards compatibility
self.variance_penalty = variance_penalty
self.n_bootstrap = n_bootstrap
self.bootstrap_max_depth = bootstrap_max_depth
# Cross-method enhancement flags
self.enable_stratified_bootstraps = enable_stratified_bootstraps
self.bootstrap_stratification_bins = bootstrap_stratification_bins
self.enable_robust_consensus = enable_robust_consensus
# Initialize fitted attributes
self.bootstrap_evaluations_ = 0
[docs]
def fit(self, X, y):
"""Fit with bootstrap variance tracking."""
# Call parent fit method
result = super().fit(X, y)
# Set bootstrap evaluations for backwards compatibility
if self.enable_explicit_variance_penalty:
# Estimate number of bootstrap evaluations based on tree structure
self.bootstrap_evaluations_ = self._estimate_bootstrap_evaluations()
else:
self.bootstrap_evaluations_ = 0
return result
def _estimate_bootstrap_evaluations(self):
"""Estimate total bootstrap evaluations performed during training."""
if self.tree_ is None:
return 0
# Rough estimate: internal nodes * n_bootstrap * candidate evaluations
internal_nodes = self._count_internal_nodes(self.tree_)
candidates_per_node = 10 # Rough estimate
return internal_nodes * self.n_bootstrap * candidates_per_node
def _count_internal_nodes(self, node):
"""Count internal (non-leaf) nodes recursively."""
if node["type"] == "leaf":
return 0
count = 1 # This node
if "left" in node:
count += self._count_internal_nodes(node["left"])
if "right" in node:
count += self._count_internal_nodes(node["right"])
return count
[docs]
def get_params(self, deep=True):
"""Get parameters for sklearn compatibility."""
return super().get_params(deep=deep)
[docs]
def set_params(self, **params):
"""Set parameters for sklearn compatibility."""
return super().set_params(**params)
# Create the backwards-compatible aliases
BootstrapVariancePenalizedRegressor = BootstrapVariancePenalizedTree # Will need task='regression'
BootstrapVariancePenalizedClassifier = (
BootstrapVariancePenalizedTree # Will need task='classification'
)