{ "cells": [ { "cell_type": "markdown", "id": "0", "metadata": {}, "source": [ "# Multitask GPQR (Center-Gap)" ] }, { "cell_type": "code", "execution_count": null, "id": "1", "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "import torch\n", "from torch.distributions import Normal\n", "from gpytorch.variational import CholeskyVariationalDistribution\n", "from gpytorch.variational import UnwhitenedVariationalStrategy\n", "from gpytorch.means import ConstantMean, Mean\n", "from gpytorch.kernels import RBFKernel, ScaleKernel\n", "from gpytorch.mlls import VariationalELBO\n", "import matplotlib.pyplot as plt\n", "\n", "from gpytorch_qr.means import CenterGapMean\n", "from gpytorch_qr.models import CenterGapQuantileGP\n", "from gpytorch_qr.variational import CGLmcVariationalStrategy\n", "from gpytorch_qr.likelihoods import MultitaskCenterGapQuantileGPLikelihood\n", "\n", "try:\n", " import sys\n", "\n", " sys.path.insert(0, os.path.abspath(\"..\"))\n", "\n", " import config_notebook\n", "except ImportError:\n", " print(\"Output will not be deterministic SVG.\")\n", "\n", "torch.manual_seed(42)\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "n_epochs = int(os.getenv(\"GPYTORCHQR_N_EPOCHS\", 10000))" ] }, { "cell_type": "markdown", "id": "2", "metadata": {}, "source": [ "## Data preparation" ] }, { "cell_type": "code", "execution_count": null, "id": "3", "metadata": {}, "outputs": [], "source": [ "def mean(x):\n", " return torch.cos(x * 2 * 3.14)\n", "\n", "\n", "def std(x):\n", " return x + 0.1\n", "\n", "\n", "x = torch.linspace(0, 1, 100).reshape(-1, 1).to(device)\n", "y = (mean(x) + torch.randn(x.shape, device=device).mul(std(x))).squeeze()\n", "q = torch.linspace(0.1, 0.9, 9).to(device)\n", "true_quantiles = mean(x) + std(x) * Normal(0, 1).icdf(q)\n", "x_pred = torch.linspace(0, 1.5, 100).reshape(-1, 1).to(device)" ] }, { "cell_type": "code", "execution_count": null, "id": "4", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.10.8, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.scatter(x.cpu(), y.cpu(), c=\"k\", marker=\".\")\n", "plt.plot(x.cpu(), true_quantiles.cpu(), \"--\", c=\"gray\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "5", "metadata": {}, "source": [ "## Prior mean" ] }, { "cell_type": "code", "execution_count": null, "id": "6", "metadata": {}, "outputs": [], "source": [ "class PriorMean(Mean):\n", " def __init__(self, batch_shape=torch.Size()):\n", " super().__init__()\n", " self.batch_shape = batch_shape\n", "\n", " def forward(self, x):\n", " return mean(x).squeeze(-1).expand(*self.batch_shape, x.shape[-2])\n", "\n", "\n", "prior_mean = PriorMean().to(device)" ] }, { "cell_type": "code", "execution_count": null, "id": "7", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.10.8, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.scatter(x.cpu(), y.cpu(), c=\"k\", marker=\".\")\n", "plt.plot(x_pred.cpu(), prior_mean(x_pred).detach().cpu())\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "8", "metadata": {}, "source": [ "## Define models and likelihoods" ] }, { "cell_type": "code", "execution_count": null, "id": "9", "metadata": {}, "outputs": [], "source": [ "class MyGP_PriorMean(CenterGapQuantileGP):\n", " def __init__(\n", " self,\n", " inducing_points,\n", " num_quantiles,\n", " num_lower_quantiles,\n", " num_latents,\n", " ):\n", " N, D = inducing_points.size()\n", " variational_distribution = CholeskyVariationalDistribution(\n", " N,\n", " batch_shape=torch.Size([num_latents]),\n", " )\n", " variational_strategy = CGLmcVariationalStrategy(\n", " UnwhitenedVariationalStrategy(\n", " self,\n", " inducing_points,\n", " variational_distribution,\n", " learn_inducing_locations=True,\n", " ),\n", " num_quantiles=num_quantiles,\n", " num_latents=num_latents,\n", " )\n", "\n", " mean = CenterGapMean(\n", " PriorMean(batch_shape=torch.Size([1])),\n", " ConstantMean(batch_shape=torch.Size([num_latents - 1])),\n", " latent_dim=-1,\n", " )\n", " covar = ScaleKernel(\n", " RBFKernel(ard_num_dims=D, batch_shape=torch.Size([num_latents])),\n", " batch_shape=torch.Size([num_latents]),\n", " )\n", " super().__init__(variational_strategy, mean, covar, -1, num_lower_quantiles)\n", "\n", "\n", "inducing_points = x.detach().clone()\n", "central_q_index = (q - 0.5).abs().argmin().item()\n", "num_latents = len(q) - 2\n", "gp_priormean = MyGP_PriorMean(inducing_points, len(q), central_q_index, num_latents).to(\n", " device\n", ")\n", "likelihood_priormean = MultitaskCenterGapQuantileGPLikelihood(q, central_q_index).to(\n", " device\n", ")" ] }, { "cell_type": "markdown", "id": "10", "metadata": {}, "source": [ "## Train" ] }, { "cell_type": "code", "execution_count": null, "id": "11", "metadata": {}, "outputs": [], "source": [ "gp_priormean.train()\n", "likelihood_priormean.train()\n", "mll = VariationalELBO(likelihood_priormean, gp_priormean, num_data=y.numel())\n", "optimizer = torch.optim.Adam(\n", " list(gp_priormean.parameters()) + list(likelihood_priormean.parameters()),\n", " lr=0.001,\n", ")\n", "\n", "for _ in range(n_epochs):\n", " output = gp_priormean(x)\n", " loss = -mll(output, y)\n", " loss.backward()\n", " optimizer.step()\n", " optimizer.zero_grad()" ] }, { "cell_type": "markdown", "id": "12", "metadata": {}, "source": [ "## Evaluate" ] }, { "cell_type": "code", "execution_count": null, "id": "13", "metadata": {}, "outputs": [], "source": [ "gp_priormean.eval()\n", "with torch.no_grad():\n", " quantiles_priormean = gp_priormean.mean_quantiles_mc(x_pred)" ] }, { "cell_type": "markdown", "id": "14", "metadata": {}, "source": [ "## Plot result" ] }, { "cell_type": "code", "execution_count": null, "id": "15", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.10.8, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "colors = plt.cm.tab10(range(len(q)))\n", "\n", "plt.scatter(x.cpu(), y.cpu(), c=\"gray\", marker=\".\", alpha=0.1)\n", "plt.plot(x.cpu(), true_quantiles.cpu(), \"--\", c=\"gray\", alpha=0.5)\n", "\n", "for i in range(len(q)):\n", " plt.plot(x_pred.cpu(), quantiles_priormean[:, i].cpu(), color=colors[i])\n", "\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "heavyedge", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.13" } }, "nbformat": 4, "nbformat_minor": 5 }