{
"cells": [
{
"cell_type": "markdown",
"id": "0",
"metadata": {},
"source": [
"# GPQR Tutoral\n",
"\n",
"Gaussian process quantile regression (GPQR) estimates quantile function using Gaussian process regression.\n",
"\n",
"In this tutorial, we estimate multiple quantile levels for data with heteroscadastic noise. To avoid quantile crossing, multitask correlated GPQR model with center-gap representation and informative prior is used."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"import gpytorch\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from gpytorch_qr.means import CenterGapMean\n",
"from gpytorch_qr.models import CenterGapQuantileGP\n",
"from gpytorch_qr.variational import CGLmcVariationalStrategy\n",
"from gpytorch_qr.likelihoods import MultitaskCenterGapQuantileGPLikelihood\n",
"\n",
"try:\n",
" import config_notebook\n",
"except ImportError:\n",
" print(\"Output will not be deterministic SVG.\")\n",
"\n",
"torch.manual_seed(42)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"n_epochs = int(os.getenv(\"GPYTORCHQR_N_EPOCHS\", 5000))"
]
},
{
"cell_type": "markdown",
"id": "2",
"metadata": {},
"source": [
"We generate sinusiodal data with heteroscedastic noise."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": [
"def mean(x):\n",
" return torch.cos(x * 2 * 3.14)\n",
"\n",
"\n",
"def std(x):\n",
" return x + 0.1\n",
"\n",
"\n",
"x_range = torch.linspace(0, 1, 100).reshape(-1, 1).to(device)\n",
"x = x_range.repeat(5, 1)\n",
"y = (mean(x) + torch.randn(x.shape, device=device).mul(std(x))).squeeze()\n",
"q = torch.tensor([0.1, 0.25, 0.5, 0.75, 0.9]).to(device)\n",
"true_quantiles = mean(x_range) + std(x_range) * torch.distributions.Normal(0, 1).icdf(q)\n",
"x_pred = torch.linspace(0, 1.5, 100).reshape(-1, 1).to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(x.cpu(), y.cpu(), c=\"k\", marker=\".\")\n",
"plt.plot(x_range.cpu(), true_quantiles.cpu(), \"--\", c=\"gray\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "5",
"metadata": {},
"source": [
"## Placing prior mean\n",
"\n",
"We place prior mean for the median quantile.\n",
"Other quantiles will have prior means with constant offset to the median."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6",
"metadata": {},
"outputs": [],
"source": [
"class PriorMean(gpytorch.means.Mean):\n",
" def __init__(self, batch_shape=torch.Size()):\n",
" super().__init__()\n",
" self.batch_shape = batch_shape\n",
"\n",
" def forward(self, x):\n",
" return mean(x).squeeze(-1).expand(*self.batch_shape, x.shape[-2])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(x.cpu(), y.cpu(), c=\"k\", marker=\".\")\n",
"plt.plot(x_pred.cpu(), PriorMean()(x_pred).detach().cpu(), c=\"r\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "8",
"metadata": {},
"source": [
"## Define model and likelihood\n",
"\n",
"To model the correlation between quantiles, the number of latent GP should be smaller than the number of tasks (= number of quantiles)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9",
"metadata": {},
"outputs": [],
"source": [
"class MyGP(CenterGapQuantileGP):\n",
" def __init__(\n",
" self,\n",
" inducing_points,\n",
" num_quantiles,\n",
" num_lower_quantiles,\n",
" num_latents,\n",
" ):\n",
" N, D = inducing_points.size()\n",
" variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(\n",
" N,\n",
" batch_shape=torch.Size([num_latents]),\n",
" )\n",
" variational_strategy = CGLmcVariationalStrategy(\n",
" gpytorch.variational.VariationalStrategy(\n",
" self,\n",
" inducing_points,\n",
" variational_distribution,\n",
" learn_inducing_locations=True,\n",
" ),\n",
" num_quantiles=num_quantiles,\n",
" num_latents=num_latents,\n",
" )\n",
"\n",
" mean = CenterGapMean(\n",
" PriorMean(batch_shape=torch.Size([1])),\n",
" gpytorch.means.ConstantMean(batch_shape=torch.Size([num_latents - 1])),\n",
" latent_dim=-1,\n",
" )\n",
" covar = gpytorch.kernels.ScaleKernel(\n",
" gpytorch.kernels.RBFKernel(\n",
" ard_num_dims=D, batch_shape=torch.Size([num_latents])\n",
" ),\n",
" batch_shape=torch.Size([num_latents]),\n",
" )\n",
" super().__init__(variational_strategy, mean, covar, -1, num_lower_quantiles)\n",
"\n",
"\n",
"inducing_points = torch.linspace(0, 1, 10).reshape(-1, 1).to(device)\n",
"central_q_index = 2\n",
"num_latents = 3\n",
"gp = MyGP(inducing_points, len(q), central_q_index, num_latents).to(device)\n",
"likelihood = MultitaskCenterGapQuantileGPLikelihood(q, central_q_index).to(device)"
]
},
{
"cell_type": "markdown",
"id": "10",
"metadata": {},
"source": [
"## Training the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11",
"metadata": {},
"outputs": [],
"source": [
"gp.train()\n",
"likelihood.train()\n",
"mll = gpytorch.mlls.VariationalELBO(likelihood, gp, num_data=y.numel())\n",
"optimizer = torch.optim.Adam(\n",
" list(gp.parameters()) + list(likelihood.parameters()),\n",
" lr=0.001,\n",
")\n",
"\n",
"for _ in range(n_epochs):\n",
" output = gp(x)\n",
" loss = -mll(output, y)\n",
" loss.backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()"
]
},
{
"cell_type": "markdown",
"id": "12",
"metadata": {},
"source": [
"## Estimating the result\n",
"\n",
"Because center-gap architecture leads to non-analytic posterior of quantiles, we estimate the mean and credible interval of quantiles using Monte-Carlo sampling."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "13",
"metadata": {},
"outputs": [],
"source": [
"gp.eval()\n",
"with torch.no_grad():\n",
" mean_q = gp.mean_quantiles_mc(x_pred)\n",
" lower_q, upper_q = gp.quantile_quantiles_mc(\n",
" x_pred, torch.tensor([0.025, 0.975]).to(device)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "14",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"colors = plt.cm.tab10.colors\n",
"\n",
"plt.scatter(x.cpu(), y.cpu(), c=\"gray\", marker=\".\", alpha=0.1)\n",
"plt.plot(x_range.cpu(), true_quantiles.cpu(), \"--\", c=\"k\")\n",
"\n",
"for i in range(len(q)):\n",
" plt.plot(x_pred.cpu(), mean_q[:, i].cpu(), color=colors[i])\n",
" plt.fill_between(\n",
" x_pred.cpu().squeeze(),\n",
" lower_q[:, i].cpu(),\n",
" upper_q[:, i].cpu(),\n",
" color=colors[i],\n",
" alpha=0.3,\n",
" )\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "15",
"metadata": {},
"source": [
"### Predictive posterior of response variable\n",
"\n",
"It is possible to evaluate the predictive posterior of response variable from quantile regression, just as ordinary Gaussian process regression.\n",
"\n",
"Estimation from extreme quantiles are less credible than near-median quantiles."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16",
"metadata": {},
"outputs": [],
"source": [
"with torch.no_grad(), gpytorch.settings.num_likelihood_samples(1000):\n",
" pp = likelihood.predictive_posterior(gp(x_pred))\n",
"pp_mean = pp.mean(dim=0)\n",
"pp_lower = pp.quantile(0.025, dim=0)\n",
"pp_upper = pp.quantile(0.975, dim=0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "17",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(x_range.cpu(), mean(x_range).cpu(), \"--\", c=\"k\")\n",
"\n",
"for i in range(len(q)):\n",
" plt.plot(x_pred.cpu(), pp_mean[:, i].cpu(), color=colors[i])\n",
" plt.fill_between(\n",
" x_pred.cpu().squeeze(),\n",
" pp_lower[:, i].cpu(),\n",
" pp_upper[:, i].cpu(),\n",
" color=colors[i],\n",
" alpha=0.3,\n",
" )\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "heavyedge",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}