"""
BaseStableTree: Unified implementation of all stability primitives.
This base class implements the 7 core stability "atoms" that can be composed
across different tree methods. Each method can inherit from this and configure
different defaults to maintain their distinct personalities.
"""
import time
from typing import Optional, Tuple, Literal
import numpy as np
from sklearn.base import BaseEstimator
from sklearn.metrics import r2_score, accuracy_score
from sklearn.utils.validation import check_X_y, check_array
from .stability_utils import (
honest_data_partition,
winsorize_features,
stabilize_leaf_estimate,
)
from .split_strategies import HybridStrategy, create_split_strategy
[docs]
class BaseStableTree(BaseEstimator):
"""
Unified base class implementing all 7 stability primitives.
The 7 stability primitives are:
1. Prefix stability (robust consensus on early splits)
2. Validation-checked split selection
3. Honesty (separate data for structure vs estimation)
4. Leaf stabilization (shrinkage/smoothing)
5. Data regularization (winsorization, etc.)
6. Candidate diversity with deterministic resolution
7. Variance-aware stopping
All tree methods inherit from this and configure different defaults
to maintain their distinct personalities while sharing the unified
stability infrastructure.
"""
[docs]
def __init__(
self,
# === TASK AND CORE PARAMETERS ===
task: Literal["regression", "classification"] = "regression",
max_depth: int = 5,
min_samples_split: int = 40,
min_samples_leaf: int = 20,
# === 3. HONESTY - Data Partitioning ===
enable_honest_estimation: bool = True,
split_frac: float = 0.6,
val_frac: float = 0.2,
est_frac: float = 0.2,
enable_stratified_sampling: bool = True,
# === 2. VALIDATION-CHECKED SPLIT SELECTION ===
enable_validation_checking: bool = True,
validation_metric: Literal["median", "one_se", "variance_penalized"] = "variance_penalized",
validation_consistency_weight: float = 1.0,
# === 1. PREFIX STABILITY ===
enable_prefix_consensus: bool = False,
prefix_levels: int = 2,
consensus_samples: int = 12,
consensus_threshold: float = 0.5,
enable_quantile_grid_thresholds: bool = False,
max_threshold_bins: int = 24,
# === 4. LEAF STABILIZATION ===
leaf_smoothing: float = 0.0,
leaf_smoothing_strategy: Literal[
"m_estimate", "shrink_to_parent", "beta_smoothing"
] = "m_estimate",
enable_calibrated_smoothing: bool = False,
min_leaf_samples_for_stability: int = 5,
# === 5. DATA REGULARIZATION ===
enable_winsorization: bool = False,
winsor_quantiles: Tuple[float, float] = (0.01, 0.99),
enable_feature_standardization: bool = False,
# === 6. CANDIDATE DIVERSITY ===
enable_oblique_splits: bool = False,
oblique_strategy: Literal["root_only", "all_levels", "adaptive"] = "root_only",
oblique_regularization: Literal["lasso", "ridge", "elastic_net"] = "lasso",
enable_correlation_gating: bool = True,
min_correlation_threshold: float = 0.3,
enable_lookahead: bool = False,
lookahead_depth: int = 1,
beam_width: int = 8,
enable_ambiguity_gating: bool = True,
ambiguity_threshold: float = 0.05,
min_samples_for_lookahead: int = 100,
enable_deterministic_preprocessing: bool = False,
enable_deterministic_tiebreaks: bool = True,
enable_margin_vetoes: bool = False,
margin_threshold: float = 0.03,
# === 7. VARIANCE-AWARE STOPPING ===
enable_variance_aware_stopping: bool = False,
variance_stopping_weight: float = 1.0,
variance_stopping_strategy: Literal[
"one_se", "variance_penalty", "both"
] = "variance_penalty",
enable_bootstrap_variance_tracking: bool = False,
variance_tracking_samples: int = 10,
enable_explicit_variance_penalty: bool = False,
variance_penalty_weight: float = 0.1,
# === ADVANCED CONFIGURATION ===
split_strategy: Optional[str] = None,
algorithm_focus: Literal["speed", "accuracy", "stability"] = "stability",
# === CLASSIFICATION ===
classification_criterion: Literal["gini", "entropy"] = "gini",
# === OTHER ===
random_state: Optional[int] = None,
# === ADDITIONAL PARAMETERS FOR CROSS-METHOD LEARNING ===
enable_threshold_binning: bool = False,
enable_gain_margin_logic: bool = False,
enable_beam_search_for_consensus: bool = False,
enable_robust_consensus_for_ambiguous: bool = False,
):
# Validate fractions sum to 1
if abs(split_frac + val_frac + est_frac - 1.0) > 1e-6:
raise ValueError("split_frac + val_frac + est_frac must sum to 1.0")
# === CORE PARAMETERS ===
self.task = task
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
# === 3. HONESTY ===
self.enable_honest_estimation = enable_honest_estimation
self.split_frac = split_frac
self.val_frac = val_frac
self.est_frac = est_frac
self.enable_stratified_sampling = enable_stratified_sampling
# === 2. VALIDATION ===
self.enable_validation_checking = enable_validation_checking
self.validation_metric = validation_metric
self.validation_consistency_weight = validation_consistency_weight
# === 1. PREFIX STABILITY ===
self.enable_prefix_consensus = enable_prefix_consensus
self.prefix_levels = prefix_levels
self.consensus_samples = consensus_samples
self.consensus_threshold = consensus_threshold
self.enable_quantile_grid_thresholds = enable_quantile_grid_thresholds
self.max_threshold_bins = max_threshold_bins
# === 4. LEAF STABILIZATION ===
self.leaf_smoothing = leaf_smoothing
self.leaf_smoothing_strategy = leaf_smoothing_strategy
self.enable_calibrated_smoothing = enable_calibrated_smoothing
self.min_leaf_samples_for_stability = min_leaf_samples_for_stability
# === 5. DATA REGULARIZATION ===
self.enable_winsorization = enable_winsorization
self.winsor_quantiles = winsor_quantiles
self.enable_feature_standardization = enable_feature_standardization
# === 6. CANDIDATE DIVERSITY ===
self.enable_oblique_splits = enable_oblique_splits
self.oblique_strategy = oblique_strategy
self.oblique_regularization = oblique_regularization
self.enable_correlation_gating = enable_correlation_gating
self.min_correlation_threshold = min_correlation_threshold
self.enable_lookahead = enable_lookahead
self.lookahead_depth = lookahead_depth
self.beam_width = beam_width
self.enable_ambiguity_gating = enable_ambiguity_gating
self.ambiguity_threshold = ambiguity_threshold
self.min_samples_for_lookahead = min_samples_for_lookahead
self.enable_deterministic_preprocessing = enable_deterministic_preprocessing
self.enable_deterministic_tiebreaks = enable_deterministic_tiebreaks
self.enable_margin_vetoes = enable_margin_vetoes
self.margin_threshold = margin_threshold
# === 7. VARIANCE-AWARE STOPPING ===
self.enable_variance_aware_stopping = enable_variance_aware_stopping
self.variance_stopping_weight = variance_stopping_weight
self.variance_stopping_strategy = variance_stopping_strategy
self.enable_bootstrap_variance_tracking = enable_bootstrap_variance_tracking
self.variance_tracking_samples = variance_tracking_samples
self.enable_explicit_variance_penalty = enable_explicit_variance_penalty
self.variance_penalty_weight = variance_penalty_weight
# === ADVANCED ===
self.split_strategy = split_strategy
self.algorithm_focus = algorithm_focus
# === CLASSIFICATION ===
self.classification_criterion = classification_criterion
# === OTHER ===
self.random_state = random_state
# === CROSS-METHOD LEARNING ===
self.enable_threshold_binning = enable_threshold_binning
self.enable_gain_margin_logic = enable_gain_margin_logic
self.enable_beam_search_for_consensus = enable_beam_search_for_consensus
self.enable_robust_consensus_for_ambiguous = enable_robust_consensus_for_ambiguous
# Initialize fitted attributes
self.tree_ = None
self.classes_ = None
self.n_classes_ = None
self.fit_time_sec_ = None
self._split_strategy_ = None
self._winsor_bounds_ = None
self._global_prior_ = None
[docs]
def fit(self, X, y):
"""Fit the stable tree to the training data."""
start_time = time.time()
# Validate inputs
X, y = check_X_y(X, y, accept_sparse=False)
# === 1. TASK SETUP ===
if self.task == "classification":
self.classes_ = np.unique(y)
self.n_classes_ = len(self.classes_)
if self.n_classes_ > 2:
raise ValueError(
"Multi-class classification not yet supported. "
"Use binary classification or regression."
)
# Convert to 0/1 for binary classification
y = (y == self.classes_[1]).astype(int)
self._global_prior_ = np.mean(y)
else:
self.classes_ = None
self.n_classes_ = None
self._global_prior_ = np.mean(y) if len(y) > 0 else 0.0
# === 5. DATA REGULARIZATION ===
X_processed = self._preprocess_features(X)
# === 3. HONESTY - Data Partitioning ===
data_splits = self._partition_data(X_processed, y)
(X_split, y_split), (X_val, y_val), (X_est, y_est) = data_splits
# === Configure Split Strategy ===
self._split_strategy_ = self._create_split_strategy()
# === Build Tree Structure ===
self.tree_ = self._build_tree(X_split, y_split, X_val, y_val, X_est, y_est, depth=0)
# Record timing and diagnostics
self.fit_time_sec_ = time.time() - start_time
return self
[docs]
def predict(self, X):
"""Predict targets for samples in X."""
check_array(X, accept_sparse=False)
if self.tree_ is None:
raise ValueError("Tree not fitted yet")
# Apply same preprocessing as training
X_processed = self._preprocess_features(X, fitted=True)
predictions = np.array([self._predict_sample(x, self.tree_) for x in X_processed])
if self.task == "classification":
# Convert back to original class labels
return np.where(predictions > 0.5, self.classes_[1], self.classes_[0])
else:
return predictions
[docs]
def predict_proba(self, X):
"""Predict class probabilities for classification tasks."""
if self.task != "classification":
raise ValueError("predict_proba is only available for classification tasks")
check_array(X, accept_sparse=False)
if self.tree_ is None:
raise ValueError("Tree not fitted yet")
# Apply same preprocessing as training
X_processed = self._preprocess_features(X, fitted=True)
# Get probability of positive class
proba_positive = np.array([self._predict_sample(x, self.tree_) for x in X_processed])
# Return as [P(class=0), P(class=1)]
proba_negative = 1 - proba_positive
return np.column_stack([proba_negative, proba_positive])
[docs]
def score(self, X, y):
"""Return the mean accuracy (classification) or R² (regression)."""
y_pred = self.predict(X)
if self.task == "regression":
return r2_score(y, y_pred)
else:
return accuracy_score(y, y_pred)
[docs]
def count_leaves(self):
"""Count the number of leaf nodes in the tree."""
if self.tree_ is None:
return 0
return self._count_leaves_recursive(self.tree_)
def _count_leaves_recursive(self, node):
"""Recursively count leaves."""
if node["type"] == "leaf":
return 1
else:
left_count = self._count_leaves_recursive(node["left"]) if "left" in node else 0
right_count = self._count_leaves_recursive(node["right"]) if "right" in node else 0
return left_count + right_count
# ========================================================================
# INTERNAL METHODS - STABILITY PRIMITIVES
# ========================================================================
def _preprocess_features(self, X, fitted=False):
"""Apply data regularization preprocessing."""
X_processed = X.copy()
# === 5. DATA REGULARIZATION ===
if self.enable_winsorization:
if fitted and self._winsor_bounds_ is not None:
X_processed, _ = winsorize_features(X_processed, fitted_bounds=self._winsor_bounds_)
else:
X_processed, self._winsor_bounds_ = winsorize_features(
X_processed, self.winsor_quantiles
)
# Feature standardization (rarely needed for trees)
if self.enable_feature_standardization:
# Would implement standardization here
pass
return X_processed
def _partition_data(self, X, y):
"""Partition data using honest splitting."""
if not self.enable_honest_estimation:
# Use all data for both structure and estimation
return (X, y), (X, y), (X, y)
return honest_data_partition(
X,
y,
split_frac=self.split_frac,
val_frac=self.val_frac,
est_frac=self.est_frac,
enable_stratification=self.enable_stratified_sampling,
task=self.task,
random_state=self.random_state,
)
def _create_split_strategy(self):
"""Create the split strategy based on enabled features."""
if self.split_strategy is not None:
# Explicit strategy specified
return create_split_strategy(
self.split_strategy,
task=self.task,
random_state=self.random_state,
# Pass relevant parameters
oblique_regularization=self.oblique_regularization,
enable_correlation_gating=self.enable_correlation_gating,
min_correlation=self.min_correlation_threshold,
consensus_samples=self.consensus_samples,
consensus_threshold=self.consensus_threshold,
lookahead_depth=self.lookahead_depth,
beam_width=self.beam_width,
variance_penalty_weight=self.variance_penalty_weight,
)
else:
# Auto-select based on enabled features and algorithm focus
return HybridStrategy(
focus=self.algorithm_focus, task=self.task, random_state=self.random_state
)
def _build_tree(self, X_split, y_split, X_val, y_val, X_est, y_est, depth=0):
"""Recursively build the tree structure."""
n_samples = len(X_split)
# Base stopping conditions
if (
depth >= self.max_depth
or n_samples < self.min_samples_split
or len(np.unique(y_split)) <= 1
):
return self._make_leaf(y_est, y_split, depth)
# Find best split using configured strategy
best_split = self._split_strategy_.find_best_split(
X_split,
y_split,
X_val if self.enable_validation_checking else None,
y_val if self.enable_validation_checking else None,
depth=depth,
max_depth=self.max_depth,
min_samples_split=self.min_samples_split,
min_samples_leaf=self.min_samples_leaf,
)
if best_split is None:
return self._make_leaf(y_est, y_split, depth)
# === 7. VARIANCE-AWARE STOPPING ===
if self.enable_variance_aware_stopping and best_split.variance_estimate is not None:
should_stop = self._split_strategy_.should_stop(
X_split,
y_split,
best_split.gain,
depth,
variance_estimate=best_split.variance_estimate,
max_depth=self.max_depth,
min_samples_split=self.min_samples_split,
)
if should_stop:
return self._make_leaf(y_est, y_split, depth)
# Apply split to all data partitions
left_indices_split, right_indices_split = self._apply_split_to_data(X_split, best_split)
left_indices_val, right_indices_val = self._apply_split_to_data(X_val, best_split)
left_indices_est, right_indices_est = self._apply_split_to_data(X_est, best_split)
# Check minimum leaf size
if (
len(left_indices_split) < self.min_samples_leaf
or len(right_indices_split) < self.min_samples_leaf
):
return self._make_leaf(y_est, y_split, depth)
# Recursively build children
left_child = self._build_tree(
X_split[left_indices_split],
y_split[left_indices_split],
X_val[left_indices_val],
y_val[left_indices_val],
X_est[left_indices_est],
y_est[left_indices_est],
depth + 1,
)
right_child = self._build_tree(
X_split[right_indices_split],
y_split[right_indices_split],
X_val[right_indices_val],
y_val[right_indices_val],
X_est[right_indices_est],
y_est[right_indices_est],
depth + 1,
)
# Create internal node
return {
"type": "split_oblique" if best_split.is_oblique else "split",
"feature_idx": best_split.feature_idx,
"threshold": best_split.threshold,
"gain": best_split.gain,
"depth": depth,
"n_samples_split": len(X_split),
"n_samples_val": len(X_val),
"n_samples_est": len(X_est),
"oblique_weights": best_split.oblique_weights if best_split.is_oblique else None,
"consensus_support": getattr(best_split, "consensus_support", None),
"variance_estimate": getattr(best_split, "variance_estimate", None),
"left": left_child,
"right": right_child,
}
def _apply_split_to_data(self, X, split_candidate):
"""Apply a split to data and return left/right indices."""
if split_candidate.is_oblique and split_candidate.oblique_weights is not None:
projections = X @ split_candidate.oblique_weights
left_mask = projections <= split_candidate.threshold
else:
left_mask = X[:, split_candidate.feature_idx] <= split_candidate.threshold
left_indices = np.where(left_mask)[0]
right_indices = np.where(~left_mask)[0]
return left_indices, right_indices
def _make_leaf(self, y_est, y_split, depth):
"""Create a leaf node with stabilized estimates."""
# === 4. LEAF STABILIZATION ===
if len(y_est) == 0:
y_est = y_split # Fallback to split data
# Get parent data for shrinkage (use split data as proxy)
stabilized_value = stabilize_leaf_estimate(
y_est,
y_split,
strategy=self.leaf_smoothing_strategy,
smoothing=self.leaf_smoothing,
task=self.task,
min_samples=self.min_leaf_samples_for_stability,
)
if self.task == "regression":
return {
"type": "leaf",
"value": stabilized_value,
"depth": depth,
"n_samples_split": len(y_split),
"n_samples_est": len(y_est),
}
else:
# For classification, stabilized_value is probability array or scalar
if isinstance(stabilized_value, (float, int)):
prob = stabilized_value
else:
# stabilized_value is an array of class probabilities
if len(stabilized_value) >= 2:
prob = stabilized_value[1] # P(class=1) for binary classification
else:
# Only one class present, assume class 0
prob = 0.0
return {
"type": "leaf",
"proba": float(prob),
"depth": depth,
"n_samples_split": len(y_split),
"n_samples_est": len(y_est),
}
def _predict_sample(self, x, node):
"""Predict a single sample by traversing the tree."""
if node["type"] == "leaf":
if self.task == "regression":
return node["value"]
else:
return node["proba"]
# Apply split
if node["type"] == "split_oblique" and node["oblique_weights"] is not None:
projection = x @ node["oblique_weights"]
go_left = projection <= node["threshold"]
else:
go_left = x[node["feature_idx"]] <= node["threshold"]
# Recurse
if go_left:
return self._predict_sample(x, node["left"])
else:
return self._predict_sample(x, node["right"])