{
"cells": [
{
"cell_type": "markdown",
"id": "0",
"metadata": {},
"source": [
"# Length Scale Constraint\n",
"\n",
"When data is deficient, quantiles are likely to overfit.\n",
"This can be mitigated by setting lower bounds to kernel length scales.\n",
"\n",
"In addition, we use block diagonal LMC structure to reduce the number of parameters."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"import torch\n",
"from gpytorch.constraints import Interval\n",
"from torch.distributions import Normal\n",
"from gpytorch.variational import CholeskyVariationalDistribution\n",
"from gpytorch.variational import VariationalStrategy\n",
"from gpytorch.means import ConstantMean\n",
"from gpytorch.kernels import RBFKernel, ScaleKernel\n",
"from gpytorch.mlls import VariationalELBO\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from gpytorch_qr.means import CenterGapMean\n",
"from gpytorch_qr.models import CenterGapQuantileGP\n",
"from gpytorch_qr.variational import CGBlkdiagLmcVariationalStrategy\n",
"from gpytorch_qr.likelihoods import MultitaskCenterGapQuantileGPLikelihood\n",
"\n",
"try:\n",
" import sys\n",
"\n",
" sys.path.insert(0, os.path.abspath(\"..\"))\n",
"\n",
" import config_notebook\n",
"except ImportError:\n",
" print(\"Output will not be deterministic SVG.\")\n",
"\n",
"torch.manual_seed(42)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"n_epochs = int(os.getenv(\"GPYTORCHQR_N_EPOCHS\", 10000))"
]
},
{
"cell_type": "markdown",
"id": "2",
"metadata": {},
"source": [
"## Data preparation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": [
"def mean(x):\n",
" return torch.cos(x[..., 0] * 2 * 3.14) * torch.cos(x[..., 1] * 2 * 3.14)\n",
"\n",
"\n",
"def std(x):\n",
" return x[..., 0] + x[..., 1] + 0.1\n",
"\n",
"\n",
"x2_values = torch.tensor([0.1, 0.5, 0.9]).to(device)\n",
"n_per_x2 = 20\n",
"x = torch.stack(\n",
" [\n",
" torch.rand(n_per_x2 * len(x2_values)).to(device),\n",
" x2_values.repeat_interleave(n_per_x2),\n",
" ],\n",
" dim=1,\n",
")\n",
"\n",
"y = (mean(x) + torch.randn(x.shape[0], device=device).mul(std(x))).squeeze()\n",
"\n",
"q = torch.tensor([0.1, 0.5, 0.9]).to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4",
"metadata": {},
"outputs": [],
"source": [
"x1_pred = torch.linspace(0, 1, 100).to(device)\n",
"x_pred = torch.stack(\n",
" [\n",
" x1_pred.unsqueeze(0).expand(len(x2_values), -1),\n",
" x2_values.unsqueeze(1).expand(-1, len(x1_pred)),\n",
" ],\n",
" dim=-1,\n",
") # *(X1, X2, D)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, axes = plt.subplots(1, len(x2_values), sharey=True)\n",
"\n",
"for i, x2_val in enumerate(x2_values):\n",
" ax = axes[i]\n",
"\n",
" mask = x[:, 1] == x2_val\n",
" ax.scatter(x[mask, 0].cpu(), y[mask].cpu(), c=\"k\", marker=\".\")\n",
"\n",
" x_line = x_pred[i]\n",
" true_q_lines = mean(x_line).unsqueeze(-1) + std(x_line).unsqueeze(-1) * Normal(\n",
" 0, 1\n",
" ).icdf(q)\n",
" ax.plot(x_line[..., 0].cpu(), true_q_lines.cpu(), \"--\", color=\"gray\")\n",
"\n",
" ax.set_title(f\"x2 = {x2_val:.1f}\")\n",
" ax.set_xlabel(\"x1\")\n",
" if i == 0:\n",
" ax.set_ylabel(\"y\")\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "6",
"metadata": {},
"source": [
"## Without constraint"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7",
"metadata": {},
"outputs": [],
"source": [
"class MyGP_NoConstraint(CenterGapQuantileGP):\n",
" def __init__(\n",
" self,\n",
" inducing_points,\n",
" num_quantiles,\n",
" num_lower_quantiles,\n",
" num_latents,\n",
" num_lower_latents,\n",
" ):\n",
" N, D = inducing_points.size()\n",
" variational_distribution = CholeskyVariationalDistribution(\n",
" N,\n",
" batch_shape=torch.Size([num_latents]),\n",
" )\n",
" variational_strategy = CGBlkdiagLmcVariationalStrategy(\n",
" VariationalStrategy(\n",
" self,\n",
" inducing_points,\n",
" variational_distribution,\n",
" learn_inducing_locations=True,\n",
" ),\n",
" num_quantiles=num_quantiles,\n",
" num_latents=num_latents,\n",
" num_lower_quantiles=num_lower_quantiles,\n",
" num_lower_latents=num_lower_latents,\n",
" )\n",
"\n",
" mean = CenterGapMean(\n",
" ConstantMean(batch_shape=torch.Size([1])),\n",
" ConstantMean(batch_shape=torch.Size([num_latents - 1])),\n",
" latent_dim=-1,\n",
" )\n",
" covar = ScaleKernel(\n",
" RBFKernel(ard_num_dims=D, batch_shape=torch.Size([num_latents])),\n",
" batch_shape=torch.Size([num_latents]),\n",
" )\n",
" super().__init__(variational_strategy, mean, covar, -1, num_lower_quantiles)\n",
"\n",
"\n",
"g1, g2 = torch.meshgrid(\n",
" torch.linspace(0, 1, 10),\n",
" torch.tensor([0.1, 0.3, 0.5, 0.7, 0.9]),\n",
" indexing=\"ij\",\n",
")\n",
"inducing_points = torch.stack([g1.flatten(), g2.flatten()], dim=1).to(device)\n",
"central_q_index = (q - 0.5).abs().argmin().item()\n",
"num_latents = len(q)\n",
"gp_noconstraint = MyGP_NoConstraint(\n",
" inducing_points, len(q), central_q_index, num_latents, num_latents // 2\n",
").to(device)\n",
"likelihood_noconstraint = MultitaskCenterGapQuantileGPLikelihood(q, central_q_index).to(\n",
" device\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8",
"metadata": {},
"outputs": [],
"source": [
"gp_noconstraint.train()\n",
"likelihood_noconstraint.train()\n",
"mll = VariationalELBO(likelihood_noconstraint, gp_noconstraint, num_data=y.numel())\n",
"optimizer = torch.optim.Adam(\n",
" list(gp_noconstraint.parameters()) + list(likelihood_noconstraint.parameters()),\n",
" lr=0.001,\n",
")\n",
"\n",
"for _ in range(n_epochs):\n",
" output = gp_noconstraint(x)\n",
" loss = -mll(output, y)\n",
" loss.backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9",
"metadata": {},
"outputs": [],
"source": [
"gp_noconstraint.eval()\n",
"with torch.no_grad():\n",
" mean_q_noconstraint = gp_noconstraint.mean_quantiles_mc(\n",
" x_pred.flatten(0, -2), 1\n",
" ).unflatten(-2, x_pred.shape[:-1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "10",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"colors = plt.cm.tab10.colors\n",
"\n",
"fig, axes = plt.subplots(1, len(x2_values), sharey=True)\n",
"\n",
"for i, x2_val in enumerate(x2_values):\n",
" ax = axes[i]\n",
"\n",
" mask = x[:, 1] == x2_val\n",
" ax.scatter(x[mask, 0].cpu(), y[mask].cpu(), c=\"gray\", marker=\".\", alpha=0.1)\n",
"\n",
" x_line = x_pred[i]\n",
" true_q_lines = mean(x_line).unsqueeze(-1) + std(x_line).unsqueeze(-1) * Normal(\n",
" 0, 1\n",
" ).icdf(q)\n",
"\n",
" ax.plot(x_line[..., 0].cpu(), true_q_lines.cpu(), \"--\", color=\"k\")\n",
"\n",
" for j in range(len(q)):\n",
" ax.plot(\n",
" x_line[..., 0].cpu(), mean_q_noconstraint[i, ..., j].cpu(), color=colors[j]\n",
" )\n",
"\n",
" ax.set_title(f\"x2 = {x2_val:.1f}\")\n",
" ax.set_xlabel(\"x1\")\n",
" if i == 0:\n",
" ax.set_ylabel(\"y\")\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "11",
"metadata": {},
"source": [
"## With constraint"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12",
"metadata": {},
"outputs": [],
"source": [
"class MyGP_Constraint(CenterGapQuantileGP):\n",
" def __init__(\n",
" self,\n",
" inducing_points,\n",
" num_quantiles,\n",
" num_lower_quantiles,\n",
" num_latents,\n",
" num_lower_latents,\n",
" ):\n",
" N, D = inducing_points.size()\n",
" variational_distribution = CholeskyVariationalDistribution(\n",
" N,\n",
" batch_shape=torch.Size([num_latents]),\n",
" )\n",
" variational_strategy = CGBlkdiagLmcVariationalStrategy(\n",
" VariationalStrategy(\n",
" self,\n",
" inducing_points,\n",
" variational_distribution,\n",
" learn_inducing_locations=True,\n",
" ),\n",
" num_quantiles=num_quantiles,\n",
" num_latents=num_latents,\n",
" num_lower_quantiles=num_lower_quantiles,\n",
" num_lower_latents=num_lower_latents,\n",
" )\n",
"\n",
" mean = CenterGapMean(\n",
" ConstantMean(batch_shape=torch.Size([1])),\n",
" ConstantMean(batch_shape=torch.Size([num_latents - 1])),\n",
" latent_dim=-1,\n",
" )\n",
" covar = ScaleKernel(\n",
" RBFKernel(ard_num_dims=D, batch_shape=torch.Size([num_latents])),\n",
" batch_shape=torch.Size([num_latents]),\n",
" )\n",
"\n",
" lower_constraint = torch.tensor([0.2, 0])\n",
" upper_constraint = torch.tensor([1e2, 1e2])\n",
" initial_lengthscale = torch.tensor([0.5, 0.5])\n",
" covar.base_kernel.register_constraint(\n",
" \"raw_lengthscale\", Interval(lower_constraint, upper_constraint)\n",
" )\n",
" with torch.no_grad():\n",
" covar.base_kernel.lengthscale = initial_lengthscale\n",
"\n",
" super().__init__(variational_strategy, mean, covar, -1, num_lower_quantiles)\n",
"\n",
"\n",
"g1, g2 = torch.meshgrid(\n",
" torch.linspace(0, 1, 10),\n",
" torch.tensor([0.1, 0.3, 0.5, 0.7, 0.9]),\n",
" indexing=\"ij\",\n",
")\n",
"inducing_points = torch.stack([g1.flatten(), g2.flatten()], dim=1).to(device)\n",
"central_q_index = (q - 0.5).abs().argmin().item()\n",
"num_latents = len(q)\n",
"gp_constraint = MyGP_Constraint(\n",
" inducing_points, len(q), central_q_index, num_latents, num_latents // 2\n",
").to(device)\n",
"likelihood_constraint = MultitaskCenterGapQuantileGPLikelihood(q, central_q_index).to(\n",
" device\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "13",
"metadata": {},
"outputs": [],
"source": [
"gp_constraint.train()\n",
"likelihood_constraint.train()\n",
"mll = VariationalELBO(likelihood_constraint, gp_constraint, num_data=y.numel())\n",
"optimizer = torch.optim.Adam(\n",
" list(gp_constraint.parameters()) + list(likelihood_constraint.parameters()),\n",
" lr=0.001,\n",
")\n",
"\n",
"for _ in range(n_epochs):\n",
" output = gp_constraint(x)\n",
" loss = -mll(output, y)\n",
" loss.backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "14",
"metadata": {},
"outputs": [],
"source": [
"gp_constraint.eval()\n",
"with torch.no_grad():\n",
" mean_q_constraint = gp_constraint.mean_quantiles_mc(\n",
" x_pred.flatten(0, -2), 1\n",
" ).unflatten(-2, x_pred.shape[:-1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "15",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"colors = plt.cm.tab10.colors\n",
"\n",
"fig, axes = plt.subplots(1, len(x2_values), sharey=True)\n",
"\n",
"for i, x2_val in enumerate(x2_values):\n",
" ax = axes[i]\n",
"\n",
" mask = x[:, 1] == x2_val\n",
" ax.scatter(x[mask, 0].cpu(), y[mask].cpu(), c=\"gray\", marker=\".\", alpha=0.1)\n",
"\n",
" x_line = x_pred[i]\n",
" true_q_lines = mean(x_line).unsqueeze(-1) + std(x_line).unsqueeze(-1) * Normal(\n",
" 0, 1\n",
" ).icdf(q)\n",
"\n",
" ax.plot(x_line[..., 0].cpu(), true_q_lines.cpu(), \"--\", color=\"k\")\n",
"\n",
" for j in range(len(q)):\n",
" ax.plot(\n",
" x_line[..., 0].cpu(), mean_q_constraint[i, ..., j].cpu(), color=colors[j]\n",
" )\n",
"\n",
" ax.set_title(f\"x2 = {x2_val:.1f}\")\n",
" ax.set_xlabel(\"x1\")\n",
" if i == 0:\n",
" ax.set_ylabel(\"y\")\n",
"\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "heavyedge",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}