{ "cells": [ { "cell_type": "markdown", "id": "0", "metadata": {}, "source": [ "# Length Scale Constraint\n", "\n", "When data is deficient, quantiles are likely to overfit.\n", "This can be mitigated by setting lower bounds to kernel length scales.\n", "\n", "In addition, we use block diagonal LMC structure to reduce the number of parameters." ] }, { "cell_type": "code", "execution_count": null, "id": "1", "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "import torch\n", "from gpytorch.constraints import Interval\n", "from torch.distributions import Normal\n", "from gpytorch.variational import CholeskyVariationalDistribution\n", "from gpytorch.variational import VariationalStrategy\n", "from gpytorch.means import ConstantMean\n", "from gpytorch.kernels import RBFKernel, ScaleKernel\n", "from gpytorch.mlls import VariationalELBO\n", "import matplotlib.pyplot as plt\n", "\n", "from gpytorch_qr.means import CenterGapMean\n", "from gpytorch_qr.models import CenterGapQuantileGP\n", "from gpytorch_qr.variational import CGBlkdiagLmcVariationalStrategy\n", "from gpytorch_qr.likelihoods import MultitaskCenterGapQuantileGPLikelihood\n", "\n", "try:\n", " import sys\n", "\n", " sys.path.insert(0, os.path.abspath(\"..\"))\n", "\n", " import config_notebook\n", "except ImportError:\n", " print(\"Output will not be deterministic SVG.\")\n", "\n", "torch.manual_seed(42)\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "n_epochs = int(os.getenv(\"GPYTORCHQR_N_EPOCHS\", 10000))" ] }, { "cell_type": "markdown", "id": "2", "metadata": {}, "source": [ "## Data preparation" ] }, { "cell_type": "code", "execution_count": null, "id": "3", "metadata": {}, "outputs": [], "source": [ "def mean(x):\n", " return torch.cos(x[..., 0] * 2 * 3.14) * torch.cos(x[..., 1] * 2 * 3.14)\n", "\n", "\n", "def std(x):\n", " return x[..., 0] + x[..., 1] + 0.1\n", "\n", "\n", "x2_values = torch.tensor([0.1, 0.5, 0.9]).to(device)\n", "n_per_x2 = 20\n", "x = torch.stack(\n", " [\n", " torch.rand(n_per_x2 * len(x2_values)).to(device),\n", " x2_values.repeat_interleave(n_per_x2),\n", " ],\n", " dim=1,\n", ")\n", "\n", "y = (mean(x) + torch.randn(x.shape[0], device=device).mul(std(x))).squeeze()\n", "\n", "q = torch.tensor([0.1, 0.5, 0.9]).to(device)" ] }, { "cell_type": "code", "execution_count": null, "id": "4", "metadata": {}, "outputs": [], "source": [ "x1_pred = torch.linspace(0, 1, 100).to(device)\n", "x_pred = torch.stack(\n", " [\n", " x1_pred.unsqueeze(0).expand(len(x2_values), -1),\n", " x2_values.unsqueeze(1).expand(-1, len(x1_pred)),\n", " ],\n", " dim=-1,\n", ") # *(X1, X2, D)" ] }, { "cell_type": "code", "execution_count": null, "id": "5", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.10.8, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, axes = plt.subplots(1, len(x2_values), sharey=True)\n", "\n", "for i, x2_val in enumerate(x2_values):\n", " ax = axes[i]\n", "\n", " mask = x[:, 1] == x2_val\n", " ax.scatter(x[mask, 0].cpu(), y[mask].cpu(), c=\"k\", marker=\".\")\n", "\n", " x_line = x_pred[i]\n", " true_q_lines = mean(x_line).unsqueeze(-1) + std(x_line).unsqueeze(-1) * Normal(\n", " 0, 1\n", " ).icdf(q)\n", " ax.plot(x_line[..., 0].cpu(), true_q_lines.cpu(), \"--\", color=\"gray\")\n", "\n", " ax.set_title(f\"x2 = {x2_val:.1f}\")\n", " ax.set_xlabel(\"x1\")\n", " if i == 0:\n", " ax.set_ylabel(\"y\")\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "6", "metadata": {}, "source": [ "## Without constraint" ] }, { "cell_type": "code", "execution_count": null, "id": "7", "metadata": {}, "outputs": [], "source": [ "class MyGP_NoConstraint(CenterGapQuantileGP):\n", " def __init__(\n", " self,\n", " inducing_points,\n", " num_quantiles,\n", " num_lower_quantiles,\n", " num_latents,\n", " num_lower_latents,\n", " ):\n", " N, D = inducing_points.size()\n", " variational_distribution = CholeskyVariationalDistribution(\n", " N,\n", " batch_shape=torch.Size([num_latents]),\n", " )\n", " variational_strategy = CGBlkdiagLmcVariationalStrategy(\n", " VariationalStrategy(\n", " self,\n", " inducing_points,\n", " variational_distribution,\n", " learn_inducing_locations=True,\n", " ),\n", " num_quantiles=num_quantiles,\n", " num_latents=num_latents,\n", " num_lower_quantiles=num_lower_quantiles,\n", " num_lower_latents=num_lower_latents,\n", " )\n", "\n", " mean = CenterGapMean(\n", " ConstantMean(batch_shape=torch.Size([1])),\n", " ConstantMean(batch_shape=torch.Size([num_latents - 1])),\n", " latent_dim=-1,\n", " )\n", " covar = ScaleKernel(\n", " RBFKernel(ard_num_dims=D, batch_shape=torch.Size([num_latents])),\n", " batch_shape=torch.Size([num_latents]),\n", " )\n", " super().__init__(variational_strategy, mean, covar, -1, num_lower_quantiles)\n", "\n", "\n", "g1, g2 = torch.meshgrid(\n", " torch.linspace(0, 1, 10),\n", " torch.tensor([0.1, 0.3, 0.5, 0.7, 0.9]),\n", " indexing=\"ij\",\n", ")\n", "inducing_points = torch.stack([g1.flatten(), g2.flatten()], dim=1).to(device)\n", "central_q_index = (q - 0.5).abs().argmin().item()\n", "num_latents = len(q)\n", "gp_noconstraint = MyGP_NoConstraint(\n", " inducing_points, len(q), central_q_index, num_latents, num_latents // 2\n", ").to(device)\n", "likelihood_noconstraint = MultitaskCenterGapQuantileGPLikelihood(q, central_q_index).to(\n", " device\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "8", "metadata": {}, "outputs": [], "source": [ "gp_noconstraint.train()\n", "likelihood_noconstraint.train()\n", "mll = VariationalELBO(likelihood_noconstraint, gp_noconstraint, num_data=y.numel())\n", "optimizer = torch.optim.Adam(\n", " list(gp_noconstraint.parameters()) + list(likelihood_noconstraint.parameters()),\n", " lr=0.001,\n", ")\n", "\n", "for _ in range(n_epochs):\n", " output = gp_noconstraint(x)\n", " loss = -mll(output, y)\n", " loss.backward()\n", " optimizer.step()\n", " optimizer.zero_grad()" ] }, { "cell_type": "code", "execution_count": null, "id": "9", "metadata": {}, "outputs": [], "source": [ "gp_noconstraint.eval()\n", "with torch.no_grad():\n", " mean_q_noconstraint = gp_noconstraint.mean_quantiles_mc(\n", " x_pred.flatten(0, -2), 1\n", " ).unflatten(-2, x_pred.shape[:-1])" ] }, { "cell_type": "code", "execution_count": null, "id": "10", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.10.8, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "colors = plt.cm.tab10.colors\n", "\n", "fig, axes = plt.subplots(1, len(x2_values), sharey=True)\n", "\n", "for i, x2_val in enumerate(x2_values):\n", " ax = axes[i]\n", "\n", " mask = x[:, 1] == x2_val\n", " ax.scatter(x[mask, 0].cpu(), y[mask].cpu(), c=\"gray\", marker=\".\", alpha=0.1)\n", "\n", " x_line = x_pred[i]\n", " true_q_lines = mean(x_line).unsqueeze(-1) + std(x_line).unsqueeze(-1) * Normal(\n", " 0, 1\n", " ).icdf(q)\n", "\n", " ax.plot(x_line[..., 0].cpu(), true_q_lines.cpu(), \"--\", color=\"k\")\n", "\n", " for j in range(len(q)):\n", " ax.plot(\n", " x_line[..., 0].cpu(), mean_q_noconstraint[i, ..., j].cpu(), color=colors[j]\n", " )\n", "\n", " ax.set_title(f\"x2 = {x2_val:.1f}\")\n", " ax.set_xlabel(\"x1\")\n", " if i == 0:\n", " ax.set_ylabel(\"y\")\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "11", "metadata": {}, "source": [ "## With constraint" ] }, { "cell_type": "code", "execution_count": null, "id": "12", "metadata": {}, "outputs": [], "source": [ "class MyGP_Constraint(CenterGapQuantileGP):\n", " def __init__(\n", " self,\n", " inducing_points,\n", " num_quantiles,\n", " num_lower_quantiles,\n", " num_latents,\n", " num_lower_latents,\n", " ):\n", " N, D = inducing_points.size()\n", " variational_distribution = CholeskyVariationalDistribution(\n", " N,\n", " batch_shape=torch.Size([num_latents]),\n", " )\n", " variational_strategy = CGBlkdiagLmcVariationalStrategy(\n", " VariationalStrategy(\n", " self,\n", " inducing_points,\n", " variational_distribution,\n", " learn_inducing_locations=True,\n", " ),\n", " num_quantiles=num_quantiles,\n", " num_latents=num_latents,\n", " num_lower_quantiles=num_lower_quantiles,\n", " num_lower_latents=num_lower_latents,\n", " )\n", "\n", " mean = CenterGapMean(\n", " ConstantMean(batch_shape=torch.Size([1])),\n", " ConstantMean(batch_shape=torch.Size([num_latents - 1])),\n", " latent_dim=-1,\n", " )\n", " covar = ScaleKernel(\n", " RBFKernel(ard_num_dims=D, batch_shape=torch.Size([num_latents])),\n", " batch_shape=torch.Size([num_latents]),\n", " )\n", "\n", " lower_constraint = torch.tensor([0.2, 0])\n", " upper_constraint = torch.tensor([1e2, 1e2])\n", " initial_lengthscale = torch.tensor([0.5, 0.5])\n", " covar.base_kernel.register_constraint(\n", " \"raw_lengthscale\", Interval(lower_constraint, upper_constraint)\n", " )\n", " with torch.no_grad():\n", " covar.base_kernel.lengthscale = initial_lengthscale\n", "\n", " super().__init__(variational_strategy, mean, covar, -1, num_lower_quantiles)\n", "\n", "\n", "g1, g2 = torch.meshgrid(\n", " torch.linspace(0, 1, 10),\n", " torch.tensor([0.1, 0.3, 0.5, 0.7, 0.9]),\n", " indexing=\"ij\",\n", ")\n", "inducing_points = torch.stack([g1.flatten(), g2.flatten()], dim=1).to(device)\n", "central_q_index = (q - 0.5).abs().argmin().item()\n", "num_latents = len(q)\n", "gp_constraint = MyGP_Constraint(\n", " inducing_points, len(q), central_q_index, num_latents, num_latents // 2\n", ").to(device)\n", "likelihood_constraint = MultitaskCenterGapQuantileGPLikelihood(q, central_q_index).to(\n", " device\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "13", "metadata": {}, "outputs": [], "source": [ "gp_constraint.train()\n", "likelihood_constraint.train()\n", "mll = VariationalELBO(likelihood_constraint, gp_constraint, num_data=y.numel())\n", "optimizer = torch.optim.Adam(\n", " list(gp_constraint.parameters()) + list(likelihood_constraint.parameters()),\n", " lr=0.001,\n", ")\n", "\n", "for _ in range(n_epochs):\n", " output = gp_constraint(x)\n", " loss = -mll(output, y)\n", " loss.backward()\n", " optimizer.step()\n", " optimizer.zero_grad()" ] }, { "cell_type": "code", "execution_count": null, "id": "14", "metadata": {}, "outputs": [], "source": [ "gp_constraint.eval()\n", "with torch.no_grad():\n", " mean_q_constraint = gp_constraint.mean_quantiles_mc(\n", " x_pred.flatten(0, -2), 1\n", " ).unflatten(-2, x_pred.shape[:-1])" ] }, { "cell_type": "code", "execution_count": null, "id": "15", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.10.8, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "colors = plt.cm.tab10.colors\n", "\n", "fig, axes = plt.subplots(1, len(x2_values), sharey=True)\n", "\n", "for i, x2_val in enumerate(x2_values):\n", " ax = axes[i]\n", "\n", " mask = x[:, 1] == x2_val\n", " ax.scatter(x[mask, 0].cpu(), y[mask].cpu(), c=\"gray\", marker=\".\", alpha=0.1)\n", "\n", " x_line = x_pred[i]\n", " true_q_lines = mean(x_line).unsqueeze(-1) + std(x_line).unsqueeze(-1) * Normal(\n", " 0, 1\n", " ).icdf(q)\n", "\n", " ax.plot(x_line[..., 0].cpu(), true_q_lines.cpu(), \"--\", color=\"k\")\n", "\n", " for j in range(len(q)):\n", " ax.plot(\n", " x_line[..., 0].cpu(), mean_q_constraint[i, ..., j].cpu(), color=colors[j]\n", " )\n", "\n", " ax.set_title(f\"x2 = {x2_val:.1f}\")\n", " ax.set_xlabel(\"x1\")\n", " if i == 0:\n", " ax.set_ylabel(\"y\")\n", "\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "heavyedge", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.13" } }, "nbformat": 4, "nbformat_minor": 5 }