{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Nadaraya-Watson Regression with `hessband`\n", "\n", "This notebook demonstrates how to use `hessband` to perform Nadaraya-Watson kernel regression. We will select the bandwidth using both the analytic method and grid search, and compare the results." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import time\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "from hessband import nw_predict, select_nw_bandwidth" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Generate Synthetic Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "np.random.seed(0)\n", "X = np.linspace(0, 1, 200)\n", "true_y = np.sin(2 * np.pi * X)\n", "y = true_y + 0.2 * np.random.randn(200)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Bandwidth Selection\n", "\n", "We will now select the optimal bandwidth using two methods: `analytic` and `grid` search." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Analytic method\n", "start_time = time.time()\n", "h_analytic = select_nw_bandwidth(X, y, method=\"analytic\")\n", "analytic_time = time.time() - start_time\n", "\n", "# Grid search method\n", "start_time = time.time()\n", "h_grid = select_nw_bandwidth(X, y, method=\"grid\")\n", "grid_time = time.time() - start_time\n", "\n", "print(f\"Analytic method: h = {h_analytic:.4f}, time = {analytic_time:.4f}s\")\n", "print(f\"Grid search method: h = {h_grid:.4f}, time = {grid_time:.4f}s\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The analytic method is much faster and gives a comparable bandwidth." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Perform Regression and Plot Results" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y_pred_analytic = nw_predict(X, y, X, h_analytic)\n", "y_pred_grid = nw_predict(X, y, X, h_grid)\n", "\n", "plt.figure(figsize=(10, 6))\n", "plt.scatter(X, y, label=\"Data\", alpha=0.5, s=10)\n", "plt.plot(X, true_y, label=\"True function\", color=\"black\", linestyle=\"--\")\n", "plt.plot(X, y_pred_analytic, label=f\"Analytic (h={h_analytic:.4f})\", color=\"red\")\n", "plt.plot(\n", " X, y_pred_grid, label=f\"Grid search (h={h_grid:.4f})\", color=\"green\", linestyle=\":\"\n", ")\n", "plt.legend()\n", "plt.title(\"Nadaraya-Watson Regression\")\n", "plt.xlabel(\"X\")\n", "plt.ylabel(\"y\")\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 2 }