h0rton.train_utils.checkpoint_utils

Module Contents

Functions

save_state_dict(model, optimizer, lr_scheduler, train_loss, val_loss, checkpoint_dir, model_architecture, epoch_idx) Save the state dict of the current training to disk
load_state_dict(checkpoint_path, model, optimizer, n_epochs, device, lr_scheduler=None) Load the state dict of the past training
load_state_dict_test(checkpoint_path, model, n_epochs, device) Load the state dict of the past training
h0rton.train_utils.checkpoint_utils.save_state_dict(model, optimizer, lr_scheduler, train_loss, val_loss, checkpoint_dir, model_architecture, epoch_idx)[source]

Save the state dict of the current training to disk

model : torch model
trained model to save

optimizer : torch.optim object lr_scheduler: torch.optim.lr_scheduler object checkpoint_dir : str or os.path object

directory into which to save the model
model_architecture : str
type of architecture
epoch : int
epoch index
str or os.path object
path to the saved model
h0rton.train_utils.checkpoint_utils.load_state_dict(checkpoint_path, model, optimizer, n_epochs, device, lr_scheduler=None)[source]

Load the state dict of the past training

checkpoint_path : str or os.path object
path of the state dict to load
model : torch model
trained model to save

optimizer : torch.optim object lr_scheduler: torch.optim.lr_scheduler object n_epochs : int

total number of epochs to train
device : torch.device object
device on which to load the model
str or os.path object
path to the saved model
h0rton.train_utils.checkpoint_utils.load_state_dict_test(checkpoint_path, model, n_epochs, device)[source]

Load the state dict of the past training

checkpoint_path : str or os.path object
path of the state dict to load
model : torch model
trained model to save

optimizer : torch.optim object lr_scheduler: torch.optim.lr_scheduler object n_epochs : int

total number of epochs to train
device : torch.device object
device on which to load the model
str or os.path object
path to the saved model