{
"cells": [
{
"cell_type": "markdown",
"id": "0",
"metadata": {},
"source": [
"# Batch Independent GPQR"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"import torch\n",
"from torch.distributions import Normal\n",
"from gpytorch.variational import CholeskyVariationalDistribution\n",
"from gpytorch.variational import VariationalStrategy\n",
"from gpytorch.means import ConstantMean\n",
"from gpytorch.kernels import RBFKernel, ScaleKernel\n",
"from gpytorch.mlls import VariationalELBO\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from gpytorch_qr.models import DirectQuantileGP\n",
"from gpytorch_qr.likelihoods import BatchQuantileGPLikelihood\n",
"\n",
"try:\n",
" import sys\n",
"\n",
" sys.path.insert(0, os.path.abspath(\"..\"))\n",
"\n",
" import config_notebook\n",
"except ImportError:\n",
" print(\"Output will not be deterministic SVG.\")\n",
"\n",
"torch.manual_seed(42)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"n_epochs = int(os.getenv(\"GPYTORCHQR_N_EPOCHS\", 5000))"
]
},
{
"cell_type": "markdown",
"id": "2",
"metadata": {},
"source": [
"## Data preparation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": [
"def mean(x):\n",
" return torch.cos(x[..., 0] * 2 * 3.14) * torch.cos(x[..., 1] * 2 * 3.14)\n",
"\n",
"\n",
"def std(x):\n",
" return x[..., 0] + x[..., 1] + 0.1\n",
"\n",
"\n",
"x2_values = torch.tensor([0.1, 0.5, 0.9]).to(device)\n",
"n_per_x2 = 500\n",
"x = torch.stack(\n",
" [\n",
" torch.rand(n_per_x2 * len(x2_values)).to(device),\n",
" x2_values.repeat_interleave(n_per_x2),\n",
" ],\n",
" dim=1,\n",
")\n",
"\n",
"y = (mean(x) + torch.randn(x.shape[0], device=device).mul(std(x))).squeeze()\n",
"\n",
"q = torch.tensor([0.1, 0.25, 0.5, 0.75, 0.9]).to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4",
"metadata": {},
"outputs": [],
"source": [
"x1_pred = torch.linspace(0, 1, 100).to(device)\n",
"x_pred = torch.stack(\n",
" [\n",
" x1_pred.unsqueeze(0).expand(len(x2_values), -1),\n",
" x2_values.unsqueeze(1).expand(-1, len(x1_pred)),\n",
" ],\n",
" dim=-1,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, axes = plt.subplots(1, len(x2_values), sharey=True)\n",
"\n",
"for i, x2_val in enumerate(x2_values):\n",
" ax = axes[i]\n",
"\n",
" mask = x[:, 1] == x2_val\n",
" ax.scatter(x[mask, 0].cpu(), y[mask].cpu(), c=\"k\", marker=\".\")\n",
"\n",
" x_line = x_pred[i]\n",
" true_q_lines = mean(x_line).unsqueeze(-1) + std(x_line).unsqueeze(-1) * Normal(\n",
" 0, 1\n",
" ).icdf(q)\n",
" ax.plot(x_line[..., 0].cpu(), true_q_lines.cpu(), \"--\", color=\"gray\")\n",
"\n",
" ax.set_title(f\"x2 = {x2_val:.1f}\")\n",
" ax.set_xlabel(\"x1\")\n",
" if i == 0:\n",
" ax.set_ylabel(\"y\")\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "6",
"metadata": {},
"source": [
"## Define model and likelihood"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7",
"metadata": {},
"outputs": [],
"source": [
"class MyGP(DirectQuantileGP):\n",
" def __init__(self, inducing_points, num_quantiles):\n",
" N, D = inducing_points.size()\n",
" variational_distribution = CholeskyVariationalDistribution(\n",
" N,\n",
" batch_shape=torch.Size([num_quantiles]),\n",
" )\n",
" variational_strategy = VariationalStrategy(\n",
" self,\n",
" inducing_points,\n",
" variational_distribution,\n",
" learn_inducing_locations=True,\n",
" )\n",
" mean = ConstantMean(batch_shape=torch.Size([num_quantiles]))\n",
" covar = ScaleKernel(\n",
" RBFKernel(ard_num_dims=D, batch_shape=torch.Size([num_quantiles])),\n",
" batch_shape=torch.Size([num_quantiles]),\n",
" )\n",
" super().__init__(variational_strategy, mean, covar, 0)\n",
"\n",
"\n",
"g1, g2 = torch.meshgrid(\n",
" torch.linspace(0, 1, 10),\n",
" torch.tensor([0.1, 0.3, 0.5, 0.7, 0.9]),\n",
" indexing=\"ij\",\n",
")\n",
"inducing_points = torch.stack([g1.flatten(), g2.flatten()], dim=1).to(device)\n",
"gp = MyGP(inducing_points, len(q)).to(device)\n",
"likelihood = BatchQuantileGPLikelihood(q).to(device)"
]
},
{
"cell_type": "markdown",
"id": "8",
"metadata": {},
"source": [
"## Train"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9",
"metadata": {},
"outputs": [],
"source": [
"gp.train()\n",
"likelihood.train()\n",
"mll = VariationalELBO(likelihood, gp, num_data=y.numel())\n",
"optimizer = torch.optim.Adam(\n",
" list(gp.parameters()) + list(likelihood.parameters()),\n",
" lr=0.001,\n",
")\n",
"\n",
"for _ in range(n_epochs):\n",
" output = gp(x)\n",
" loss = -mll(output, y).sum()\n",
" loss.backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()"
]
},
{
"cell_type": "markdown",
"id": "10",
"metadata": {},
"source": [
"## Plot result"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11",
"metadata": {},
"outputs": [],
"source": [
"gp.eval()\n",
"with torch.no_grad():\n",
" mean_q = gp.mean_quantiles(x_pred.flatten(0, -2)).unflatten(-1, x_pred.shape[:-1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"colors = plt.cm.tab10.colors\n",
"\n",
"fig, axes = plt.subplots(1, len(x2_values), sharey=True)\n",
"\n",
"for i, x2_val in enumerate(x2_values):\n",
" ax = axes[i]\n",
"\n",
" mask = x[:, 1] == x2_val\n",
" ax.scatter(x[mask, 0].cpu(), y[mask].cpu(), c=\"gray\", marker=\".\", alpha=0.1)\n",
"\n",
" x_line = x_pred[i]\n",
" true_q_lines = mean(x_line).unsqueeze(-1) + std(x_line).unsqueeze(-1) * Normal(\n",
" 0, 1\n",
" ).icdf(q)\n",
"\n",
" ax.plot(x_line[..., 0].cpu(), true_q_lines.cpu(), \"--\", color=\"k\")\n",
"\n",
" for j in range(len(q)):\n",
" ax.plot(x_line[..., 0].cpu(), mean_q[j, i, ...].cpu(), color=colors[j])\n",
"\n",
" ax.set_title(f\"x2 = {x2_val:.1f}\")\n",
" ax.set_xlabel(\"x1\")\n",
" if i == 0:\n",
" ax.set_ylabel(\"y\")\n",
"\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "heavyedge",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}