{
"cells": [
{
"cell_type": "markdown",
"id": "0",
"metadata": {},
"source": [
"# Scaling with Prior Mean\n",
"\n",
"When inputs dimensions have different scales, they need to be standardized for GP regression.\n",
"On the other hand, informative prior mean often require input to be in its original scale.\n",
"\n",
"This notebook provides an example for combining the two requirements."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"import torch\n",
"from torch.distributions import Normal\n",
"from gpytorch.variational import CholeskyVariationalDistribution\n",
"from gpytorch.variational import VariationalStrategy\n",
"from gpytorch.means import Mean, ConstantMean\n",
"from gpytorch.kernels import RBFKernel, ScaleKernel\n",
"from gpytorch.mlls import VariationalELBO\n",
"from sklearn.preprocessing import StandardScaler\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from gpytorch_qr.means import CenterGapMean\n",
"from gpytorch_qr.models import CenterGapQuantileGP\n",
"from gpytorch_qr.variational import CGLmcVariationalStrategy\n",
"from gpytorch_qr.likelihoods import MultitaskCenterGapQuantileGPLikelihood\n",
"\n",
"try:\n",
" import sys\n",
"\n",
" sys.path.insert(0, os.path.abspath(\"..\"))\n",
"\n",
" import config_notebook\n",
"except ImportError:\n",
" print(\"Output will not be deterministic SVG.\")\n",
"\n",
"torch.manual_seed(42)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"n_epochs = int(os.getenv(\"GPYTORCHQR_N_EPOCHS\", 5000))"
]
},
{
"cell_type": "markdown",
"id": "2",
"metadata": {},
"source": [
"## Data preparation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": [
"def mean(x):\n",
" return torch.cos(x[..., 0] * 20 * 3.14) * torch.cos(x[..., 1] * 2 * 3.14)\n",
"\n",
"\n",
"def std(x):\n",
" return x[..., 0] + x[..., 1] + 0.1\n",
"\n",
"\n",
"N = 500\n",
"x = torch.stack(\n",
" [\n",
" torch.rand(N, device=device) * 0.1,\n",
" torch.rand(N, device=device) * 1.0,\n",
" ],\n",
" dim=1,\n",
")\n",
"\n",
"y = (mean(x) + torch.randn(x.shape[0], device=device).mul(std(x))).squeeze()\n",
"\n",
"q = torch.tensor([0.1, 0.25, 0.5, 0.75, 0.9]).to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plt.figure()\n",
"ax = fig.add_subplot(111, projection=\"3d\")\n",
"\n",
"x1_grid = torch.linspace(0, 0.1, 50)\n",
"x2_grid = torch.linspace(0, 1.0, 50)\n",
"X1, X2 = torch.meshgrid(x1_grid, x2_grid, indexing=\"ij\")\n",
"x_grid = torch.stack([X1.flatten(), X2.flatten()], dim=1)\n",
"Z = mean(x_grid).reshape(50, 50)\n",
"\n",
"ax.plot_surface(\n",
" X1.numpy(),\n",
" X2.numpy(),\n",
" Z.numpy(),\n",
" facecolors=plt.cm.coolwarm(\n",
" (Z.numpy() - Z.numpy().min()) / (Z.numpy().max() - Z.numpy().min())\n",
" ),\n",
" alpha=0.4,\n",
")\n",
"\n",
"# Scatter\n",
"ax.scatter(\n",
" x[:, 0].cpu().numpy(),\n",
" x[:, 1].cpu().numpy(),\n",
" y.cpu().numpy(),\n",
" s=5,\n",
" color=\"gray\",\n",
")\n",
"\n",
"ax.set_xlabel(\"x1\")\n",
"ax.set_ylabel(\"x2\")\n",
"ax.set_zlabel(\"y\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "5",
"metadata": {},
"source": [
"## Prior mean"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6",
"metadata": {},
"outputs": [],
"source": [
"class PriorMean(Mean):\n",
" def __init__(self, batch_shape=torch.Size()):\n",
" super().__init__()\n",
" self.batch_shape = batch_shape\n",
"\n",
" def forward(self, x):\n",
" return mean(x)"
]
},
{
"cell_type": "markdown",
"id": "7",
"metadata": {},
"source": [
"## Define models and likelihoods"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8",
"metadata": {},
"outputs": [],
"source": [
"class Unscaler(torch.nn.Module):\n",
" \"\"\"Inverse of StandardScaler: maps standardized input back to original scale.\"\"\"\n",
"\n",
" def __init__(self, X_scale, X_mean):\n",
" super().__init__()\n",
" self.register_buffer(\"X_scale\", X_scale)\n",
" self.register_buffer(\"X_mean\", X_mean)\n",
"\n",
" def forward(self, x):\n",
" x_flattened = x.view(-1, x.shape[-1])\n",
" x_unscaled = x_flattened * self.X_scale + self.X_mean\n",
" return x_unscaled.view_as(x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9",
"metadata": {},
"outputs": [],
"source": [
"class MyGP(CenterGapQuantileGP):\n",
" def __init__(\n",
" self,\n",
" inducing_points,\n",
" num_quantiles,\n",
" num_lower_quantiles,\n",
" num_latents,\n",
" X_scale=None,\n",
" X_mean=None,\n",
" ):\n",
" N, D = inducing_points.size()\n",
" variational_distribution = CholeskyVariationalDistribution(\n",
" N,\n",
" batch_shape=torch.Size([num_latents]),\n",
" )\n",
" variational_strategy = CGLmcVariationalStrategy(\n",
" VariationalStrategy(\n",
" self,\n",
" inducing_points,\n",
" variational_distribution,\n",
" learn_inducing_locations=True,\n",
" ),\n",
" num_quantiles=num_quantiles,\n",
" num_latents=num_latents,\n",
" )\n",
"\n",
" if X_scale is None:\n",
" X_scale = torch.ones(D)\n",
" if X_mean is None:\n",
" X_mean = torch.zeros(D)\n",
" unscaler = Unscaler(X_scale=X_scale, X_mean=X_mean)\n",
"\n",
" mean = CenterGapMean(\n",
" torch.nn.Sequential(unscaler, PriorMean(batch_shape=torch.Size([1]))),\n",
" ConstantMean(batch_shape=torch.Size([num_latents - 1])),\n",
" latent_dim=-1,\n",
" )\n",
" covar = ScaleKernel(\n",
" RBFKernel(ard_num_dims=D, batch_shape=torch.Size([num_latents])),\n",
" batch_shape=torch.Size([num_latents]),\n",
" )\n",
" super().__init__(variational_strategy, mean, covar, -1, num_lower_quantiles)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "10",
"metadata": {},
"outputs": [],
"source": [
"scaler = StandardScaler().fit(x.detach().cpu())\n",
"X_scale = torch.from_numpy(scaler.scale_).float()\n",
"X_mean = torch.from_numpy(scaler.mean_).float()\n",
"x = torch.from_numpy(scaler.transform(x.detach().cpu())).float().to(device)\n",
"\n",
"g1, g2 = torch.meshgrid(\n",
" torch.linspace(0, 1, 10),\n",
" torch.linspace(0, 1, 10),\n",
" indexing=\"ij\",\n",
")\n",
"inducing_points = torch.stack([g1.flatten(), g2.flatten()], dim=1)\n",
"\n",
"central_q_index = (q - 0.5).abs().argmin().item()\n",
"num_latents = len(q) - 2\n",
"gp = MyGP(\n",
" inducing_points,\n",
" len(q),\n",
" central_q_index,\n",
" num_latents,\n",
" X_scale,\n",
" X_mean,\n",
").to(device)\n",
"likelihood = MultitaskCenterGapQuantileGPLikelihood(q, central_q_index).to(device)"
]
},
{
"cell_type": "markdown",
"id": "11",
"metadata": {},
"source": [
"## Train"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12",
"metadata": {},
"outputs": [],
"source": [
"gp.train()\n",
"likelihood.train()\n",
"mll = VariationalELBO(likelihood, gp, num_data=y.numel())\n",
"optimizer = torch.optim.Adam(\n",
" list(gp.parameters()) + list(likelihood.parameters()),\n",
" lr=0.001,\n",
")\n",
"\n",
"for _ in range(n_epochs):\n",
" output = gp(x)\n",
" loss = -mll(output, y)\n",
" loss.backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()"
]
},
{
"cell_type": "markdown",
"id": "13",
"metadata": {},
"source": [
"## Plot result"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "14",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = plt.figure()\n",
"ax = fig.add_subplot(111, projection=\"3d\")\n",
"\n",
"x1_grid = torch.linspace(0, 0.1, 50)\n",
"x2_grid = torch.linspace(0, 1.0, 50)\n",
"X1, X2 = torch.meshgrid(x1_grid, x2_grid, indexing=\"ij\")\n",
"x_grid = torch.stack([X1.flatten(), X2.flatten()], dim=1)\n",
"\n",
"x_grid_scaled = torch.tensor(scaler.transform(x_grid.cpu().numpy())).to(device)\n",
"Z = gp.mean_module(x_grid_scaled.unsqueeze(0))[0, ...]\n",
"\n",
"X1_scaled = x_grid_scaled[:, 0].reshape(50, 50).cpu().numpy()\n",
"X2_scaled = x_grid_scaled[:, 1].reshape(50, 50).cpu().numpy()\n",
"Z_grid = Z.reshape(50, 50).detach().cpu().numpy()\n",
"\n",
"ax.plot_surface(\n",
" X1_scaled,\n",
" X2_scaled,\n",
" Z_grid,\n",
" facecolors=plt.cm.coolwarm((Z_grid - Z_grid.min()) / (Z_grid.max() - Z_grid.min())),\n",
" alpha=0.4,\n",
")\n",
"\n",
"ax.set_xlabel(\"x1 (scaled)\")\n",
"ax.set_ylabel(\"x2 (scaled)\")\n",
"ax.set_zlabel(\"y\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "15",
"metadata": {},
"outputs": [],
"source": [
"x1_pred = torch.linspace(0, 0.1, 100).to(device)\n",
"x2_values = torch.tensor([0.1, 0.5, 0.9]).to(device)\n",
"x_pred = torch.stack(\n",
" [\n",
" x1_pred.unsqueeze(0).expand(len(x2_values), -1),\n",
" x2_values.unsqueeze(1).expand(-1, len(x1_pred)),\n",
" ],\n",
" dim=-1,\n",
")\n",
"\n",
"gp.eval()\n",
"with torch.no_grad():\n",
" x_pred_flat = x_pred.detach().cpu().flatten(0, -2)\n",
" x_pred_scaled = (\n",
" torch.from_numpy(scaler.transform(x_pred_flat))\n",
" .float()\n",
" .to(device)\n",
" .reshape(x_pred.shape)\n",
" )\n",
" mean_q = gp.mean_quantiles_mc(x_pred_scaled.flatten(0, -2)).unflatten(\n",
" -2, x_pred.shape[:-1]\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"colors = plt.cm.tab10.colors\n",
"\n",
"fig, axes = plt.subplots(1, len(x2_values), sharey=True)\n",
"\n",
"for i, x2_val in enumerate(x2_values):\n",
" ax = axes[i]\n",
"\n",
" x_line = x_pred[i]\n",
" true_q_lines = mean(x_line).unsqueeze(-1) + std(x_line).unsqueeze(-1) * Normal(\n",
" 0, 1\n",
" ).icdf(q)\n",
"\n",
" ax.plot(x_line[..., 0].cpu(), true_q_lines.cpu(), \"--\", color=\"k\")\n",
"\n",
" for j in range(len(q)):\n",
" ax.plot(x_line[..., 0].cpu(), mean_q[i, ..., j].cpu(), color=colors[j])\n",
"\n",
" ax.set_title(f\"x2 = {x2_val:.1f}\")\n",
" ax.set_xlabel(\"x1\")\n",
" if i == 0:\n",
" ax.set_ylabel(\"y\")\n",
"\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "heavyedge",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}