Source code for h0rton.train_utils.logging_utils

import numpy as np
from matplotlib.figure import Figure
__all__ = ['get_1d_mapping_fig', 'get_mae', 'interpret_pred', 'get_logdet']

[docs]def get_logdet(tril_elements, Y_dim): """Returns the log determinant of the covariance matrix """ batch_size = tril_elements.shape[0] tril = np.zeros([batch_size, Y_dim, Y_dim]) tril_idx = np.tril_indices(Y_dim) tril_len = len(tril_idx[0]) tril[:, tril_idx[0], tril_idx[1]] = tril_elements[:, :tril_len] # safeguarding measure log_diag_tril = np.diagonal(tril, offset=0, axis1=1, axis2=2) # [batch_size, Y_dim] return -np.sum(log_diag_tril, axis=1) # [batch_size,]
[docs]def get_1d_mapping_fig(name, mu, Y): """Plots the marginal 1D mapping of the mean predictions Parameters ---------- name : str name of the parameter mu : np.array of shape [batch_size,] network prediction of the Gaussian mean Y : np.array of shape [batch_size,] truth label which_normal_i : int which Gaussian (0 for first, 1 for second) Returns ------- matplotlib.FigureCanvas object plot of network predictions against truth """ my_dpi = 72.0 fig = Figure(figsize=(720/my_dpi, 360/my_dpi), dpi=my_dpi, tight_layout=True) ax = fig.gca() # Reference (perfect) mapping perfect = np.linspace(np.min(Y), np.max(Y), 20) ax.plot(perfect, np.zeros_like(perfect), linestyle='--', color='b', label="Perfect mapping") # For this param offset = mu - Y ax.scatter(Y, offset, color='tab:blue', marker='o') ax.set_title('offset vs. truth ({:s})'.format(name)) ax.set_ylabel('pred - truth') ax.set_xlabel('truth') return fig
def _sigmoid(self, x): return 1.0/(np.exp(-x) + 1.0)
[docs]def get_mae(pred_mu, true_mu, Y_cols): """Get the total RMSE of predicted mu of the primary Gaussian wrt the transformed labels mu in a batch of validation data Parameters ---------- pred_mu : np.array of shape `[batch_size, Y_dim]` predicted means of the primary Gaussian true_mu : np.array of shape `[batch_size, Y_dim]` true (label) Gaussian means Y_cols : np.array of shape `[Y_dim,]` the column names Returns ------- dict total mean of the RMSE for that batch """ abs_error = np.abs(pred_mu - true_mu) # [batch_size, Y_dim] mae_dict = {} # Total sq error, averaged across examples mae_dict['mae'] = np.median(np.sum(abs_error, axis=1)) # float # Parameter-wise sq error, averaged across examples mae_param = np.median(abs_error, axis=0) # [Y_dim,] param_dict = dict(zip(Y_cols, range(len(Y_cols)))) for k, i in param_dict.items(): mae_dict[k] = mae_param[i] return mae_dict
[docs]def interpret_pred(pred, Y_dim): """Slice the network prediction into means and cov matrix elements Parameters ---------- pred : np.array of shape `[batch_size, out_dim]` Y_dim : int number of parameters to predict Note ---- Currently hardcoded for `DoubleGaussianNLL`. (Update: no longer used; slicing function replaced by the BNNPosterior class.) Returns ------- dict pred sliced into parameters of the Gaussians to predict """ d = Y_dim # for readability mu = pred[:, :d] logvar = pred[:, d:2*d] F = pred[:, 2*d:4*d] mu2 = pred[:, 4*d:5*d] logvar2 = pred[:, 5*d:6*d] F2 = pred[:, 6*d:8*d] alpha = pred[:, -1] w2 = 0.5/(np.exp(-alpha) + 1.0) pred_dict = dict( mu=mu, logvar=logvar, F=F, mu2=mu2, logvar2=logvar2, F2=F2, w2=w2, ) return pred_dict