Source code for h0rton.h0_inference.gaussian_bnn_posterior

from abc import ABC, abstractmethod
import random
import numpy as np
import torch
__all__ = ['BaseGaussianBNNPosterior', 'DiagonalGaussianBNNPosterior', 'LowRankGaussianBNNPosterior', 'DoubleLowRankGaussianBNNPosterior', 'FullRankGaussianBNNPosterior', 'DoubleGaussianBNNPosterior']

[docs]class BaseGaussianBNNPosterior(ABC): """Abstract base class to represent the Gaussian BNN posterior Gaussian posteriors or mixtures thereof with various forms of the covariance matrix inherit from this class. """ def __init__(self, Y_dim, device, Y_mean=None, Y_std=None): """ Parameters ---------- Y_dim : int number of parameters to predict whitened_Y_cols_idx : list list of Y_cols indices that were whitened Y_mean : list mean values for the original values of `whitened_Y_cols` Y_std : list std values for the original values of `whitened_Y_cols` device : torch.device object """ self.Y_dim = Y_dim self.Y_mean = torch.Tensor(Y_mean).reshape(1, -1) self.Y_std = torch.Tensor(Y_std).reshape(1, -1) self.device = device self.sigmoid = torch.nn.Sigmoid() self.logsigmoid = torch.nn.LogSigmoid()
[docs] def seed_samples(self, sample_seed): """Seed the sampling for reproducibility Parameters ---------- sample_seed : int """ np.random.seed(sample_seed) random.seed(sample_seed) torch.manual_seed(sample_seed) torch.cuda.manual_seed(sample_seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False
@abstractmethod
[docs] def sample(self, n_samples, sample_seed=None): """Sample from the Gaussian posterior. Must be overridden by subclasses. Parameters ---------- n_samples : int how many samples to obtain sample_seed : int seed for the samples. Default: None Returns ------- np.array of shape `[n_samples, self.Y_dim]` samples """ return NotImplemented
@abstractmethod
[docs] def get_hpd_interval(self): """Get the highest posterior density (HPD) interval """ return NotImplemented
[docs] def transform_back_mu(self, tensor): """Transform back, i.e. unwhiten, the tensor of central values Parameters ---------- tensor : torch.Tensor of shape `[batch_size, Y_dim]` Returns ------- torch.tensor of shape `[batch_size, Y_dim]` the original tensor """ tensor = tensor.unsqueeze(1) tensor = self.unwhiten_back(tensor) return tensor.squeeze()
[docs] def transform_back_logvar(self, logvar): """Transform back, i.e. unwhiten, the tensor of predicted log of the diagonal entries of the cov mat Parameters ---------- tensor : torch.Tensor of shape `[batch_size, Y_dim]` Returns ------- torch.tensor of shape `[batch_size, Y_dim]` the original tensor """ natural_logvar = logvar*self.Y_std*self.Y_std # note self.Y_std is shape [1, Y_dim] return natural_logvar
[docs] def transform_back_cov_mat(self, cov_mat): """Transform back, i.e. unwhiten, the tensor of predicted covariance matrix Parameters ---------- tensor : torch.Tensor of shape `[batch_size, Y_dim, Y_dim]` Returns ------- torch.tensor of shape `[batch_size, Y_dim]` the original tensor """ natural_cov_mat = cov_mat*self.Y_std.unsqueeze(-1)*self.Y_std.unsqueeze(0) return natural_cov_mat
[docs] def unwhiten_back(self, sample): """Scale and shift back to the unwhitened state Parameters ---------- pred : torch.Tensor network prediction of shape `[batch_size, n_samples, self.Y_dim]` Returns ------- torch.Tensor the unwhitened pred """ sample = sample*self.Y_std.unsqueeze(1) + self.Y_mean.unsqueeze(1) return sample
[docs] def sample_low_rank(self, n_samples, mu, logvar, F): """Sample from a single Gaussian posterior with a full but low-rank plus diagonal covariance matrix Parameters ---------- n_samples : int how many samples to obtain mu : torch.Tensor of shape `[self.batch_size, self.Y_dim]` network prediction of the mu (mean parameter) of the BNN posterior logvar : torch.Tensor of shape `[self.batch_size, self.Y_dim]` network prediction of the log of the diagonal elements of the covariance matrix F : torch.Tensor of shape `[self.batch_size, self.Y_dim, self.rank]` network prediction of the low rank portion of the covariance matrix Returns ------- np.array of shape `[self.batch_size, n_samples, self.Y_dim]` samples """ #F = torch.unsqueeze(F, dim=1).repeat(1, n_samples, 1, 1) # [self.batch_size, n_samples, self.Y_dim, self.rank] F = F.repeat(n_samples, 1, 1) # [self.batch_size*n_samples, self.Y_dim, self.rank] mu = mu.repeat(n_samples, 1) # [self.batch_size*n_samples, self.Y_dim] logvar = logvar.repeat(n_samples, 1) # [self.batch_size*n_samples, self.Y_dim] eps_low_rank = torch.randn(self.batch_size*n_samples, self.rank, 1) eps_diag = torch.randn(self.batch_size*n_samples, self.Y_dim) half_var = torch.exp(0.5*logvar) # [self.batch_size*n_samples, self.Y_dim] samples = torch.bmm(F, eps_low_rank).squeeze() + mu + half_var*eps_diag samples = samples.reshape(n_samples, self.batch_size, self.Y_dim) samples = samples.transpose(0, 1) samples = self.unwhiten_back(samples) samples = samples.data.cpu().numpy() return samples
[docs] def sample_full_rank(self, n_samples, mu, tril_elements, as_numpy=True): """Sample from a single Gaussian posterior with a full-rank covariance matrix Parameters ---------- n_samples : int how many samples to obtain mu : torch.Tensor of shape `[self.batch_size, self.Y_dim]` network prediction of the mu (mean parameter) of the BNN posterior tril_elements : torch.Tensor of shape `[self.batch_size, tril_len]` network prediction of lower-triangular matrix in the log-Cholesky decomposition of the precision matrix Returns ------- np.array of shape `[self.batch_size, n_samples, self.Y_dim]` samples """ samples = torch.zeros([self.batch_size, n_samples, self.Y_dim]) for b in range(self.batch_size): tril = torch.zeros([self.Y_dim, self.Y_dim], device=self.device, dtype=None) tril[self.tril_idx[0], self.tril_idx[1]] = tril_elements[b, :] log_diag_tril = torch.diagonal(tril, offset=0, dim1=0, dim2=1) tril[torch.eye(self.Y_dim, dtype=bool)] = torch.exp(log_diag_tril) prec_mat = torch.mm(tril, tril.T) # [Y_dim, Y_dim] mvn = torch.distributions.multivariate_normal.MultivariateNormal(loc=mu[b, :], precision_matrix=prec_mat) sample_b = mvn.sample([n_samples,]) samples[b, :, :] = sample_b samples = self.unwhiten_back(samples) if as_numpy: return samples.cpu().numpy() else: return samples
[docs]class DiagonalGaussianBNNPosterior(BaseGaussianBNNPosterior): """The negative log likelihood (NLL) for a single Gaussian with diagonal covariance matrix `BaseGaussianNLL.__init__` docstring for the parameter description. """ def __init__(self, Y_dim, device, Y_mean=None, Y_std=None): super(DiagonalGaussianBNNPosterior, self).__init__(Y_dim, device, Y_mean, Y_std) self.out_dim = self.Y_dim*2
[docs] def set_sliced_pred(self, pred): d = self.Y_dim # for readability self.batch_size = pred.shape[0] self.mu = pred[:, :d] self.logvar = pred[:, d:] self.cov_diag = torch.exp(self.logvar)
[docs] def sample(self, n_samples, sample_seed): """Sample from a Gaussian posterior with diagonal covariance matrix Parameters ---------- n_samples : int how many samples to obtain sample_seed : int seed for the samples. Default: None Returns ------- np.array of shape `[n_samples, self.Y_dim]` samples """ self.seed_samples(sample_seed) eps = torch.randn(self.batch_size, n_samples, self.Y_dim) samples = eps*torch.exp(0.5*self.logvar.unsqueeze(1)) + self.mu.unsqueeze(1) samples = self.unwhiten_back(samples) samples = samples.data.cpu().numpy() return samples
[docs] def get_hpd_interval(self): return NotImplementedError
[docs]class LowRankGaussianBNNPosterior(BaseGaussianBNNPosterior): """The negative log likelihood (NLL) for a single Gaussian with diagonal covariance matrix `BaseGaussianNLL.__init__` docstring for the parameter description. """ def __init__(self, Y_dim, device, Y_mean=None, Y_std=None): super(LowRankGaussianBNNPosterior, self).__init__(Y_dim, device, Y_mean, Y_std) self.out_dim = self.Y_dim*4 self.rank = 2 # FIXME: hardcoded
[docs] def set_sliced_pred(self, pred): d = self.Y_dim # for readability self.batch_size = pred.shape[0] self.mu = pred[:, :d] self.logvar = pred[:, d:2*d] self.F = pred[:, 2*d:].reshape([self.batch_size, self.Y_dim, self.rank]) F_F_tran = torch.bmm(self.F, torch.transpose(self.F, 1, 2)) # [n_lenses, d, d] self.cov_diag = torch.exp(self.logvar) + torch.diagonal(F_F_tran, dim1=1, dim2=2) # [n_lenses, d] self.cov_mat = torch.diag_embed(self.logvar) + F_F_tran
[docs] def sample(self, n_samples, sample_seed): self.seed_samples(sample_seed) return self.sample_low_rank(n_samples, self.mu, self.logvar, self.F)
[docs] def get_hpd_interval(self): return NotImplementedError
[docs]class DoubleLowRankGaussianBNNPosterior(BaseGaussianBNNPosterior): """The negative log likelihood (NLL) for a single Gaussian with diagonal covariance matrix `BaseGaussianNLL.__init__` docstring for the parameter description. """ def __init__(self, Y_dim, device, Y_mean=None, Y_std=None): super(DoubleLowRankGaussianBNNPosterior, self).__init__(Y_dim, device, Y_mean, Y_std) self.out_dim = self.Y_dim*8 + 1 self.rank = 2 # FIXME: hardcoded
[docs] def set_sliced_pred(self, pred): d = self.Y_dim # for readability self.batch_size = pred.shape[0] # First gaussian self.mu = pred[:, :d] self.logvar = pred[:, d:2*d] self.F = pred[:, 2*d:4*d].reshape([self.batch_size, self.Y_dim, self.rank]) F_F_tran = torch.bmm(self.F, torch.transpose(self.F, 1, 2)) # [n_lenses, d, d] self.cov_diag = torch.exp(self.logvar) + torch.diagonal(F_F_tran, dim1=1, dim2=2) # [n_lenses, d] self.cov_mat = torch.diag_embed(self.logvar) + F_F_tran # Second gaussian self.mu2 = pred[:, 4*d:5*d] self.logvar2 = pred[:, 5*d:6*d] self.F2 = pred[:, 6*d:8*d].reshape([self.batch_size, self.Y_dim, self.rank]) F_F_tran2 = torch.bmm(self.F2, torch.transpose(self.F2, 1, 2)) self.cov_diag2 = torch.exp(self.logvar2) + torch.diagonal(F_F_tran2, dim1=1, dim2=2) self.cov_mat2 = torch.diag_embed(self.logvar2) + F_F_tran2 self.w2 = 0.5*self.sigmoid(pred[:, -1].reshape(-1, 1))
[docs] def sample(self, n_samples, sample_seed): """Sample from a mixture of two Gaussians, each with a full but constrained as low-rank plus diagonal covariance Parameters ---------- n_samples : int how many samples to obtain sample_seed : int seed for the samples. Default: None Returns ------- np.array of shape `[self.batch_size, n_samples, self.Y_dim]` samples """ self.seed_samples(sample_seed) samples = torch.zeros([self.batch_size, n_samples, self.Y_dim], device=self.device) # Determine first vs. second Gaussian unif2 = torch.rand(self.batch_size, n_samples) second_gaussian = (self.w2 > unif2) # Sample from second Gaussian samples2 = torch.Tensor(self.sample_low_rank(n_samples, self.mu2, self.logvar2, self.F2)) samples[second_gaussian, :] = samples2[second_gaussian, :] # Sample from first Gaussian samples1 = torch.Tensor(self.sample_low_rank(n_samples, self.mu, self.logvar, self.F)) samples[~second_gaussian, :] = samples1[~second_gaussian, :] samples = samples.data.cpu().numpy() return samples
[docs] def get_hpd_interval(self): return NotImplementedError
[docs]class FullRankGaussianBNNPosterior(BaseGaussianBNNPosterior): """The negative log likelihood (NLL) for a single Gaussian with diagonal covariance matrix `BaseGaussianNLL.__init__` docstring for the parameter description. """ def __init__(self, Y_dim, device, Y_mean=None, Y_std=None): super(FullRankGaussianBNNPosterior, self).__init__(Y_dim, device, Y_mean, Y_std) self.tril_idx = torch.tril_indices(self.Y_dim, self.Y_dim, offset=0, device=device) # lower-triangular indices self.tril_len = len(self.tril_idx[0]) self.out_dim = self.Y_dim + self.Y_dim*(self.Y_dim + 1)//2
[docs] def set_sliced_pred(self, pred): d = self.Y_dim # for readability self.batch_size = pred.shape[0] self.mu = pred[:, :d] self.tril_elements = pred[:, d:self.out_dim]
[docs] def sample(self, n_samples, sample_seed): self.seed_samples(sample_seed) return self.sample_full_rank(n_samples, self.mu, self.tril_elements)
[docs] def get_hpd_interval(self): return NotImplementedError
[docs]class DoubleGaussianBNNPosterior(BaseGaussianBNNPosterior): """The negative log likelihood (NLL) for a single Gaussian with diagonal covariance matrix `BaseGaussianNLL.__init__` docstring for the parameter description. """ def __init__(self, Y_dim, device, Y_mean=None, Y_std=None): super(DoubleGaussianBNNPosterior, self).__init__(Y_dim, device, Y_mean, Y_std) self.tril_idx = torch.tril_indices(self.Y_dim, self.Y_dim, offset=0, device=device) # lower-triangular indices self.tril_len = len(self.tril_idx[0]) self.out_dim = self.Y_dim**2 + 3*self.Y_dim + 1
[docs] def set_sliced_pred(self, pred): d = self.Y_dim # for readability self.batch_size = pred.shape[0] # First gaussian self.mu = pred[:, :d] self.tril_elements = pred[:, d:d+self.tril_len] self.mu2 = pred[:, d+self.tril_len:2*d+self.tril_len] self.tril_elements2 = pred[:, 2*d+self.tril_len:-1] #print(pred[:, -1]) self.w2 = 0.5*self.sigmoid(pred[:, -1].reshape(-1, 1))
[docs] def sample(self, n_samples, sample_seed): """Sample from a mixture of two Gaussians, each with a full but constrained as low-rank plus diagonal covariance Parameters ---------- n_samples : int how many samples to obtain sample_seed : int seed for the samples. Default: None Returns ------- np.array of shape `[self.batch_size, n_samples, self.Y_dim]` samples """ self.seed_samples(sample_seed) samples = torch.zeros([self.batch_size, n_samples, self.Y_dim], device=self.device) # Determine first vs. second Gaussian unif2 = torch.rand(self.batch_size, n_samples) second_gaussian = (self.w2 > unif2) # Sample from second Gaussian samples2 = self.sample_full_rank(n_samples, self.mu2, self.tril_elements2, as_numpy=False) samples[second_gaussian, :] = samples2[second_gaussian, :] # Sample from first Gaussian samples1 = self.sample_full_rank(n_samples, self.mu, self.tril_elements, as_numpy=False) samples[~second_gaussian, :] = samples1[~second_gaussian, :] samples = samples.data.cpu().numpy() return samples
[docs] def get_hpd_interval(self): return NotImplementedError