Source code for gpytorch_qr.variational

"""Variational strategies for GPQR."""

import gpytorch
import torch

__all__ = [
    "CGLmcVariationalStrategy",
    "CGBlkdiagLmcVariationalStrategy",
]


[docs] class CGLmcVariationalStrategy(gpytorch.variational.LMCVariationalStrategy): """LMC variational strategy for the center-gap quantile regression model. This class allows all gaps to be correlated. Parameters ---------- base_variational_strategy : gpytorch.variational.VariationalStrategy The base variational strategy to wrap. num_quantiles : int The number of quantiles. num_latents : int The number of latent functions. latent_dim : int, optional The dimension along which the latent functions are defined. Default is -1. jitter_val : float, optional The jitter value to add to the covariance matrix for numerical stability. Notes ----- This class modifies the standard LMC coefficients to fit the center-gap representation. The first latent function directly represents the central quantile, and it does not form any linear combinations with the other latent functions. The remaining latent functions are linearly combined to model the gap functions between quantiles. Subclass can extend :meth:`construct_lmc_mask` to further restrict the linear combinations, e.g., to model upper and lower gap functions separately by block diagonal matrices. """ def __init__( self, base_variational_strategy, num_quantiles, # Q num_latents, # T latent_dim=-1, jitter_val=None, ): super().__init__( base_variational_strategy, num_quantiles, num_latents, latent_dim, jitter_val, ) # lmc_coefficients: ([batch_shape], T, Q) lmc_coefficients = self.lmc_coefficients.detach().clone() del self.lmc_coefficients g0_mask = torch.zeros_like(lmc_coefficients) g0_mask[..., 0, 0] = 1 self.register_buffer("g0_mask", g0_mask) lmc_mask = torch.zeros_like(lmc_coefficients) lmc_mask[..., 1:, 1:] = self.construct_lmc_mask( torch.Size( list(lmc_coefficients.shape[:-2]) + [lmc_coefficients.shape[-2] - 1] + [lmc_coefficients.shape[-1] - 1] ) ) self.register_buffer("lmc_mask", lmc_mask) self.register_parameter( "_lmc_coefficients", torch.nn.Parameter(lmc_coefficients) )
[docs] def construct_lmc_mask(self, shape): """Construct a mask to restrict the LMC structure. Parameters ---------- shape : torch.Size The shape of the LMC coefficients. Must be ``([batch_shape], T - 1, Q - 1)``, where ``T`` is the number of latent functions and ``Q`` is the number of quantiles. Returns ------- lmc_mask : torch.Tensor with shape ``shape`` A binary mask of the same shape as the LMC coefficients, where 1 indicates the positions of the LMC coefficients to be learned, and 0 indicates the positions of the LMC coefficients to be fixed at 0. """ return torch.ones(shape)
@property def lmc_coefficients(self): return self._lmc_coefficients * self.lmc_mask + self.g0_mask
[docs] class CGBlkdiagLmcVariationalStrategy(CGLmcVariationalStrategy): """LMC variational strategy for the center-gap quantile regression model. This class does not allow correlations between upper and lower gap functions. Parameters ---------- base_variational_strategy num_quantiles num_latents num_lower_quantiles : int The number of lower quantiles. num_lower_latents : int The number of lower latent functions. latent_dim jitter_val """ def __init__( self, base_variational_strategy, num_quantiles, # Q num_latents, # T num_lower_quantiles, num_lower_latents, latent_dim=-1, jitter_val=None, ): num_upper_quantiles = num_quantiles - num_lower_quantiles - 1 num_upper_latents = num_latents - num_lower_latents - 1 self.num_lower_quantiles = num_lower_quantiles self.num_lower_latents = num_lower_latents self.num_upper_quantiles = num_upper_quantiles self.num_upper_latents = num_upper_latents super().__init__( base_variational_strategy, num_quantiles, num_latents, latent_dim, jitter_val, )
[docs] def construct_lmc_mask(self, shape): mask = super().construct_lmc_mask(shape) mask[..., : self.num_lower_latents, -self.num_upper_quantiles :] = 0 mask[..., -self.num_upper_latents :, : self.num_lower_quantiles] = 0 return mask
@property def lmc_coefficients(self): return self._lmc_coefficients * self.lmc_mask + self.g0_mask