h0rton.train_utils.logging_utils

Module Contents

Functions

get_logdet(tril_elements, Y_dim) Returns the log determinant of the covariance matrix
get_1d_mapping_fig(name, mu, Y) Plots the marginal 1D mapping of the mean predictions
get_mae(pred_mu, true_mu, Y_cols) Get the total RMSE of predicted mu of the primary Gaussian wrt the transformed labels mu in a batch of validation data
interpret_pred(pred, Y_dim) Slice the network prediction into means and cov matrix elements
h0rton.train_utils.logging_utils.get_logdet(tril_elements, Y_dim)[source]

Returns the log determinant of the covariance matrix

h0rton.train_utils.logging_utils.get_1d_mapping_fig(name, mu, Y)[source]

Plots the marginal 1D mapping of the mean predictions

name : str
name of the parameter
mu : np.array of shape [batch_size,]
network prediction of the Gaussian mean
Y : np.array of shape [batch_size,]
truth label
which_normal_i : int
which Gaussian (0 for first, 1 for second)
matplotlib.FigureCanvas object
plot of network predictions against truth
h0rton.train_utils.logging_utils.get_mae(pred_mu, true_mu, Y_cols)[source]

Get the total RMSE of predicted mu of the primary Gaussian wrt the transformed labels mu in a batch of validation data

pred_mu : np.array of shape [batch_size, Y_dim]
predicted means of the primary Gaussian
true_mu : np.array of shape [batch_size, Y_dim]
true (label) Gaussian means
Y_cols : np.array of shape [Y_dim,]
the column names
dict
total mean of the RMSE for that batch
h0rton.train_utils.logging_utils.interpret_pred(pred, Y_dim)[source]

Slice the network prediction into means and cov matrix elements

pred : np.array of shape [batch_size, out_dim] Y_dim : int

number of parameters to predict

Currently hardcoded for DoubleGaussianNLL. (Update: no longer used; slicing function replaced by the BNNPosterior class.)

dict
pred sliced into parameters of the Gaussians to predict