h0rton.train_utils.checkpoint_utils
¶
Module Contents¶
Functions¶
save_state_dict (model, optimizer, lr_scheduler, train_loss, val_loss, checkpoint_dir, model_architecture, epoch_idx) |
Save the state dict of the current training to disk |
load_state_dict (checkpoint_path, model, optimizer, n_epochs, device, lr_scheduler=None) |
Load the state dict of the past training |
load_state_dict_test (checkpoint_path, model, n_epochs, device) |
Load the state dict of the past training |
-
h0rton.train_utils.checkpoint_utils.
save_state_dict
(model, optimizer, lr_scheduler, train_loss, val_loss, checkpoint_dir, model_architecture, epoch_idx)[source]¶ Save the state dict of the current training to disk
- model : torch model
- trained model to save
optimizer : torch.optim object lr_scheduler: torch.optim.lr_scheduler object checkpoint_dir : str or os.path object
directory into which to save the model- model_architecture : str
- type of architecture
- epoch : int
- epoch index
- str or os.path object
- path to the saved model
-
h0rton.train_utils.checkpoint_utils.
load_state_dict
(checkpoint_path, model, optimizer, n_epochs, device, lr_scheduler=None)[source]¶ Load the state dict of the past training
- checkpoint_path : str or os.path object
- path of the state dict to load
- model : torch model
- trained model to save
optimizer : torch.optim object lr_scheduler: torch.optim.lr_scheduler object n_epochs : int
total number of epochs to train- device : torch.device object
- device on which to load the model
- str or os.path object
- path to the saved model
-
h0rton.train_utils.checkpoint_utils.
load_state_dict_test
(checkpoint_path, model, n_epochs, device)[source]¶ Load the state dict of the past training
- checkpoint_path : str or os.path object
- path of the state dict to load
- model : torch model
- trained model to save
optimizer : torch.optim object lr_scheduler: torch.optim.lr_scheduler object n_epochs : int
total number of epochs to train- device : torch.device object
- device on which to load the model
- str or os.path object
- path to the saved model