{ "cells": [ { "cell_type": "markdown", "id": "0", "metadata": {}, "source": [ "# Cross validation\n", "\n", "Model selection by 5-fold cross validation." ] }, { "cell_type": "code", "execution_count": null, "id": "1", "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "import torch\n", "from torch.distributions import Normal\n", "from gpytorch.variational import CholeskyVariationalDistribution\n", "from gpytorch.variational import VariationalStrategy, LMCVariationalStrategy\n", "from gpytorch.means import ConstantMean\n", "from gpytorch.kernels import RBFKernel, ScaleKernel\n", "from gpytorch.mlls import VariationalELBO\n", "from sklearn.model_selection import KFold\n", "import matplotlib.pyplot as plt\n", "\n", "try:\n", " import sys\n", "\n", " sys.path.insert(0, os.path.abspath(\"..\"))\n", "\n", " import config_notebook\n", "except ImportError:\n", " print(\"Output will not be deterministic SVG.\")\n", "\n", "torch.manual_seed(42)\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "n_epochs = int(os.getenv(\"GPYTORCHQR_N_EPOCHS\", 5000))" ] }, { "cell_type": "markdown", "id": "2", "metadata": {}, "source": [ "## Data preparation" ] }, { "cell_type": "code", "execution_count": null, "id": "3", "metadata": {}, "outputs": [], "source": [ "def mean(x):\n", " return torch.cos(x * 2 * 3.14)\n", "\n", "\n", "def std(x):\n", " return x + 0.1\n", "\n", "\n", "N = 100\n", "x = torch.linspace(0, 1, N).reshape(-1, 1).to(device)\n", "y = (mean(x) + torch.randn(x.shape, device=device).mul(std(x))).squeeze()\n", "q = torch.linspace(0.1, 0.9, 9).to(device)\n", "true_quantiles = mean(x) + std(x) * Normal(0, 1).icdf(q)\n", "x_pred = torch.linspace(0, 1.5, 100).reshape(-1, 1).to(device)" ] }, { "cell_type": "code", "execution_count": null, "id": "4", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.10.8, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.scatter(x.cpu(), y.cpu(), c=\"k\", marker=\".\")\n", "plt.plot(x.cpu(), true_quantiles.cpu(), \"--\", c=\"gray\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "5", "metadata": {}, "outputs": [], "source": [ "K = 5\n", "kf = KFold(n_splits=K, shuffle=True, random_state=42)\n", "\n", "x_train_list, y_train_list, x_test_list, y_test_list = [], [], [], []\n", "for train_idx, test_idx in kf.split(x.cpu()):\n", " x_train_list.append(x[train_idx])\n", " y_train_list.append(y[train_idx])\n", " x_test_list.append(x[test_idx])\n", " y_test_list.append(y[test_idx])\n", "\n", "x_train_cv = torch.stack(x_train_list).to(device)\n", "y_train_cv = torch.stack(y_train_list).to(device)\n", "x_test_cv = torch.stack(x_test_list).to(device)\n", "y_test_cv = torch.stack(y_test_list).to(device)" ] }, { "cell_type": "markdown", "id": "6", "metadata": {}, "source": [ "## Batch direct GPQR" ] }, { "cell_type": "code", "execution_count": null, "id": "7", "metadata": {}, "outputs": [], "source": [ "from gpytorch_qr.models import DirectQuantileGP\n", "from gpytorch_qr.likelihoods import BatchQuantileGPLikelihood\n", "\n", "\n", "class CVBatchQuantileGP(DirectQuantileGP):\n", " def __init__(self, inducing_points, num_quantiles, batch_shape):\n", " N, D = inducing_points.size()\n", " batch_shape = torch.Size([num_quantiles, *batch_shape])\n", " variational_distribution = CholeskyVariationalDistribution(\n", " N,\n", " batch_shape=batch_shape,\n", " )\n", " variational_strategy = VariationalStrategy(\n", " self,\n", " inducing_points,\n", " variational_distribution,\n", " learn_inducing_locations=True,\n", " )\n", " mean = ConstantMean(batch_shape=batch_shape)\n", " covar = ScaleKernel(\n", " RBFKernel(ard_num_dims=D, batch_shape=batch_shape),\n", " batch_shape=batch_shape,\n", " )\n", " super().__init__(variational_strategy, mean, covar, 0)\n", "\n", "\n", "inducing_points = torch.linspace(0, 1, 10).reshape(-1, 1).to(device)\n", "gpqr = CVBatchQuantileGP(inducing_points, len(q), (K,)).to(device)\n", "q_expanded = q.unsqueeze(1).expand(-1, K)\n", "likelihood_gpqr = BatchQuantileGPLikelihood(q_expanded).to(device)" ] }, { "cell_type": "code", "execution_count": null, "id": "8", "metadata": {}, "outputs": [], "source": [ "mll = VariationalELBO(likelihood_gpqr, gpqr, num_data=y_train_cv.shape[-1])\n", "optimizer = torch.optim.Adam(\n", " list(gpqr.parameters()) + list(likelihood_gpqr.parameters()),\n", " lr=0.001,\n", ")\n", "\n", "gpqr_train_losses, gpqr_test_losses = [], []\n", "for _ in range(n_epochs):\n", " gpqr.train()\n", " likelihood_gpqr.train()\n", "\n", " output = gpqr(x_train_cv)\n", " train_loss = -mll(output, y_train_cv)\n", " train_loss.sum().backward()\n", " optimizer.step()\n", " optimizer.zero_grad()\n", "\n", " gpqr.eval()\n", " likelihood_gpqr.eval()\n", " with torch.no_grad():\n", " output = gpqr(x_test_cv)\n", " test_loss = -mll(output, y_test_cv)\n", "\n", " gpqr_train_losses.append(train_loss.mean().item())\n", " gpqr_test_losses.append(test_loss.mean().item())" ] }, { "cell_type": "markdown", "id": "9", "metadata": {}, "source": [ "## Batch center-gap GPQR" ] }, { "cell_type": "code", "execution_count": null, "id": "10", "metadata": {}, "outputs": [], "source": [ "from gpytorch_qr.means import CenterGapMean\n", "from gpytorch_qr.models import CenterGapQuantileGP\n", "from gpytorch_qr.likelihoods import BatchCenterGapQuantileGPLikelihood\n", "\n", "\n", "class CVBatchCenterGapQuantileGP(CenterGapQuantileGP):\n", " def __init__(self, inducing_points, num_quantiles, num_lower_quantiles, num_folds):\n", " N, D = inducing_points.size()\n", " batch_shape = torch.Size([num_quantiles, num_folds])\n", " variational_distribution = CholeskyVariationalDistribution(\n", " N,\n", " batch_shape=batch_shape,\n", " )\n", " variational_strategy = VariationalStrategy(\n", " self,\n", " inducing_points,\n", " variational_distribution,\n", " learn_inducing_locations=True,\n", " )\n", "\n", " mean = CenterGapMean(\n", " ConstantMean(batch_shape=torch.Size([1, num_folds])),\n", " ConstantMean(batch_shape=torch.Size([num_quantiles - 1, num_folds])),\n", " latent_dim=0,\n", " )\n", " covar = ScaleKernel(\n", " RBFKernel(ard_num_dims=D, batch_shape=batch_shape),\n", " batch_shape=batch_shape,\n", " )\n", " super().__init__(variational_strategy, mean, covar, 0, num_lower_quantiles)\n", "\n", "\n", "inducing_points = torch.linspace(0, 1, 10).reshape(-1, 1).to(device)\n", "central_q_index = (q - 0.5).abs().argmin().item()\n", "num_lower_quantiles = central_q_index\n", "gpqr_cg = CVBatchCenterGapQuantileGP(\n", " inducing_points, len(q), num_lower_quantiles, K\n", ").to(device)\n", "likelihood_gpqr_cg = BatchCenterGapQuantileGPLikelihood(\n", " q.unsqueeze(1), central_q_index, torch.zeros((len(q), K))\n", ").to(device)" ] }, { "cell_type": "code", "execution_count": null, "id": "11", "metadata": {}, "outputs": [], "source": [ "mll = VariationalELBO(likelihood_gpqr_cg, gpqr_cg, num_data=y_train_cv.shape[1])\n", "optimizer = torch.optim.Adam(\n", " list(gpqr_cg.parameters()) + list(likelihood_gpqr_cg.parameters()),\n", " lr=0.001,\n", ")\n", "\n", "gpqr_cg_train_losses, gpqr_cg_test_losses = [], []\n", "for _ in range(n_epochs):\n", " gpqr_cg.train()\n", " likelihood_gpqr_cg.train()\n", "\n", " output = gpqr_cg(x_train_cv)\n", " train_loss = -mll(output, y_train_cv)\n", " train_loss.sum().backward()\n", " optimizer.step()\n", " optimizer.zero_grad()\n", "\n", " gpqr_cg.eval()\n", " likelihood_gpqr_cg.eval()\n", " with torch.no_grad():\n", " output = gpqr_cg(x_test_cv)\n", " test_loss = -mll(output, y_test_cv)\n", "\n", " gpqr_cg_train_losses.append(train_loss.mean().item())\n", " gpqr_cg_test_losses.append(test_loss.mean().item())" ] }, { "cell_type": "markdown", "id": "12", "metadata": {}, "source": [ "## Multitask direct GPQR" ] }, { "cell_type": "code", "execution_count": null, "id": "13", "metadata": {}, "outputs": [], "source": [ "from gpytorch_qr.models import DirectQuantileGP\n", "from gpytorch_qr.likelihoods import MultitaskQuantileGPLikelihood\n", "\n", "\n", "class CVMultitaskQuantileGP(DirectQuantileGP):\n", " def __init__(self, inducing_points, num_quantiles, batch_shape, num_latents):\n", " N, D = inducing_points.size()\n", " batch_shape = torch.Size([*batch_shape, num_latents])\n", " variational_distribution = CholeskyVariationalDistribution(\n", " N,\n", " batch_shape=batch_shape,\n", " )\n", " variational_strategy = LMCVariationalStrategy(\n", " VariationalStrategy(\n", " self,\n", " inducing_points,\n", " variational_distribution,\n", " learn_inducing_locations=True,\n", " ),\n", " num_tasks=num_quantiles,\n", " num_latents=num_latents,\n", " )\n", "\n", " mean_module = ConstantMean(batch_shape=batch_shape)\n", " covar_module = ScaleKernel(\n", " RBFKernel(ard_num_dims=D, batch_shape=batch_shape),\n", " batch_shape=batch_shape,\n", " )\n", " super().__init__(variational_strategy, mean_module, covar_module, -1)\n", "\n", "\n", "inducing_points = torch.linspace(0, 1, 10).reshape(-1, 1).to(device)\n", "num_latents = len(q) - 2\n", "mtgpqr = CVMultitaskQuantileGP(inducing_points, len(q), (K,), num_latents).to(device)\n", "q_expanded = q.unsqueeze(0).expand(K, -1)\n", "likelihood_mtgpqr = MultitaskQuantileGPLikelihood(q_expanded).to(device)" ] }, { "cell_type": "code", "execution_count": null, "id": "14", "metadata": {}, "outputs": [], "source": [ "mll = VariationalELBO(likelihood_mtgpqr, mtgpqr, num_data=y_train_cv.shape[-1])\n", "optimizer = torch.optim.Adam(\n", " list(mtgpqr.parameters()) + list(likelihood_mtgpqr.parameters()),\n", " lr=0.001,\n", ")\n", "\n", "mtgpqr_train_losses, mtgpqr_test_losses = [], []\n", "for _ in range(n_epochs):\n", " mtgpqr.train()\n", " likelihood_mtgpqr.train()\n", "\n", " output = mtgpqr(x_train_cv)\n", " train_loss = -mll(output, y_train_cv)\n", " train_loss.sum().backward()\n", " optimizer.step()\n", " optimizer.zero_grad()\n", "\n", " mtgpqr.eval()\n", " likelihood_mtgpqr.eval()\n", " with torch.no_grad():\n", " output = mtgpqr(x_test_cv)\n", " test_loss = -mll(output, y_test_cv)\n", "\n", " mtgpqr_train_losses.append(train_loss.mean().item() / len(q))\n", " mtgpqr_test_losses.append(test_loss.mean().item() / len(q))" ] }, { "cell_type": "markdown", "id": "15", "metadata": {}, "source": [ "## Multitask center-gap GPQR" ] }, { "cell_type": "code", "execution_count": null, "id": "16", "metadata": {}, "outputs": [], "source": [ "from gpytorch_qr.means import CenterGapMean\n", "from gpytorch_qr.models import CenterGapQuantileGP\n", "from gpytorch_qr.variational import CGLmcVariationalStrategy\n", "from gpytorch_qr.likelihoods import MultitaskCenterGapQuantileGPLikelihood\n", "\n", "\n", "class CVMultitaskCenterGapQuantileGP(CenterGapQuantileGP):\n", " def __init__(\n", " self,\n", " inducing_points,\n", " num_quantiles,\n", " num_lower_quantiles,\n", " num_latents,\n", " num_folds,\n", " ):\n", " N, D = inducing_points.size()\n", " batch_shape = torch.Size([num_folds, num_latents])\n", " variational_distribution = CholeskyVariationalDistribution(\n", " N,\n", " batch_shape=batch_shape,\n", " )\n", " variational_strategy = CGLmcVariationalStrategy(\n", " VariationalStrategy(\n", " self,\n", " inducing_points,\n", " variational_distribution,\n", " learn_inducing_locations=True,\n", " ),\n", " num_quantiles=num_quantiles,\n", " num_latents=num_latents,\n", " )\n", "\n", " mean = CenterGapMean(\n", " ConstantMean(batch_shape=torch.Size([num_folds, 1])),\n", " ConstantMean(batch_shape=torch.Size([num_folds, num_latents - 1])),\n", " latent_dim=-1,\n", " )\n", " covar = ScaleKernel(\n", " RBFKernel(ard_num_dims=D, batch_shape=batch_shape),\n", " batch_shape=batch_shape,\n", " )\n", " super().__init__(variational_strategy, mean, covar, -1, num_lower_quantiles)\n", "\n", "\n", "inducing_points = torch.linspace(0, 1, 10).reshape(-1, 1).to(device)\n", "central_q_index = (q - 0.5).abs().argmin().item()\n", "num_latents = len(q) - 2\n", "mtgpqr_cg = CVMultitaskCenterGapQuantileGP(\n", " inducing_points, len(q), central_q_index, num_latents, K\n", ").to(device)\n", "likelihood_mtgpqr_cg = MultitaskCenterGapQuantileGPLikelihood(\n", " q.unsqueeze(0), central_q_index, torch.zeros((K, len(q)))\n", ").to(device)" ] }, { "cell_type": "code", "execution_count": null, "id": "17", "metadata": {}, "outputs": [], "source": [ "mll = VariationalELBO(likelihood_mtgpqr_cg, mtgpqr_cg, num_data=y_train_cv.shape[1])\n", "optimizer = torch.optim.Adam(\n", " list(mtgpqr_cg.parameters()) + list(likelihood_mtgpqr_cg.parameters()),\n", " lr=0.001,\n", ")\n", "\n", "mtgpqr_cg_train_losses, mtgpqr_cg_test_losses = [], []\n", "for _ in range(n_epochs):\n", " mtgpqr_cg.train()\n", " likelihood_mtgpqr_cg.train()\n", "\n", " output = mtgpqr_cg(x_train_cv)\n", " train_loss = -mll(output, y_train_cv)\n", " train_loss.sum().backward()\n", " optimizer.step()\n", " optimizer.zero_grad()\n", "\n", " mtgpqr_cg.eval()\n", " likelihood_mtgpqr_cg.eval()\n", " with torch.no_grad():\n", " output = mtgpqr_cg(x_test_cv)\n", " test_loss = -mll(output, y_test_cv)\n", "\n", " mtgpqr_cg_train_losses.append(train_loss.mean().item() / len(q))\n", " mtgpqr_cg_test_losses.append(test_loss.mean().item() / len(q))" ] }, { "cell_type": "markdown", "id": "18", "metadata": {}, "source": [ "## Plot loss by epoch" ] }, { "cell_type": "code", "execution_count": null, "id": "19", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.10.8, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from matplotlib.lines import Line2D\n", "\n", "colors = plt.cm.tab10.colors\n", "\n", "losses = [\n", " [gpqr_train_losses, gpqr_test_losses],\n", " [gpqr_cg_train_losses, gpqr_cg_test_losses],\n", " [mtgpqr_train_losses, mtgpqr_test_losses],\n", " [mtgpqr_cg_train_losses, mtgpqr_cg_test_losses],\n", "]\n", "\n", "for (train_losses, test_losses), color in zip(losses, colors):\n", " plt.plot(train_losses, color=color, linestyle=\"-.\", alpha=0.5)\n", " plt.plot(test_losses, color=color, linestyle=\"--\")\n", "\n", "model_handles = [Line2D([0], [0], color=colors[i], linestyle=\"-\") for i in range(4)]\n", "style_handles = [\n", " Line2D([0], [0], color=\"k\", linestyle=\"-.\", alpha=0.5),\n", " Line2D([0], [0], color=\"k\", linestyle=\"--\"),\n", "]\n", "\n", "handles = model_handles + style_handles\n", "labels = [\"GPQR\", \"GPQR_CG\", \"MTGPQR\", \"MTGPQR_CG\", \"Train\", \"Test\"]\n", "plt.legend(handles=handles, labels=labels, ncols=4)\n", "\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Negative ELBO\")\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "heavyedge", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.13" } }, "nbformat": 4, "nbformat_minor": 5 }