Source code for h0rton.losses.gaussian_nll_native

from abc import ABC, abstractmethod
import numpy as np
import torch
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.lowrank_multivariate_normal import LowRankMultivariateNormal
__all__ = ['BaseGaussianNLLNative', 'DiagonalGaussianNLLNative', 'LowRankGaussianNLLNative', 'DoubleLowRankGaussianNLLNative', 'FullRankGaussianNLLNative', 'DoubleGaussianNLLNative']

log_2_pi = 1.8378770664093453
log_2 = 0.6931471805599453

[docs]class BaseGaussianNLLNative(ABC): """Abstract base class to represent the Gaussian negative log likelihood (NLLNative). Gaussian NLLNatives or mixtures thereof with various forms of the covariance matrix inherit from this class. """ def __init__(self, Y_dim, device): """ Parameters ---------- Y_dim : int number of parameters to predict device : torch.device object """ self.Y_dim = Y_dim self.device = device self.sigmoid = torch.nn.Sigmoid() self.logsigmoid = torch.nn.LogSigmoid() @abstractmethod
[docs] def slice(self, pred): """Slice the raw network prediction into meaningful Gaussian parameters Parameters ---------- pred : torch.Tensor of shape `[batch_size, self.Y_dim]` the network prediction """ return NotImplemented
@abstractmethod
[docs] def __call__(self, pred, target): """Evaluate the NLLNative. Must be overridden by subclasses. Parameters ---------- pred : torch.Tensor raw network output for the predictions target : torch.Tensor Y labels """ return NotImplemented
[docs] def nll_diagonal(self, target, mu, logvar): """Evaluate the NLLNative for single Gaussian with diagonal covariance matrix Parameters ---------- target : torch.Tensor of shape [batch_size, Y_dim] Y labels mu : torch.Tensor of shape [batch_size, Y_dim] network prediction of the mu (mean parameter) of the BNN posterior logvar : torch.Tensor of shape [batch_size, Y_dim] network prediction of the log of the diagonal elements of the covariance matrix Returns ------- torch.Tensor of shape NLL values """ precision = torch.exp(-logvar) # Loss kernel loss = precision * (target - mu)**2.0 + logvar # Restore prefactors loss += np.log(2.0*np.pi) loss *= 0.5 return torch.mean(torch.sum(loss, dim=1), dim=0)
[docs] def nll_low_rank(self, target, mu, logvar, F, reduce=True): """Evaluate the NLLNative for a single Gaussian with a full but low-rank plus diagonal covariance matrix Parameters ---------- target : torch.Tensor of shape [batch_size, Y_dim] Y labels mu : torch.Tensor of shape [batch_size, Y_dim] network prediction of the mu (mean parameter) of the BNN posterior logvar : torch.Tensor of shape [batch_size, Y_dim] network prediction of the log of the diagonal elements of the covariance matrix F : torch.Tensor of shape [batch_size, rank*Y_dim] network prediction of the low rank portion of the covariance matrix reduce : bool whether to take the mean across the batch Returns ------- torch.Tensor of shape [batch_size,] NLL values """ # 1/(Y_dim - 1) * (sq_mahalanobis + log(det of \Sigma)) batch_size, _ = target.shape # self.Y_dim = Y_dim - 1 rank = 2 # FIXME: hardcoded for rank 2 F = F.reshape([batch_size, self.Y_dim, rank]) lr_mvn = LowRankMultivariateNormal(loc=mu, cov_factor=F, cov_diag=torch.exp(logvar)) loss = -lr_mvn.log_prob(target) if reduce==True: return torch.mean(loss, dim=0) # float else: return loss # [batch_size,]
[docs] def nll_mixture_low_rank(self, target, mu, logvar, F, mu2, logvar2, F2, alpha): """Evaluate the NLLNative for a single Gaussian with a full but low-rank plus diagonal covariance matrix Parameters ---------- target : torch.Tensor of shape [batch_size, Y_dim] Y labels mu : torch.Tensor of shape [batch_size, Y_dim] network prediction of the mu (mean parameter) of the BNN posterior for the first Gaussian logvar : torch.Tensor of shape [batch_size, Y_dim] network prediction of the log of the diagonal elements of the covariance matrix for the first Gaussian F : torch.Tensor of shape [batch_size, rank*Y_dim] network prediction of the low rank portion of the covariance matrix for the first Gaussian mu2 : torch.Tensor of shape [batch_size, Y_dim] network prediction of the mu (mean parameter) of the BNN posterior for the second Gaussian logvar2 : torch.Tensor of shape [batch_size, Y_dim] network prediction of the log of the diagonal elements of the covariance matrix for the second Gaussian F2 : torch.Tensor of shape [batch_size, rank*Y_dim] network prediction of the low rank portion of the covariance matrix for the second Gaussian alpha : torch.Tensor of shape [batch_size, 1] network prediction of the logit of twice the weight on the second Gaussian reduce : bool whether to take the mean across the batch Note ---- The weight on the second Gaussian is required to be less than 0.5, to make the two Gaussians well-defined. Returns ------- torch.Tensor of shape [batch_size,] NLL values """ batch_size, _ = target.shape alpha = alpha.reshape(-1) #log_w1p1 = -alpha -torch.log1p(torch.exp(-alpha)) - self.nll_low_rank(target, mu, logvar, F=F, reduce=False) # [batch_size] #log_w2p2 = self.logsigmoid(alpha) - self.nll_low_rank(target, mu2, logvar2, F=F2, reduce=False) # [batch_size], 0.6931471 = np.log(2) log_w1p1 = torch.log1p(2.0*torch.exp(-alpha)) - log_2 - torch.log1p(torch.exp(-alpha)) - self.nll_low_rank(target, mu, logvar, F=F, reduce=False) log_w2p2 = -log_2 + self.logsigmoid(alpha) - self.nll_low_rank(target, mu2, logvar2, F=F2, reduce=False) stacked = torch.stack([log_w1p1, log_w2p2], dim=1) log_nll = -torch.logsumexp(stacked, dim=1) return torch.mean(log_nll)
[docs] def nll_full_rank(self, target, mu, tril_elements, reduce=True): """Evaluate the NLLNative for a single Gaussian with a full-rank covariance matrix Parameters ---------- target : torch.Tensor of shape [batch_size, Y_dim] Y labels mu : torch.Tensor of shape [batch_size, Y_dim] network prediction of the mu (mean parameter) of the BNN posterior tril_elements : torch.Tensor of shape [batch_size, Y_dim*(Y_dim + 1)//2] reduce : bool whether to take the mean across the batch Returns ------- torch.Tensor of shape [batch_size,] NLL values """ batch_size, _ = target.shape tril = torch.zeros([batch_size, self.Y_dim, self.Y_dim], device=self.device, dtype=None) tril[:, self.tril_idx[0], self.tril_idx[1]] = tril_elements log_diag_tril = torch.diagonal(tril, offset=0, dim1=1, dim2=2) # [batch_size, Y_dim] tril[:, torch.eye(self.Y_dim, dtype=bool)] = torch.exp(log_diag_tril) prec_mat = torch.bmm(tril, torch.transpose(tril, 1, 2)) mvn = MultivariateNormal(loc=mu, precision_matrix=prec_mat) loss = -mvn.log_prob(target) if reduce: return torch.mean(loss, dim=0) # float else: return loss # [batch_size,]
[docs] def nll_mixture(self, target, mu, tril_elements, mu2, tril_elements2, alpha): """Evaluate the NLLNative for a single Gaussian with a full but low-rank plus diagonal covariance matrix Parameters ---------- target : torch.Tensor of shape [batch_size, Y_dim] Y labels mu : torch.Tensor of shape [batch_size, Y_dim] network prediction of the mu (mean parameter) of the BNN posterior for the first Gaussian tril_elements : torch.Tensor of shape [batch_size, self.tril_len] network prediction of the elements in the precision matrix mu2 : torch.Tensor of shape [batch_size, Y_dim] network prediction of the mu (mean parameter) of the BNN posterior for the second Gaussian tril_elements2 : torch.Tensor of shape [batch_size, self.tril_len] network prediction of the elements in the precision matrix for the second Gaussian alpha : torch.Tensor of shape [batch_size, 1] network prediction of the logit of twice the weight on the second Gaussian Note ---- The weight on the second Gaussian is required to be less than 0.5, to make the two Gaussians well-defined. Returns ------- torch.Tensor of shape [batch_size,] NLL values """ batch_size, _ = target.shape alpha = alpha.reshape(-1) log_w1p1 = torch.log1p(2.0*torch.exp(-alpha)) - log_2 - torch.log1p(torch.exp(-alpha)) - self.nll_full_rank(target, mu, tril_elements, reduce=False) # [batch_size] log_w2p2 = -log_2 + self.logsigmoid(alpha) - self.nll_full_rank(target, mu2, tril_elements2, reduce=False) # [batch_size] stacked = torch.stack([log_w1p1, log_w2p2], dim=1) log_nll = -torch.logsumexp(stacked, dim=1) return torch.mean(log_nll)
[docs]class DiagonalGaussianNLLNative(BaseGaussianNLLNative): """The negative log likelihood (NLLNative) for a single Gaussian with diagonal covariance matrix `BaseGaussianNLLNative.__init__` docstring for the parameter description. """
[docs] posterior_name = 'DiagonalGaussianBNNPosterior'
def __init__(self, Y_dim, device): super(DiagonalGaussianNLLNative, self).__init__(Y_dim, device) self.out_dim = Y_dim*2
[docs] def __call__(self, pred, target): return self.nll_diagonal(target, *self.slice(pred))
[docs] def slice(self, pred): d = self.Y_dim # for readability return torch.split(pred, [d, d], dim=1)
[docs]class LowRankGaussianNLLNative(BaseGaussianNLLNative): """The negative log likelihood (NLLNative) for a single Gaussian with a full but constrained as low-rank plus diagonal covariance matrix Only rank 2 is currently supported. `BaseGaussianNLLNative.__init__` docstring for the parameter description. """
[docs] posterior_name = 'LowRankGaussianBNNPosterior'
def __init__(self, Y_dim, device): super(LowRankGaussianNLLNative, self).__init__(Y_dim, device) self.out_dim = Y_dim*4
[docs] def __call__(self, pred, target): return self.nll_low_rank(target, *self.slice(pred), reduce=True)
[docs] def slice(self, pred): d = self.Y_dim # for readability return torch.split(pred, [d, d, 2*d], dim=1)
[docs]class FullRankGaussianNLLNative(BaseGaussianNLLNative): """The negative log likelihood (NLLNative) for a single Gaussian with a full-rank covariance matrix See `BaseGaussianNLLNative.__init__` docstring for the parameter description. """
[docs] posterior_name = 'FullRankGaussianBNNPosterior'
def __init__(self, Y_dim, device): super(FullRankGaussianNLLNative, self).__init__(Y_dim, device) self.tril_idx = torch.tril_indices(self.Y_dim, self.Y_dim, offset=0, device=device) # lower-triangular indices self.tril_len = len(self.tril_idx[0]) self.out_dim = self.Y_dim + self.Y_dim*(self.Y_dim + 1)//2
[docs] def __call__(self, pred, target): return self.nll_full_rank(target, *self.slice(pred), reduce=True)
[docs] def slice(self, pred): d = self.Y_dim # for readability return torch.split(pred, [d, self.tril_len], dim=1)
[docs]class DoubleLowRankGaussianNLLNative(BaseGaussianNLLNative): """The negative log likelihood (NLLNative) for a mixture of two Gaussians, each with a full but constrained as low-rank plus diagonal covariance Only rank 2 is currently supported. `BaseGaussianNLLNative.__init__` docstring for the parameter description. """
[docs] posterior_name = 'DoubleLowRankGaussianBNNPosterior'
def __init__(self, Y_dim, device): super(DoubleLowRankGaussianNLLNative, self).__init__(Y_dim, device) self.out_dim = Y_dim*8 + 1
[docs] def __call__(self, pred, target): return self.nll_mixture_low_rank(target, *self.slice(pred))
[docs] def slice(self, pred): d = self.Y_dim # for readability #mu, logvar, F, mu2, logvar2, F2, alpha return torch.split(pred, [d, d, 2*d, d, d, 2*d, 1], dim=1)
[docs]class DoubleGaussianNLLNative(BaseGaussianNLLNative): """The negative log likelihood (NLLNative) for a mixture of two Gaussians, each with a full but constrained as low-rank plus diagonal covariance Only rank 2 is currently supported. `BaseGaussianNLLNative.__init__` docstring for the parameter description. """
[docs] posterior_name = 'DoubleGaussianBNNPosterior'
def __init__(self, Y_dim, device): super(DoubleGaussianNLLNative, self).__init__(Y_dim, device) self.tril_idx = torch.tril_indices(self.Y_dim, self.Y_dim, offset=0, device=device) # lower-triangular indices self.tril_len = len(self.tril_idx[0]) self.out_dim = self.Y_dim**2 + 3*self.Y_dim + 1
[docs] def __call__(self, pred, target): return self.nll_mixture(target, *self.slice(pred))
[docs] def slice(self, pred): d = self.Y_dim # for readability return torch.split(pred, [d, self.tril_len, d, self.tril_len, 1], dim=1)