Cross validation#
Model selection by 5-fold cross validation.
import os
import torch
from torch.distributions import Normal
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.variational import VariationalStrategy, LMCVariationalStrategy
from gpytorch.means import ConstantMean
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.mlls import VariationalELBO
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
try:
import sys
sys.path.insert(0, os.path.abspath(".."))
import config_notebook
except ImportError:
print("Output will not be deterministic SVG.")
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_epochs = int(os.getenv("GPYTORCHQR_N_EPOCHS", 5000))
Data preparation#
def mean(x):
return torch.cos(x * 2 * 3.14)
def std(x):
return x + 0.1
N = 100
x = torch.linspace(0, 1, N).reshape(-1, 1).to(device)
y = (mean(x) + torch.randn(x.shape, device=device).mul(std(x))).squeeze()
q = torch.linspace(0.1, 0.9, 9).to(device)
true_quantiles = mean(x) + std(x) * Normal(0, 1).icdf(q)
x_pred = torch.linspace(0, 1.5, 100).reshape(-1, 1).to(device)
plt.scatter(x.cpu(), y.cpu(), c="k", marker=".")
plt.plot(x.cpu(), true_quantiles.cpu(), "--", c="gray")
plt.show()
K = 5
kf = KFold(n_splits=K, shuffle=True, random_state=42)
x_train_list, y_train_list, x_test_list, y_test_list = [], [], [], []
for train_idx, test_idx in kf.split(x.cpu()):
x_train_list.append(x[train_idx])
y_train_list.append(y[train_idx])
x_test_list.append(x[test_idx])
y_test_list.append(y[test_idx])
x_train_cv = torch.stack(x_train_list).to(device)
y_train_cv = torch.stack(y_train_list).to(device)
x_test_cv = torch.stack(x_test_list).to(device)
y_test_cv = torch.stack(y_test_list).to(device)
Plot loss by epoch#
from matplotlib.lines import Line2D
colors = plt.cm.tab10.colors
losses = [
[mtgpqr_train_losses, mtgpqr_test_losses],
[mtgpqr_cg_train_losses, mtgpqr_cg_test_losses],
]
for (train_losses, test_losses), color in zip(losses, colors):
plt.plot(train_losses, color=color, linestyle="-.", alpha=0.5)
plt.plot(test_losses, color=color, linestyle="--")
model_handles = [
Line2D([0], [0], color=colors[i], linestyle="-") for i in range(len(losses))
]
style_handles = [
Line2D([0], [0], color="k", linestyle="-.", alpha=0.5),
Line2D([0], [0], color="k", linestyle="--"),
]
handles = model_handles + style_handles
labels = ["Direct GPQR", "CG GPQR", "Train", "Test"]
plt.legend(handles=handles, labels=labels, ncols=4)
plt.xlabel("Epoch")
plt.ylabel("Negative ELBO")
plt.show()