Source code for hessband.cv

"""
Cross-validation utilities for kernel regression and density estimation.

This module defines a CVScorer class that can be used to evaluate
leave-one-out cross-validation (LOOCV) or K-fold cross-validation
for kernel regression or density estimation.
"""

from __future__ import annotations

from collections.abc import Callable

import numpy as np
from sklearn.metrics import mean_squared_error  # type: ignore
from sklearn.model_selection import KFold  # type: ignore

__all__ = ["CVScorer"]


[docs] class CVScorer: """Cross-validation scorer for kernel regression. Args: X: Input values. y: Target values. folds: Number of folds for K-fold cross-validation. kernel: Kernel type ('gaussian' or 'epanechnikov'). """ def __init__( self, X: np.ndarray, y: np.ndarray, folds: int = 5, kernel: str = "gaussian" ) -> None: self.X = np.asarray(X).ravel() self.y = np.asarray(y).ravel() if not (2 <= folds <= len(self.X)): raise ValueError( f"`folds` must be between 2 and {len(self.X)}, got {folds}" ) self.kf = KFold(n_splits=folds, shuffle=True, random_state=0) self.kernel = kernel self.evals = 0
[docs] def score( self, predict_fn: Callable[ [np.ndarray, np.ndarray, np.ndarray, float, str], np.ndarray ], h: float, ) -> float: """Computes the cross-validation MSE for a given bandwidth. Args: predict_fn: Function that takes ``(X_train, y_train, X_test, h, kernel)`` and returns predictions. h: Bandwidth value. Returns: Cross-validation mean squared error. """ mses = [] for train_idx, test_idx in self.kf.split(self.X): Xtr, Xte = self.X[train_idx], self.X[test_idx] ytr, yte = self.y[train_idx], self.y[test_idx] ypred = predict_fn(Xtr, ytr, Xte, h, self.kernel) mses.append(mean_squared_error(yte, ypred)) self.evals += 1 return float(np.mean(mses))