Source code for gpytorch_qr.utils

"""Utility functions."""

import gpytorch
import torch
import torch.nn.functional as F

__all__ = [
    "centergap_to_quantiles",
    "CenterGapToQuantileTransform",
    "transform_centergap_posterior",
]


[docs] def centergap_to_quantiles(central, lower_gaps, upper_gaps, quantile_dim=-1): """Convert center-gap representation samples to quantiles. Parameters ---------- central : torch.Tensor with shape (..., 1, ...) The central quantile values. lower_gaps : torch.Tensor with shape (..., L, ...) Pre-transformed lower gap values. upper_gaps : torch.Tensor with shape (..., U, ...) Pre-transformed upper gap values. quantile_dim : int, default=-1 The dimension along which the quantiles are represented. Returns ------- quantiles : torch.Tensor with shape (..., Q, ...) Quantile values. (Q = L + U + 1 at *quantile_dim*) The quantiles are ordered from lowest to highest along the quantile dimension. """ lower_gaps = F.softplus(lower_gaps) lower_quantiles = central - lower_gaps.flip(dims=[quantile_dim]).cumsum( dim=quantile_dim ).flip(dims=[quantile_dim]) upper_gaps = F.softplus(upper_gaps) upper_quantiles = central + upper_gaps.cumsum(dim=quantile_dim) ret = torch.concat([lower_quantiles, central, upper_quantiles], dim=quantile_dim) return ret
def _softplus_inverse(y): return y + torch.log(-torch.expm1(-y))
[docs] class CenterGapToQuantileTransform(torch.distributions.transforms.Transform): """Bijective transform from center-gap distribution to quantile distribution. Parameters ---------- L : int Number of lower quantile gaps in the center-gap representation. quantile_dim : {-1, -2} The dimension along which the quantiles are represented. Notes ----- If *quantile_dim* is -1, shape of input tensor is either ``(N, Q)`` or ``(S, N, Q)``. If *quantile_dim* is -2, shape of input tensor is either ``(Q, N)`` or ``(S, Q, N)``. Here, *Q* is the number of quantiles, *N* is the number of data points, and *S* is the number of samples. The center-gap components along the quantile dimension is ordered as a central quantile, *L* lower pre-gaps, and *U* upper pre-gaps (``Q = 1 + L + U``). """ domain = torch.distributions.constraints.real_vector codomain = torch.distributions.constraints.real_vector bijective = True def __init__(self, L, quantile_dim=-2): super().__init__() self.L = L self.quantile_dim = quantile_dim def _call(self, x): qdim = self.quantile_dim C = torch.narrow(x, qdim, 0, 1) L = torch.narrow(x, qdim, 1, self.L) U = torch.narrow(x, qdim, 1 + self.L, x.shape[qdim] - 1 - self.L) Q = centergap_to_quantiles(C, L, U, quantile_dim=qdim) return Q def _inverse(self, y): L = self.L qdim = self.quantile_dim central = torch.narrow(y, qdim, L, 1) lower_gaps_linear = torch.narrow(y, qdim, 0, L + 1).diff(dim=qdim) upper_gaps_linear = torch.narrow(y, qdim, L, y.shape[qdim] - L).diff(dim=qdim) return torch.cat( [ central, _softplus_inverse(lower_gaps_linear), _softplus_inverse(upper_gaps_linear), ], dim=qdim, )
[docs] def log_abs_det_jacobian(self, x, y): qdim = self.quantile_dim gaps = torch.narrow(x, qdim, 1, x.shape[qdim] - 1) return F.logsigmoid(gaps).sum(dim=(-2, -1))
[docs] def transform_centergap_posterior(posterior, L): """Convert the center-gap posterior to quantile posterior. Parameters ---------- posterior : gpytorch.distributions.MultivariateNormal The center-gap posterior distribution. L : int The number of lower quantiles in center-gap representation. Returns ------- quantile_posterior : torch.distributions.TransformedDistribution Posterior over quantiles, obtained by applying :class:`CenterGapToQuantileTransform` to a batched :class:`gpytorch.distributions.MultivariateNormal`. Notes ----- The quantile dimension consists of the central quantile, followed by *L* lower gaps and *U* upper gaps, where *U = Q - L - 1*. """ if isinstance(posterior, gpytorch.distributions.MultitaskMultivariateNormal): quantile_dim = -1 elif isinstance(posterior, gpytorch.distributions.MultivariateNormal): quantile_dim = -2 else: raise ValueError("Posterior is not a multivariate normal.") transform = CenterGapToQuantileTransform(L, quantile_dim=quantile_dim) return torch.distributions.TransformedDistribution(posterior, transform)