Source code for gpytorch_qr.models

"""Gaussian process classes for quantile regression."""

import abc

import gpytorch
import torch

from .utils import centergap_to_quantiles, transform_centergap_posterior

__all__ = [
    "QuantileGP",
    "DirectQuantileGP",
    "CenterGapQuantileGP",
]


[docs] class QuantileGP(gpytorch.models.ApproximateGP, abc.ABC): """Base class for Gaussian process quantile regression. Parameters ---------- variational_strategy : gpytorch.variational.VariationalStrategy mean_module : gpytorch.means.Mean covar_module : gpytorch.kernels.Kernel latent_dim : {0, -1} The dimension along which the latent GPs are represented in module batch shape. ``0`` if quantiles are batches, ``-1`` if quantiles are multitasks. Notes ----- Input predictors are expected to have shape ``(*B, N, D)``, where ``*B`` are optional batch shapes (e.g., for cross validation), *N* is the number of data points and *D* is the number of input dimensions. Quantiles can be either batch dimension or task dimension with shape *Q*. .. rubric:: Batch quantiles - ``variational_strategy`` must wrap a variational distribution with batch shape ``(Q, *B)``. - ``mean_module`` and ``covar_module`` must have batch shape ``(Q, *B)``. - Posterior is :class:`gpytorch.distributions.MultivariateNormal` with batch shape ``(Q, *B)`` and event shape ``(N,)``. - MLL loss is a tensor of shape ``(Q, *B)``. .. rubric:: Multitask quantiles Quantiles are constructed by combination of *L* latent GPs. - ``variational_strategy`` must wrap a variational distribution with batch shape ``(*B, L)``. - ``mean_module`` and ``covar_module`` must have batch shape ``(*B, L)``. - Posterior is :class:`gpytorch.distributions.MultitaskMultivariateNormal` with batch shape ``(*B)`` and event shape ``(N, Q)``. - MLL loss is a tensor of shape ``(*B)``. """ def __init__(self, variational_strategy, mean_module, covar_module, latent_dim): super().__init__(variational_strategy) self.mean_module = mean_module self.covar_module = covar_module if latent_dim == 0: self.unsqueeze_dim = 0 elif latent_dim == -1: self.unsqueeze_dim = -3 else: raise ValueError("latent_dim should be either 0 or -1.")
[docs] def forward(self, x): # x : (*B, N, D) -> (1, *B, N, D) or (*B, 1, N, D) x = x.unsqueeze(self.unsqueeze_dim) mean = self.mean_module(x) covar = self.covar_module(x) return gpytorch.distributions.MultivariateNormal(mean, covar)
[docs] @abc.abstractmethod def joint_quantile_posterior(self, x): """Joint posterior over quantiles at input locations. Parameters ---------- x : torch.Tensor with shape ``(*B, N, D)`` The input locations. Returns ------- torch.distributions.Distribution """ pass
[docs] def marginal_quantile_posterior(self, x): """Marginal posterior over quantiles at input locations. Parameters ---------- x : torch.Tensor with shape ``(*B, N, D)`` The input locations. Returns ------- torch.distributions.Distribution """ raise NotImplementedError
[docs] def mean_quantiles(self, x): """Predict quantiles by analytical posterior mean. Parameters ---------- x : torch.Tensor with shape ``(*B, N, D)`` The input locations. Returns ------- quantiles : torch.Tensor The predicted quantiles at the input locations. """ raise NotImplementedError
[docs] def mean_quantiles_mc(self, x, num_samples=10): """Posterior mean of quantiles by Monte Carlo approximation. Parameters ---------- x : torch.Tensor with shape ``(*B, N, D)`` The input locations. num_samples : int, default=10 The number of Monte Carlo samples. Returns ------- quantiles : torch.Tensor The predicted quantiles at the input locations. """ dist = self.joint_quantile_posterior(x) samples = dist.rsample(torch.Size([num_samples])) return samples.mean(dim=0)
[docs] def mean_quantiles_delta(self, x): """Posterior mean of quantiles by 0th-order delta method. Parameters ---------- x : torch.Tensor with shape ``(*B, N, D)`` The input locations. Returns ------- quantiles : torch.Tensor The predicted quantiles at the input locations. """ raise NotImplementedError
[docs] def quantile_quantiles(self, x, q): """Analytic quantile of quantile posterior. Parameters ---------- x : torch.Tensor with shape ``(*B, N, D)`` The input locations. q : torch.Tensor with shape (q,) The quantile levels. Returns ------- quantiles : torch.Tensor The predicted quantiles at the input locations. """ raise NotImplementedError
[docs] def quantile_quantiles_mc(self, x, q, num_samples=10): """Quantile of quantile posterior by Monte Carlo approximation. Parameters ---------- x : torch.Tensor with shape ``(*B, N, D)`` The input locations. q : torch.Tensor with shape (q,) The quantile levels. num_samples : int, default=10 The number of Monte Carlo samples. Returns ------- quantiles : torch.Tensor The predicted quantiles at the input locations. """ dist = self.joint_quantile_posterior(x) samples = dist.rsample(torch.Size([num_samples])) return samples.quantile(q, dim=0)
[docs] class DirectQuantileGP(QuantileGP): """Gaussian process quantile regression with direct quantile representation."""
[docs] def joint_quantile_posterior(self, x): return self(x)
[docs] def marginal_quantile_posterior(self, x): dist = self(x) return torch.distributions.Normal(dist.mean, dist.variance.sqrt())
[docs] def mean_quantiles(self, x): return self(x).mean
[docs] def mean_quantiles_delta(self, x): return self(x).mean
[docs] def quantile_quantiles(self, x, q): dist = self.marginal_quantile_posterior(x) shape = [-1] + [1 for _ in range(len(dist.batch_shape))] return dist.icdf(q.reshape(*shape))
[docs] class CenterGapQuantileGP(QuantileGP): """Gaussian process quantile regression with center-gap quantile representation. Parameters ---------- variational_strategy mean_module : gpytorch_qr.centergap.CenterGapMean Mean module for center-gap representation. covar_module latent_dim num_lower_quantiles : int The number of lower quantiles in center-gap representation. """ def __init__( self, variational_strategy, mean_module, covar_module, latent_dim, num_lower_quantiles, ): super().__init__(variational_strategy, mean_module, covar_module, latent_dim) self.num_lower_quantiles = num_lower_quantiles
[docs] def joint_quantile_posterior(self, x): return transform_centergap_posterior(self(x), self.num_lower_quantiles)
[docs] def mean_quantiles_delta(self, x): latent_posterior = self(x) if isinstance( latent_posterior, gpytorch.distributions.MultitaskMultivariateNormal ): quantile_dim = -1 elif isinstance(latent_posterior, gpytorch.distributions.MultivariateNormal): quantile_dim = -2 else: raise ValueError("Posterior is not a multivariate normal.") latent_mean = latent_posterior.mean num_upper = latent_mean.shape[quantile_dim] - 1 - self.num_lower_quantiles center_mean = torch.narrow(latent_mean, quantile_dim, 0, 1) lower_gaps = torch.narrow( latent_mean, quantile_dim, 1, self.num_lower_quantiles ) upper_gaps = torch.narrow( latent_mean, quantile_dim, 1 + self.num_lower_quantiles, num_upper ) return centergap_to_quantiles( center_mean, lower_gaps, upper_gaps, quantile_dim=quantile_dim )