Batch Independent GPQR#

import os

import torch
from torch.distributions import Normal
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.variational import UnwhitenedVariationalStrategy
from gpytorch.means import Mean
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.mlls import VariationalELBO
import matplotlib.pyplot as plt

from gpytorch_qr.models import DirectQuantileGP
from gpytorch_qr.likelihoods import BatchQuantileGPLikelihood

try:
    import sys

    sys.path.insert(0, os.path.abspath(".."))

    import config_notebook
except ImportError:
    print("Output will not be deterministic SVG.")

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

n_epochs = int(os.getenv("GPYTORCHQR_N_EPOCHS", 10000))

Data preparation#

def mean(x):
    return torch.cos(x * 2 * 3.14)


def std(x):
    return x + 0.1


x = torch.linspace(0, 1, 100).reshape(-1, 1).to(device)
y = (mean(x) + torch.randn(x.shape, device=device).mul(std(x))).squeeze()
q = torch.linspace(0.1, 0.9, 9).to(device)
true_quantiles = mean(x) + std(x) * Normal(0, 1).icdf(q)
x_pred = torch.linspace(0, 1.5, 100).reshape(-1, 1).to(device)
plt.scatter(x.cpu(), y.cpu(), c="k", marker=".")
plt.plot(x.cpu(), true_quantiles.cpu(), "--", c="gray")
plt.show()
../../_images/f909c7fb9e6989ed8fcac3f9b25504e02bdd3c200bd63c0ea7769bb2c0d5dbf4.svg

Prior mean#

class PriorMean(Mean):
    def __init__(self, batch_shape=torch.Size([])):
        super().__init__()
        self.batch_shape = batch_shape
        self.register_parameter("offset", torch.nn.Parameter(torch.zeros(*batch_shape)))

    def forward(self, x):
        res = mean(x).squeeze(-1)
        return res + self.offset.unsqueeze(-1)


prior_mean = PriorMean(batch_shape=torch.Size([len(q)])).to(device)
plt.scatter(x.cpu(), y.cpu(), c="k", marker=".")
plt.plot(x_pred.cpu(), prior_mean(x_pred).detach().cpu().T)
plt.show()
../../_images/0f9ec080c486be1365158a05d17c826d56e09f13bf4b5ebb139568cd52926e16.svg

Define models and likelihoods#

class MyGP_PriorMean(DirectQuantileGP):
    def __init__(self, inducing_points, num_quantiles):
        N, D = inducing_points.size()
        variational_distribution = CholeskyVariationalDistribution(
            N,
            batch_shape=torch.Size([num_quantiles]),
        )
        variational_strategy = UnwhitenedVariationalStrategy(
            self,
            inducing_points,
            variational_distribution,
            learn_inducing_locations=False,
        )
        mean = PriorMean(batch_shape=torch.Size([num_quantiles]))
        covar = ScaleKernel(
            RBFKernel(ard_num_dims=D, batch_shape=torch.Size([num_quantiles])),
            batch_shape=torch.Size([num_quantiles]),
        )
        super().__init__(variational_strategy, mean, covar, 0)


inducing_points = x.detach().clone()
gp_priormean = MyGP_PriorMean(inducing_points, len(q)).to(device)
likelihood_priormean = BatchQuantileGPLikelihood(q).to(device)

Train#

gp_priormean.train()
likelihood_priormean.train()
mll = VariationalELBO(likelihood_priormean, gp_priormean, num_data=y.numel())
optimizer = torch.optim.Adam(
    list(gp_priormean.parameters()) + list(likelihood_priormean.parameters()),
    lr=0.001,
)

for _ in range(n_epochs):
    output = gp_priormean(x)
    loss = -mll(output, y).sum()
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
/home/jisoosong/miniconda3/envs/heavyedge/lib/python3.13/site-packages/gpytorch/distributions/multivariate_normal.py:375: NumericalWarning: Negative variance values detected. This is likely due to numerical instabilities. Rounding negative variances up to 1e-06.
  warnings.warn(

Evaluate#

gp_priormean.eval()
with torch.no_grad():
    quantiles_priormean = gp_priormean.mean_quantiles(x_pred)

Plot result#

colors = plt.cm.tab10(range(len(q)))

plt.scatter(x.cpu(), y.cpu(), c="gray", marker=".", alpha=0.1)
plt.plot(x.cpu(), true_quantiles.cpu(), "--", c="gray", alpha=0.5)

for i in range(len(q)):
    plt.plot(x_pred.cpu(), quantiles_priormean[i].cpu(), color=colors[i])

plt.show()
../../_images/0b38d3a63508cd8bcfa9807e7e88be17d1379c72050ba16df3c58d35a8761e0b.svg