{
"cells": [
{
"cell_type": "markdown",
"id": "0",
"metadata": {},
"source": [
"# Multitask Independent GPQR"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"import torch\n",
"from torch.distributions import Normal\n",
"from gpytorch.variational import CholeskyVariationalDistribution\n",
"from gpytorch.variational import (\n",
" VariationalStrategy,\n",
" IndependentMultitaskVariationalStrategy,\n",
")\n",
"from gpytorch.means import ConstantMean\n",
"from gpytorch.kernels import RBFKernel, ScaleKernel\n",
"from gpytorch.mlls import VariationalELBO\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from gpytorch_qr.models import DirectQuantileGP\n",
"from gpytorch_qr.likelihoods import MultitaskQuantileGPLikelihood\n",
"\n",
"try:\n",
" import sys\n",
"\n",
" sys.path.insert(0, os.path.abspath(\"..\"))\n",
"\n",
" import config_notebook\n",
"except ImportError:\n",
" print(\"Output will not be deterministic SVG.\")\n",
"\n",
"torch.manual_seed(42)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"n_epochs = int(os.getenv(\"GPYTORCHQR_N_EPOCHS\", 5000))"
]
},
{
"cell_type": "markdown",
"id": "2",
"metadata": {},
"source": [
"## Data preparation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": [
"def mean(x):\n",
" return torch.cos(x * 2 * 3.14)\n",
"\n",
"\n",
"def std(x):\n",
" return x + 0.1\n",
"\n",
"\n",
"x_range = torch.linspace(0, 1, 100).reshape(-1, 1).to(device)\n",
"x = x_range.repeat(5, 1)\n",
"y = (mean(x) + torch.randn(x.shape, device=device).mul(std(x))).squeeze()\n",
"q = torch.tensor([0.1, 0.25, 0.5, 0.75, 0.9]).to(device)\n",
"true_quantiles = mean(x_range) + std(x_range) * Normal(0, 1).icdf(q)\n",
"x_pred = torch.linspace(0, 1.5, 100).reshape(-1, 1).to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(x.cpu(), y.cpu(), c=\"k\", marker=\".\")\n",
"plt.plot(x_range.cpu(), true_quantiles.cpu(), \"--\", c=\"gray\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "5",
"metadata": {},
"source": [
"## Define model and likelihood\n",
"\n",
"To model the correlation between quantiles, the number of latent GP should be smaller than the number of tasks (= number of quantiles)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6",
"metadata": {},
"outputs": [],
"source": [
"class MyGP(DirectQuantileGP):\n",
" def __init__(self, inducing_points, num_quantiles):\n",
" N, D = inducing_points.size()\n",
" variational_distribution = CholeskyVariationalDistribution(\n",
" N,\n",
" batch_shape=torch.Size([num_quantiles]),\n",
" )\n",
" variational_strategy = IndependentMultitaskVariationalStrategy(\n",
" VariationalStrategy(\n",
" self,\n",
" inducing_points,\n",
" variational_distribution,\n",
" learn_inducing_locations=True,\n",
" ),\n",
" num_tasks=num_quantiles,\n",
" )\n",
"\n",
" mean_module = ConstantMean(batch_shape=torch.Size([num_quantiles]))\n",
" covar_module = ScaleKernel(\n",
" RBFKernel(ard_num_dims=D, batch_shape=torch.Size([num_quantiles])),\n",
" batch_shape=torch.Size([num_quantiles]),\n",
" )\n",
" super().__init__(variational_strategy, mean_module, covar_module, -1)\n",
"\n",
"\n",
"inducing_points = torch.linspace(0, 1, 10).reshape(-1, 1).to(device)\n",
"gp = MyGP(inducing_points, len(q)).to(device)\n",
"likelihood = MultitaskQuantileGPLikelihood(q).to(device)"
]
},
{
"cell_type": "markdown",
"id": "7",
"metadata": {},
"source": [
"## Train"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8",
"metadata": {},
"outputs": [],
"source": [
"gp.train()\n",
"likelihood.train()\n",
"mll = VariationalELBO(likelihood, gp, num_data=y.numel())\n",
"optimizer = torch.optim.Adam(\n",
" list(gp.parameters()) + list(likelihood.parameters()),\n",
" lr=0.001,\n",
")\n",
"\n",
"for _ in range(n_epochs):\n",
" output = gp(x)\n",
" loss = -mll(output, y)\n",
" loss.backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()"
]
},
{
"cell_type": "markdown",
"id": "9",
"metadata": {},
"source": [
"## Plot result"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "10",
"metadata": {},
"outputs": [],
"source": [
"gp.eval()\n",
"with torch.no_grad():\n",
" mean_q = gp.mean_quantiles(x_pred)\n",
" lower_q, upper_q = gp.quantile_quantiles(\n",
" x_pred, torch.tensor([0.025, 0.975]).to(device)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"colors = plt.cm.tab10.colors\n",
"\n",
"plt.scatter(x.cpu(), y.cpu(), c=\"gray\", marker=\".\", alpha=0.1)\n",
"plt.plot(x_range.cpu(), true_quantiles.cpu(), \"--\", c=\"k\")\n",
"\n",
"for i in range(len(q)):\n",
" plt.plot(x_pred.cpu(), mean_q[:, i].cpu(), color=colors[i])\n",
" plt.fill_between(\n",
" x_pred.cpu().squeeze(),\n",
" lower_q[:, i].cpu(),\n",
" upper_q[:, i].cpu(),\n",
" color=colors[i],\n",
" alpha=0.3,\n",
" )\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "heavyedge",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}