Source code for gpytorch_qr.means

"""Mean modules."""

import torch
from gpytorch.means import Mean

__all__ = [
    "CenterGapMean",
]


[docs] class CenterGapMean(Mean): """Mean module for center-gap. Parameters ---------- center_mean : gpytorch.means.Mean Mean module for the central quantile. If GPQR treats quantiles as batches, this module should have batch shape ``(1, *B)`` where *B* is additional batch shape. If GPQR treats quantiles as task, this module should have batch shape ``(*B, 1)``. gap_mean : gpytorch.means.Mean Mean module for the quantile gaps. If GPQR treats quantiles as batches, this module should have batch shape ``(Q-1, *B)`` where *Q* is the number of quantiles and *B* is additional batch shape. If GPQR treats quantiles as task, this module should have batch shape ``(*B, L-1)`` where *L* is the number of latent GPs. latent_dim : {0, -1} The dimension along which the latent GPs are represented in module batch shape. ``0`` if quantiles are batches, ``-1`` if quantiles are tasks. Notes ----- If GPQR treats quantiles as batches, input predictors are expected to have shape ``(1, *B, N, D)``. If GPQR treats quantiles as tasks, input predictors are expected to have shape ``(*B, 1, N, D)``. *N* is the number of data points and *D* is the number of input dimensions. """ def __init__(self, center_mean, gap_mean, latent_dim): super().__init__() self.center_mean = center_mean self.gap_mean = gap_mean if latent_dim == 0: self.concat_dim = 0 elif latent_dim == -1: self.concat_dim = -2 else: raise ValueError("latent_dim should be either 0 or -1.")
[docs] def forward(self, x): """Compute the mean of center-gap representation. Parameters ---------- x : torch.Tensor in shape ``(1, *B, N, D)`` or ``(*B, 1, N, D)`` """ center_mean = self.center_mean(x) # (1, *B, N) or (*B, 1, N) gap_mean = self.gap_mean(x) return torch.concat([center_mean, gap_mean], dim=self.concat_dim)