{
"cells": [
{
"cell_type": "markdown",
"id": "0",
"metadata": {},
"source": [
"# Multitask GPQR"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"import torch\n",
"from torch.distributions import Normal\n",
"from gpytorch.variational import CholeskyVariationalDistribution\n",
"from gpytorch.variational import UnwhitenedVariationalStrategy\n",
"from gpytorch.variational import LMCVariationalStrategy\n",
"from gpytorch.means import Mean\n",
"from gpytorch.kernels import RBFKernel, ScaleKernel\n",
"from gpytorch.mlls import VariationalELBO\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from gpytorch_qr.models import DirectQuantileGP\n",
"from gpytorch_qr.likelihoods import MultitaskQuantileGPLikelihood\n",
"\n",
"try:\n",
" import sys\n",
"\n",
" sys.path.insert(0, os.path.abspath(\"..\"))\n",
"\n",
" import config_notebook\n",
"except ImportError:\n",
" print(\"Output will not be deterministic SVG.\")\n",
"\n",
"torch.manual_seed(42)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"n_epochs = int(os.getenv(\"GPYTORCHQR_N_EPOCHS\", 10000))"
]
},
{
"cell_type": "markdown",
"id": "2",
"metadata": {},
"source": [
"## Data preparation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": [
"def mean(x):\n",
" return torch.cos(x * 2 * 3.14)\n",
"\n",
"\n",
"def std(x):\n",
" return x + 0.1\n",
"\n",
"\n",
"x = torch.linspace(0, 1, 100).reshape(-1, 1).to(device)\n",
"y = (mean(x) + torch.randn(x.shape, device=device).mul(std(x))).squeeze()\n",
"q = torch.linspace(0.1, 0.9, 9).to(device)\n",
"true_quantiles = mean(x) + std(x) * Normal(0, 1).icdf(q)\n",
"x_pred = torch.linspace(0, 1.5, 100).reshape(-1, 1).to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(x.cpu(), y.cpu(), c=\"k\", marker=\".\")\n",
"plt.plot(x.cpu(), true_quantiles.cpu(), \"--\", c=\"gray\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5",
"metadata": {},
"outputs": [],
"source": [
"class PriorMean(Mean):\n",
" def __init__(self, batch_shape=torch.Size([])):\n",
" super().__init__()\n",
" self.batch_shape = batch_shape\n",
" self.register_parameter(\"offset\", torch.nn.Parameter(torch.zeros(*batch_shape)))\n",
"\n",
" def forward(self, x):\n",
" res = mean(x).squeeze(-1)\n",
" return res + self.offset.unsqueeze(-1)\n",
"\n",
"\n",
"prior_mean = PriorMean(batch_shape=torch.Size([len(q)])).to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(x.cpu(), y.cpu(), c=\"k\", marker=\".\")\n",
"plt.plot(x_pred.cpu(), prior_mean(x_pred).detach().cpu().T)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "7",
"metadata": {},
"source": [
"## Define models and likelihoods"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8",
"metadata": {},
"outputs": [],
"source": [
"class MyGP_PriorMean(DirectQuantileGP):\n",
" def __init__(self, inducing_points, num_quantiles, num_latents):\n",
" N, D = inducing_points.size()\n",
" variational_distribution = CholeskyVariationalDistribution(\n",
" N,\n",
" batch_shape=torch.Size([num_latents]),\n",
" )\n",
" variational_strategy = LMCVariationalStrategy(\n",
" UnwhitenedVariationalStrategy(\n",
" self,\n",
" inducing_points,\n",
" variational_distribution,\n",
" learn_inducing_locations=False,\n",
" ),\n",
" num_tasks=num_quantiles,\n",
" num_latents=num_latents,\n",
" )\n",
" mean = PriorMean(batch_shape=torch.Size([num_latents]))\n",
" covar = ScaleKernel(\n",
" RBFKernel(ard_num_dims=D, batch_shape=torch.Size([num_latents])),\n",
" batch_shape=torch.Size([num_latents]),\n",
" )\n",
" super().__init__(variational_strategy, mean, covar, -1)\n",
"\n",
"\n",
"inducing_points = x.detach().clone()\n",
"num_latents = len(q) - 2\n",
"gp_priormean = MyGP_PriorMean(inducing_points, len(q), num_latents).to(device)\n",
"likelihood_priormean = MultitaskQuantileGPLikelihood(q).to(device)"
]
},
{
"cell_type": "markdown",
"id": "9",
"metadata": {},
"source": [
"## Train"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "10",
"metadata": {},
"outputs": [],
"source": [
"gp_priormean.train()\n",
"likelihood_priormean.train()\n",
"mll = VariationalELBO(likelihood_priormean, gp_priormean, num_data=y.numel())\n",
"optimizer = torch.optim.Adam(\n",
" list(gp_priormean.parameters()) + list(likelihood_priormean.parameters()),\n",
" lr=0.001,\n",
")\n",
"\n",
"for _ in range(n_epochs):\n",
" output = gp_priormean(x)\n",
" loss = -mll(output, y)\n",
" loss.backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()"
]
},
{
"cell_type": "markdown",
"id": "11",
"metadata": {},
"source": [
"## Evaluate"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12",
"metadata": {},
"outputs": [],
"source": [
"gp_priormean.eval()\n",
"with torch.no_grad():\n",
" quantiles_priormean = gp_priormean.mean_quantiles(x_pred)"
]
},
{
"cell_type": "markdown",
"id": "13",
"metadata": {},
"source": [
"## Plot result"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "14",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"colors = plt.cm.tab10(range(len(q)))\n",
"\n",
"plt.scatter(x.cpu(), y.cpu(), c=\"gray\", marker=\".\", alpha=0.1)\n",
"plt.plot(x.cpu(), true_quantiles.cpu(), \"--\", c=\"gray\", alpha=0.5)\n",
"\n",
"for i in range(len(q)):\n",
" plt.plot(x_pred.cpu(), quantiles_priormean[:, i].cpu(), color=colors[i])\n",
"\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "heavyedge",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}