{
"cells": [
{
"cell_type": "markdown",
"id": "0",
"metadata": {},
"source": [
"# Cross validation\n",
"\n",
"Model selection by 5-fold cross validation."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"import torch\n",
"from torch.distributions import Normal\n",
"from gpytorch.variational import CholeskyVariationalDistribution\n",
"from gpytorch.variational import VariationalStrategy, LMCVariationalStrategy\n",
"from gpytorch.means import ConstantMean\n",
"from gpytorch.kernels import RBFKernel, ScaleKernel\n",
"from gpytorch.mlls import VariationalELBO\n",
"from sklearn.model_selection import KFold\n",
"import matplotlib.pyplot as plt\n",
"\n",
"try:\n",
" import sys\n",
"\n",
" sys.path.insert(0, os.path.abspath(\"..\"))\n",
"\n",
" import config_notebook\n",
"except ImportError:\n",
" print(\"Output will not be deterministic SVG.\")\n",
"\n",
"torch.manual_seed(42)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"n_epochs = int(os.getenv(\"GPYTORCHQR_N_EPOCHS\", 5000))"
]
},
{
"cell_type": "markdown",
"id": "2",
"metadata": {},
"source": [
"## Data preparation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": [
"def mean(x):\n",
" return torch.cos(x * 2 * 3.14)\n",
"\n",
"\n",
"def std(x):\n",
" return x + 0.1\n",
"\n",
"\n",
"N = 100\n",
"x = torch.linspace(0, 1, N).reshape(-1, 1).to(device)\n",
"y = (mean(x) + torch.randn(x.shape, device=device).mul(std(x))).squeeze()\n",
"q = torch.linspace(0.1, 0.9, 9).to(device)\n",
"true_quantiles = mean(x) + std(x) * Normal(0, 1).icdf(q)\n",
"x_pred = torch.linspace(0, 1.5, 100).reshape(-1, 1).to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(x.cpu(), y.cpu(), c=\"k\", marker=\".\")\n",
"plt.plot(x.cpu(), true_quantiles.cpu(), \"--\", c=\"gray\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5",
"metadata": {},
"outputs": [],
"source": [
"K = 5\n",
"kf = KFold(n_splits=K, shuffle=True, random_state=42)\n",
"\n",
"x_train_list, y_train_list, x_test_list, y_test_list = [], [], [], []\n",
"for train_idx, test_idx in kf.split(x.cpu()):\n",
" x_train_list.append(x[train_idx])\n",
" y_train_list.append(y[train_idx])\n",
" x_test_list.append(x[test_idx])\n",
" y_test_list.append(y[test_idx])\n",
"\n",
"x_train_cv = torch.stack(x_train_list).to(device)\n",
"y_train_cv = torch.stack(y_train_list).to(device)\n",
"x_test_cv = torch.stack(x_test_list).to(device)\n",
"y_test_cv = torch.stack(y_test_list).to(device)"
]
},
{
"cell_type": "markdown",
"id": "6",
"metadata": {},
"source": [
"## Direct GPQR (correlated quantiles)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7",
"metadata": {},
"outputs": [],
"source": [
"from gpytorch_qr.models import DirectQuantileGP\n",
"from gpytorch_qr.likelihoods import MultitaskQuantileGPLikelihood\n",
"\n",
"\n",
"class CVMultitaskQuantileGP(DirectQuantileGP):\n",
" def __init__(self, inducing_points, num_quantiles, batch_shape, num_latents):\n",
" N, D = inducing_points.size()\n",
" batch_shape = torch.Size([*batch_shape, num_latents])\n",
" variational_distribution = CholeskyVariationalDistribution(\n",
" N,\n",
" batch_shape=batch_shape,\n",
" )\n",
" variational_strategy = LMCVariationalStrategy(\n",
" VariationalStrategy(\n",
" self,\n",
" inducing_points,\n",
" variational_distribution,\n",
" learn_inducing_locations=True,\n",
" ),\n",
" num_tasks=num_quantiles,\n",
" num_latents=num_latents,\n",
" )\n",
"\n",
" mean_module = ConstantMean(batch_shape=batch_shape)\n",
" covar_module = ScaleKernel(\n",
" RBFKernel(ard_num_dims=D, batch_shape=batch_shape),\n",
" batch_shape=batch_shape,\n",
" )\n",
" super().__init__(variational_strategy, mean_module, covar_module)\n",
"\n",
"\n",
"inducing_points = torch.linspace(0, 1, 10).reshape(-1, 1).to(device)\n",
"num_latents = len(q) - 2\n",
"mtgpqr = CVMultitaskQuantileGP(inducing_points, len(q), (K,), num_latents).to(device)\n",
"q_expanded = q.unsqueeze(0).expand(K, -1)\n",
"likelihood_mtgpqr = MultitaskQuantileGPLikelihood(q_expanded).to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8",
"metadata": {},
"outputs": [],
"source": [
"mll = VariationalELBO(likelihood_mtgpqr, mtgpqr, num_data=y_train_cv.shape[-1])\n",
"optimizer = torch.optim.Adam(\n",
" list(mtgpqr.parameters()) + list(likelihood_mtgpqr.parameters()),\n",
" lr=0.001,\n",
")\n",
"\n",
"mtgpqr_train_losses, mtgpqr_test_losses = [], []\n",
"for _ in range(n_epochs):\n",
" mtgpqr.train()\n",
" likelihood_mtgpqr.train()\n",
"\n",
" output = mtgpqr(x_train_cv)\n",
" train_loss = -mll(output, y_train_cv)\n",
" train_loss.sum().backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
"\n",
" mtgpqr.eval()\n",
" likelihood_mtgpqr.eval()\n",
" with torch.no_grad():\n",
" output = mtgpqr(x_test_cv)\n",
" test_loss = -mll(output, y_test_cv)\n",
"\n",
" mtgpqr_train_losses.append(train_loss.mean().item() / len(q))\n",
" mtgpqr_test_losses.append(test_loss.mean().item() / len(q))"
]
},
{
"cell_type": "markdown",
"id": "9",
"metadata": {},
"source": [
"## Center-gap GPQR (correlated gaps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "10",
"metadata": {},
"outputs": [],
"source": [
"from gpytorch_qr.means import CenterGapMean\n",
"from gpytorch_qr.models import CenterGapQuantileGP\n",
"from gpytorch_qr.variational import CGLmcVariationalStrategy\n",
"from gpytorch_qr.likelihoods import MultitaskCenterGapQuantileGPLikelihood\n",
"\n",
"\n",
"class CVMultitaskCenterGapQuantileGP(CenterGapQuantileGP):\n",
" def __init__(\n",
" self,\n",
" inducing_points,\n",
" num_quantiles,\n",
" num_lower_quantiles,\n",
" num_latents,\n",
" num_folds,\n",
" ):\n",
" N, D = inducing_points.size()\n",
" batch_shape = torch.Size([num_folds, num_latents])\n",
" variational_distribution = CholeskyVariationalDistribution(\n",
" N,\n",
" batch_shape=batch_shape,\n",
" )\n",
" variational_strategy = CGLmcVariationalStrategy(\n",
" VariationalStrategy(\n",
" self,\n",
" inducing_points,\n",
" variational_distribution,\n",
" learn_inducing_locations=True,\n",
" ),\n",
" num_quantiles=num_quantiles,\n",
" num_latents=num_latents,\n",
" )\n",
"\n",
" mean = CenterGapMean(\n",
" ConstantMean(batch_shape=torch.Size([num_folds, 1])),\n",
" ConstantMean(batch_shape=torch.Size([num_folds, num_latents - 1])),\n",
" )\n",
" covar = ScaleKernel(\n",
" RBFKernel(ard_num_dims=D, batch_shape=batch_shape),\n",
" batch_shape=batch_shape,\n",
" )\n",
" super().__init__(variational_strategy, mean, covar, num_lower_quantiles)\n",
"\n",
"\n",
"inducing_points = torch.linspace(0, 1, 10).reshape(-1, 1).to(device)\n",
"central_q_index = (q - 0.5).abs().argmin().item()\n",
"num_latents = len(q) - 2\n",
"mtgpqr_cg = CVMultitaskCenterGapQuantileGP(\n",
" inducing_points, len(q), central_q_index, num_latents, K\n",
").to(device)\n",
"likelihood_mtgpqr_cg = MultitaskCenterGapQuantileGPLikelihood(\n",
" q.unsqueeze(0), central_q_index, torch.zeros((K, len(q)))\n",
").to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11",
"metadata": {},
"outputs": [],
"source": [
"mll = VariationalELBO(likelihood_mtgpqr_cg, mtgpqr_cg, num_data=y_train_cv.shape[1])\n",
"optimizer = torch.optim.Adam(\n",
" list(mtgpqr_cg.parameters()) + list(likelihood_mtgpqr_cg.parameters()),\n",
" lr=0.001,\n",
")\n",
"\n",
"mtgpqr_cg_train_losses, mtgpqr_cg_test_losses = [], []\n",
"for _ in range(n_epochs):\n",
" mtgpqr_cg.train()\n",
" likelihood_mtgpqr_cg.train()\n",
"\n",
" output = mtgpqr_cg(x_train_cv)\n",
" train_loss = -mll(output, y_train_cv)\n",
" train_loss.sum().backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
"\n",
" mtgpqr_cg.eval()\n",
" likelihood_mtgpqr_cg.eval()\n",
" with torch.no_grad():\n",
" output = mtgpqr_cg(x_test_cv)\n",
" test_loss = -mll(output, y_test_cv)\n",
"\n",
" mtgpqr_cg_train_losses.append(train_loss.mean().item() / len(q))\n",
" mtgpqr_cg_test_losses.append(test_loss.mean().item() / len(q))"
]
},
{
"cell_type": "markdown",
"id": "12",
"metadata": {},
"source": [
"## Plot loss by epoch"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "13",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from matplotlib.lines import Line2D\n",
"\n",
"colors = plt.cm.tab10.colors\n",
"\n",
"losses = [\n",
" [mtgpqr_train_losses, mtgpqr_test_losses],\n",
" [mtgpqr_cg_train_losses, mtgpqr_cg_test_losses],\n",
"]\n",
"\n",
"for (train_losses, test_losses), color in zip(losses, colors):\n",
" plt.plot(train_losses, color=color, linestyle=\"-.\", alpha=0.5)\n",
" plt.plot(test_losses, color=color, linestyle=\"--\")\n",
"\n",
"model_handles = [\n",
" Line2D([0], [0], color=colors[i], linestyle=\"-\") for i in range(len(losses))\n",
"]\n",
"style_handles = [\n",
" Line2D([0], [0], color=\"k\", linestyle=\"-.\", alpha=0.5),\n",
" Line2D([0], [0], color=\"k\", linestyle=\"--\"),\n",
"]\n",
"\n",
"handles = model_handles + style_handles\n",
"labels = [\"Direct GPQR\", \"CG GPQR\", \"Train\", \"Test\"]\n",
"plt.legend(handles=handles, labels=labels, ncols=4)\n",
"\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Negative ELBO\")\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "heavyedge",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}