h0rton.train_utils

Package Contents

Functions

save_state_dict(model, optimizer, lr_scheduler, train_loss, val_loss, checkpoint_dir, model_architecture, epoch_idx) Save the state dict of the current training to disk
load_state_dict(checkpoint_path, model, optimizer, n_epochs, device, lr_scheduler=None) Load the state dict of the past training
load_state_dict_test(checkpoint_path, model, n_epochs, device) Load the state dict of the past training
get_1d_mapping_fig(name, mu, Y) Plots the marginal 1D mapping of the mean predictions
get_mae(pred_mu, true_mu, Y_cols) Get the total RMSE of predicted mu of the primary Gaussian wrt the transformed labels mu in a batch of validation data
interpret_pred(pred, Y_dim) Slice the network prediction into means and cov matrix elements
get_logdet(tril_elements, Y_dim) Returns the log determinant of the covariance matrix
h0rton.train_utils.save_state_dict(model, optimizer, lr_scheduler, train_loss, val_loss, checkpoint_dir, model_architecture, epoch_idx)[source]

Save the state dict of the current training to disk

model : torch model
trained model to save

optimizer : torch.optim object lr_scheduler: torch.optim.lr_scheduler object checkpoint_dir : str or os.path object

directory into which to save the model
model_architecture : str
type of architecture
epoch : int
epoch index
str or os.path object
path to the saved model
h0rton.train_utils.load_state_dict(checkpoint_path, model, optimizer, n_epochs, device, lr_scheduler=None)[source]

Load the state dict of the past training

checkpoint_path : str or os.path object
path of the state dict to load
model : torch model
trained model to save

optimizer : torch.optim object lr_scheduler: torch.optim.lr_scheduler object n_epochs : int

total number of epochs to train
device : torch.device object
device on which to load the model
str or os.path object
path to the saved model
h0rton.train_utils.load_state_dict_test(checkpoint_path, model, n_epochs, device)[source]

Load the state dict of the past training

checkpoint_path : str or os.path object
path of the state dict to load
model : torch model
trained model to save

optimizer : torch.optim object lr_scheduler: torch.optim.lr_scheduler object n_epochs : int

total number of epochs to train
device : torch.device object
device on which to load the model
str or os.path object
path to the saved model
h0rton.train_utils.get_1d_mapping_fig(name, mu, Y)[source]

Plots the marginal 1D mapping of the mean predictions

name : str
name of the parameter
mu : np.array of shape [batch_size,]
network prediction of the Gaussian mean
Y : np.array of shape [batch_size,]
truth label
which_normal_i : int
which Gaussian (0 for first, 1 for second)
matplotlib.FigureCanvas object
plot of network predictions against truth
h0rton.train_utils.get_mae(pred_mu, true_mu, Y_cols)[source]

Get the total RMSE of predicted mu of the primary Gaussian wrt the transformed labels mu in a batch of validation data

pred_mu : np.array of shape [batch_size, Y_dim]
predicted means of the primary Gaussian
true_mu : np.array of shape [batch_size, Y_dim]
true (label) Gaussian means
Y_cols : np.array of shape [Y_dim,]
the column names
dict
total mean of the RMSE for that batch
h0rton.train_utils.interpret_pred(pred, Y_dim)[source]

Slice the network prediction into means and cov matrix elements

pred : np.array of shape [batch_size, out_dim] Y_dim : int

number of parameters to predict

Currently hardcoded for DoubleGaussianNLL. (Update: no longer used; slicing function replaced by the BNNPosterior class.)

dict
pred sliced into parameters of the Gaussians to predict
h0rton.train_utils.get_logdet(tril_elements, Y_dim)[source]

Returns the log determinant of the covariance matrix