{
"cells": [
{
"cell_type": "markdown",
"id": "0",
"metadata": {},
"source": [
"# Cross validation\n",
"\n",
"Model selection by 5-fold cross validation."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"import torch\n",
"from torch.distributions import Normal\n",
"from gpytorch.variational import CholeskyVariationalDistribution\n",
"from gpytorch.variational import VariationalStrategy, LMCVariationalStrategy\n",
"from gpytorch.means import ConstantMean\n",
"from gpytorch.kernels import RBFKernel, ScaleKernel\n",
"from gpytorch.mlls import VariationalELBO\n",
"from sklearn.model_selection import KFold\n",
"import matplotlib.pyplot as plt\n",
"\n",
"try:\n",
" import sys\n",
"\n",
" sys.path.insert(0, os.path.abspath(\"..\"))\n",
"\n",
" import config_notebook\n",
"except ImportError:\n",
" print(\"Output will not be deterministic SVG.\")\n",
"\n",
"torch.manual_seed(42)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"n_epochs = int(os.getenv(\"GPYTORCHQR_N_EPOCHS\", 5000))"
]
},
{
"cell_type": "markdown",
"id": "2",
"metadata": {},
"source": [
"## Data preparation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": [
"def mean(x):\n",
" return torch.cos(x * 2 * 3.14)\n",
"\n",
"\n",
"def std(x):\n",
" return x + 0.1\n",
"\n",
"\n",
"N = 100\n",
"x = torch.linspace(0, 1, N).reshape(-1, 1).to(device)\n",
"y = (mean(x) + torch.randn(x.shape, device=device).mul(std(x))).squeeze()\n",
"q = torch.linspace(0.1, 0.9, 9).to(device)\n",
"true_quantiles = mean(x) + std(x) * Normal(0, 1).icdf(q)\n",
"x_pred = torch.linspace(0, 1.5, 100).reshape(-1, 1).to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(x.cpu(), y.cpu(), c=\"k\", marker=\".\")\n",
"plt.plot(x.cpu(), true_quantiles.cpu(), \"--\", c=\"gray\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5",
"metadata": {},
"outputs": [],
"source": [
"K = 5\n",
"kf = KFold(n_splits=K, shuffle=True, random_state=42)\n",
"\n",
"x_train_list, y_train_list, x_test_list, y_test_list = [], [], [], []\n",
"for train_idx, test_idx in kf.split(x.cpu()):\n",
" x_train_list.append(x[train_idx])\n",
" y_train_list.append(y[train_idx])\n",
" x_test_list.append(x[test_idx])\n",
" y_test_list.append(y[test_idx])\n",
"\n",
"x_train_cv = torch.stack(x_train_list).to(device)\n",
"y_train_cv = torch.stack(y_train_list).to(device)\n",
"x_test_cv = torch.stack(x_test_list).to(device)\n",
"y_test_cv = torch.stack(y_test_list).to(device)"
]
},
{
"cell_type": "markdown",
"id": "6",
"metadata": {},
"source": [
"## Batch direct GPQR"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7",
"metadata": {},
"outputs": [],
"source": [
"from gpytorch_qr.models import DirectQuantileGP\n",
"from gpytorch_qr.likelihoods import BatchQuantileGPLikelihood\n",
"\n",
"\n",
"class CVBatchQuantileGP(DirectQuantileGP):\n",
" def __init__(self, inducing_points, num_quantiles, batch_shape):\n",
" N, D = inducing_points.size()\n",
" batch_shape = torch.Size([num_quantiles, *batch_shape])\n",
" variational_distribution = CholeskyVariationalDistribution(\n",
" N,\n",
" batch_shape=batch_shape,\n",
" )\n",
" variational_strategy = VariationalStrategy(\n",
" self,\n",
" inducing_points,\n",
" variational_distribution,\n",
" learn_inducing_locations=True,\n",
" )\n",
" mean = ConstantMean(batch_shape=batch_shape)\n",
" covar = ScaleKernel(\n",
" RBFKernel(ard_num_dims=D, batch_shape=batch_shape),\n",
" batch_shape=batch_shape,\n",
" )\n",
" super().__init__(variational_strategy, mean, covar, 0)\n",
"\n",
"\n",
"inducing_points = torch.linspace(0, 1, 10).reshape(-1, 1).to(device)\n",
"gpqr = CVBatchQuantileGP(inducing_points, len(q), (K,)).to(device)\n",
"q_expanded = q.unsqueeze(1).expand(-1, K)\n",
"likelihood_gpqr = BatchQuantileGPLikelihood(q_expanded).to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8",
"metadata": {},
"outputs": [],
"source": [
"mll = VariationalELBO(likelihood_gpqr, gpqr, num_data=y_train_cv.shape[-1])\n",
"optimizer = torch.optim.Adam(\n",
" list(gpqr.parameters()) + list(likelihood_gpqr.parameters()),\n",
" lr=0.001,\n",
")\n",
"\n",
"gpqr_train_losses, gpqr_test_losses = [], []\n",
"for _ in range(n_epochs):\n",
" gpqr.train()\n",
" likelihood_gpqr.train()\n",
"\n",
" output = gpqr(x_train_cv)\n",
" train_loss = -mll(output, y_train_cv)\n",
" train_loss.sum().backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
"\n",
" gpqr.eval()\n",
" likelihood_gpqr.eval()\n",
" with torch.no_grad():\n",
" output = gpqr(x_test_cv)\n",
" test_loss = -mll(output, y_test_cv)\n",
"\n",
" gpqr_train_losses.append(train_loss.mean().item())\n",
" gpqr_test_losses.append(test_loss.mean().item())"
]
},
{
"cell_type": "markdown",
"id": "9",
"metadata": {},
"source": [
"## Batch center-gap GPQR"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "10",
"metadata": {},
"outputs": [],
"source": [
"from gpytorch_qr.means import CenterGapMean\n",
"from gpytorch_qr.models import CenterGapQuantileGP\n",
"from gpytorch_qr.likelihoods import BatchCenterGapQuantileGPLikelihood\n",
"\n",
"\n",
"class CVBatchCenterGapQuantileGP(CenterGapQuantileGP):\n",
" def __init__(self, inducing_points, num_quantiles, num_lower_quantiles, num_folds):\n",
" N, D = inducing_points.size()\n",
" batch_shape = torch.Size([num_quantiles, num_folds])\n",
" variational_distribution = CholeskyVariationalDistribution(\n",
" N,\n",
" batch_shape=batch_shape,\n",
" )\n",
" variational_strategy = VariationalStrategy(\n",
" self,\n",
" inducing_points,\n",
" variational_distribution,\n",
" learn_inducing_locations=True,\n",
" )\n",
"\n",
" mean = CenterGapMean(\n",
" ConstantMean(batch_shape=torch.Size([1, num_folds])),\n",
" ConstantMean(batch_shape=torch.Size([num_quantiles - 1, num_folds])),\n",
" latent_dim=0,\n",
" )\n",
" covar = ScaleKernel(\n",
" RBFKernel(ard_num_dims=D, batch_shape=batch_shape),\n",
" batch_shape=batch_shape,\n",
" )\n",
" super().__init__(variational_strategy, mean, covar, 0, num_lower_quantiles)\n",
"\n",
"\n",
"inducing_points = torch.linspace(0, 1, 10).reshape(-1, 1).to(device)\n",
"central_q_index = (q - 0.5).abs().argmin().item()\n",
"num_lower_quantiles = central_q_index\n",
"gpqr_cg = CVBatchCenterGapQuantileGP(\n",
" inducing_points, len(q), num_lower_quantiles, K\n",
").to(device)\n",
"likelihood_gpqr_cg = BatchCenterGapQuantileGPLikelihood(\n",
" q.unsqueeze(1), central_q_index, torch.zeros((len(q), K))\n",
").to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11",
"metadata": {},
"outputs": [],
"source": [
"mll = VariationalELBO(likelihood_gpqr_cg, gpqr_cg, num_data=y_train_cv.shape[1])\n",
"optimizer = torch.optim.Adam(\n",
" list(gpqr_cg.parameters()) + list(likelihood_gpqr_cg.parameters()),\n",
" lr=0.001,\n",
")\n",
"\n",
"gpqr_cg_train_losses, gpqr_cg_test_losses = [], []\n",
"for _ in range(n_epochs):\n",
" gpqr_cg.train()\n",
" likelihood_gpqr_cg.train()\n",
"\n",
" output = gpqr_cg(x_train_cv)\n",
" train_loss = -mll(output, y_train_cv)\n",
" train_loss.sum().backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
"\n",
" gpqr_cg.eval()\n",
" likelihood_gpqr_cg.eval()\n",
" with torch.no_grad():\n",
" output = gpqr_cg(x_test_cv)\n",
" test_loss = -mll(output, y_test_cv)\n",
"\n",
" gpqr_cg_train_losses.append(train_loss.mean().item())\n",
" gpqr_cg_test_losses.append(test_loss.mean().item())"
]
},
{
"cell_type": "markdown",
"id": "12",
"metadata": {},
"source": [
"## Multitask direct GPQR"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "13",
"metadata": {},
"outputs": [],
"source": [
"from gpytorch_qr.models import DirectQuantileGP\n",
"from gpytorch_qr.likelihoods import MultitaskQuantileGPLikelihood\n",
"\n",
"\n",
"class CVMultitaskQuantileGP(DirectQuantileGP):\n",
" def __init__(self, inducing_points, num_quantiles, batch_shape, num_latents):\n",
" N, D = inducing_points.size()\n",
" batch_shape = torch.Size([*batch_shape, num_latents])\n",
" variational_distribution = CholeskyVariationalDistribution(\n",
" N,\n",
" batch_shape=batch_shape,\n",
" )\n",
" variational_strategy = LMCVariationalStrategy(\n",
" VariationalStrategy(\n",
" self,\n",
" inducing_points,\n",
" variational_distribution,\n",
" learn_inducing_locations=True,\n",
" ),\n",
" num_tasks=num_quantiles,\n",
" num_latents=num_latents,\n",
" )\n",
"\n",
" mean_module = ConstantMean(batch_shape=batch_shape)\n",
" covar_module = ScaleKernel(\n",
" RBFKernel(ard_num_dims=D, batch_shape=batch_shape),\n",
" batch_shape=batch_shape,\n",
" )\n",
" super().__init__(variational_strategy, mean_module, covar_module, -1)\n",
"\n",
"\n",
"inducing_points = torch.linspace(0, 1, 10).reshape(-1, 1).to(device)\n",
"num_latents = len(q) - 2\n",
"mtgpqr = CVMultitaskQuantileGP(inducing_points, len(q), (K,), num_latents).to(device)\n",
"q_expanded = q.unsqueeze(0).expand(K, -1)\n",
"likelihood_mtgpqr = MultitaskQuantileGPLikelihood(q_expanded).to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "14",
"metadata": {},
"outputs": [],
"source": [
"mll = VariationalELBO(likelihood_mtgpqr, mtgpqr, num_data=y_train_cv.shape[-1])\n",
"optimizer = torch.optim.Adam(\n",
" list(mtgpqr.parameters()) + list(likelihood_mtgpqr.parameters()),\n",
" lr=0.001,\n",
")\n",
"\n",
"mtgpqr_train_losses, mtgpqr_test_losses = [], []\n",
"for _ in range(n_epochs):\n",
" mtgpqr.train()\n",
" likelihood_mtgpqr.train()\n",
"\n",
" output = mtgpqr(x_train_cv)\n",
" train_loss = -mll(output, y_train_cv)\n",
" train_loss.sum().backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
"\n",
" mtgpqr.eval()\n",
" likelihood_mtgpqr.eval()\n",
" with torch.no_grad():\n",
" output = mtgpqr(x_test_cv)\n",
" test_loss = -mll(output, y_test_cv)\n",
"\n",
" mtgpqr_train_losses.append(train_loss.mean().item() / len(q))\n",
" mtgpqr_test_losses.append(test_loss.mean().item() / len(q))"
]
},
{
"cell_type": "markdown",
"id": "15",
"metadata": {},
"source": [
"## Multitask center-gap GPQR"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16",
"metadata": {},
"outputs": [],
"source": [
"from gpytorch_qr.means import CenterGapMean\n",
"from gpytorch_qr.models import CenterGapQuantileGP\n",
"from gpytorch_qr.variational import CGLmcVariationalStrategy\n",
"from gpytorch_qr.likelihoods import MultitaskCenterGapQuantileGPLikelihood\n",
"\n",
"\n",
"class CVMultitaskCenterGapQuantileGP(CenterGapQuantileGP):\n",
" def __init__(\n",
" self,\n",
" inducing_points,\n",
" num_quantiles,\n",
" num_lower_quantiles,\n",
" num_latents,\n",
" num_folds,\n",
" ):\n",
" N, D = inducing_points.size()\n",
" batch_shape = torch.Size([num_folds, num_latents])\n",
" variational_distribution = CholeskyVariationalDistribution(\n",
" N,\n",
" batch_shape=batch_shape,\n",
" )\n",
" variational_strategy = CGLmcVariationalStrategy(\n",
" VariationalStrategy(\n",
" self,\n",
" inducing_points,\n",
" variational_distribution,\n",
" learn_inducing_locations=True,\n",
" ),\n",
" num_quantiles=num_quantiles,\n",
" num_latents=num_latents,\n",
" )\n",
"\n",
" mean = CenterGapMean(\n",
" ConstantMean(batch_shape=torch.Size([num_folds, 1])),\n",
" ConstantMean(batch_shape=torch.Size([num_folds, num_latents - 1])),\n",
" latent_dim=-1,\n",
" )\n",
" covar = ScaleKernel(\n",
" RBFKernel(ard_num_dims=D, batch_shape=batch_shape),\n",
" batch_shape=batch_shape,\n",
" )\n",
" super().__init__(variational_strategy, mean, covar, -1, num_lower_quantiles)\n",
"\n",
"\n",
"inducing_points = torch.linspace(0, 1, 10).reshape(-1, 1).to(device)\n",
"central_q_index = (q - 0.5).abs().argmin().item()\n",
"num_latents = len(q) - 2\n",
"mtgpqr_cg = CVMultitaskCenterGapQuantileGP(\n",
" inducing_points, len(q), central_q_index, num_latents, K\n",
").to(device)\n",
"likelihood_mtgpqr_cg = MultitaskCenterGapQuantileGPLikelihood(\n",
" q.unsqueeze(0), central_q_index, torch.zeros((K, len(q)))\n",
").to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "17",
"metadata": {},
"outputs": [],
"source": [
"mll = VariationalELBO(likelihood_mtgpqr_cg, mtgpqr_cg, num_data=y_train_cv.shape[1])\n",
"optimizer = torch.optim.Adam(\n",
" list(mtgpqr_cg.parameters()) + list(likelihood_mtgpqr_cg.parameters()),\n",
" lr=0.001,\n",
")\n",
"\n",
"mtgpqr_cg_train_losses, mtgpqr_cg_test_losses = [], []\n",
"for _ in range(n_epochs):\n",
" mtgpqr_cg.train()\n",
" likelihood_mtgpqr_cg.train()\n",
"\n",
" output = mtgpqr_cg(x_train_cv)\n",
" train_loss = -mll(output, y_train_cv)\n",
" train_loss.sum().backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
"\n",
" mtgpqr_cg.eval()\n",
" likelihood_mtgpqr_cg.eval()\n",
" with torch.no_grad():\n",
" output = mtgpqr_cg(x_test_cv)\n",
" test_loss = -mll(output, y_test_cv)\n",
"\n",
" mtgpqr_cg_train_losses.append(train_loss.mean().item() / len(q))\n",
" mtgpqr_cg_test_losses.append(test_loss.mean().item() / len(q))"
]
},
{
"cell_type": "markdown",
"id": "18",
"metadata": {},
"source": [
"## Plot loss by epoch"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "19",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from matplotlib.lines import Line2D\n",
"\n",
"colors = plt.cm.tab10.colors\n",
"\n",
"losses = [\n",
" [gpqr_train_losses, gpqr_test_losses],\n",
" [gpqr_cg_train_losses, gpqr_cg_test_losses],\n",
" [mtgpqr_train_losses, mtgpqr_test_losses],\n",
" [mtgpqr_cg_train_losses, mtgpqr_cg_test_losses],\n",
"]\n",
"\n",
"for (train_losses, test_losses), color in zip(losses, colors):\n",
" plt.plot(train_losses, color=color, linestyle=\"-.\", alpha=0.5)\n",
" plt.plot(test_losses, color=color, linestyle=\"--\")\n",
"\n",
"model_handles = [Line2D([0], [0], color=colors[i], linestyle=\"-\") for i in range(4)]\n",
"style_handles = [\n",
" Line2D([0], [0], color=\"k\", linestyle=\"-.\", alpha=0.5),\n",
" Line2D([0], [0], color=\"k\", linestyle=\"--\"),\n",
"]\n",
"\n",
"handles = model_handles + style_handles\n",
"labels = [\"GPQR\", \"GPQR_CG\", \"MTGPQR\", \"MTGPQR_CG\", \"Train\", \"Test\"]\n",
"plt.legend(handles=handles, labels=labels, ncols=4)\n",
"\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Negative ELBO\")\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "heavyedge",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}