Source code for h0rton.train_utils.checkpoint_utils

import os
import numpy as np
import random
import datetime
import torch
__all__ = ['save_state_dict', 'load_state_dict', 'load_state_dict_test']


[docs]def save_state_dict(model, optimizer, lr_scheduler, train_loss, val_loss, checkpoint_dir, model_architecture, epoch_idx): """Save the state dict of the current training to disk Parameters ---------- model : torch model trained model to save optimizer : torch.optim object lr_scheduler: torch.optim.lr_scheduler object checkpoint_dir : str or os.path object directory into which to save the model model_architecture : str type of architecture epoch : int epoch index Returns ------- str or os.path object path to the saved model """ state = dict( model=model.state_dict(), optimizer=optimizer.state_dict(), lr_scheduler=lr_scheduler.state_dict(), epoch=epoch_idx, train_loss=train_loss, val_loss=val_loss, ) time_stamp = datetime.datetime.now().strftime("epoch={:d}_%m-%d-%Y_%H:%M".format(epoch_idx)) model_fname = '{:s}_{:s}.mdl'.format(model_architecture, time_stamp) model_path = os.path.join(checkpoint_dir, model_fname) torch.save(state, model_path) return model_path
[docs]def load_state_dict(checkpoint_path, model, optimizer, n_epochs, device, lr_scheduler=None): """Load the state dict of the past training Parameters ---------- checkpoint_path : str or os.path object path of the state dict to load model : torch model trained model to save optimizer : torch.optim object lr_scheduler: torch.optim.lr_scheduler object n_epochs : int total number of epochs to train device : torch.device object device on which to load the model Returns ------- str or os.path object path to the saved model """ state = torch.load(checkpoint_path) model.load_state_dict(state['model']) model.to(device) optimizer.load_state_dict(state['optimizer']) if lr_scheduler is not None: lr_scheduler.load_state_dict(state['lr_scheduler']) epoch = state['epoch'] train_loss = state['train_loss'] val_loss = state['val_loss'] print("Loaded weights at {:s}".format(checkpoint_path)) print("Epoch [{}/{}]: TRAIN Loss: {:.4f}".format(epoch+1, n_epochs, train_loss)) print("Epoch [{}/{}]: VALID Loss: {:.4f}".format(epoch+1, n_epochs, val_loss)) return epoch, model, optimizer, train_loss, val_loss
[docs]def load_state_dict_test(checkpoint_path, model, n_epochs, device): """Load the state dict of the past training Parameters ---------- checkpoint_path : str or os.path object path of the state dict to load model : torch model trained model to save optimizer : torch.optim object lr_scheduler: torch.optim.lr_scheduler object n_epochs : int total number of epochs to train device : torch.device object device on which to load the model Returns ------- str or os.path object path to the saved model """ state = torch.load(checkpoint_path) model.load_state_dict(state['model']) model.to(device) epoch = state['epoch'] train_loss = state['train_loss'] val_loss = state['val_loss'] print("Loaded weights at {:s}".format(checkpoint_path)) print("Epoch [{}/{}]: TRAIN Loss: {:.4f}".format(epoch+1, n_epochs, train_loss)) print("Epoch [{}/{}]: VALID Loss: {:.4f}".format(epoch+1, n_epochs, val_loss)) return model, epoch