Source code for test_train_val_config

import numpy as np
import unittest
import copy
from h0rton.configs import TrainValConfig

[docs]class TestTrainValConfig(unittest.TestCase): """A suite of tests for TrainValConfig """ @classmethod
[docs] def setUpClass(cls): cls.train_val_dict = dict( data=dict( ), monitoring=dict( n_plotting=20 ), model=dict( likelihood_class='DoubleGaussianNLL' ), optim=dict( batch_size=100
) )
[docs] def test_train_val_config_constructor(self): """Test the instantiation of TrainValConfig from a dictionary with minimum required keys """ train_val_dict = copy.deepcopy(self.train_val_dict) train_val_dict['data']['train_baobab_cfg_path'] = 'some_path' train_val_dict['data']['val_baobab_cfg_path'] = 'some_other_path' train_val_cfg = TrainValConfig(train_val_dict)
[docs] def test_train_val_absent(self): """Test if an error is raised when the either the train or val baobab config is not passed in """ train_val_dict = copy.deepcopy(self.train_val_dict) train_val_dict['data']['val_baobab_cfg_path'] = 'some_path' with np.testing.assert_raises(ValueError): train_val_cfg = TrainValConfig(train_val_dict) train_val_dict = copy.deepcopy(self.train_val_dict) train_val_dict['data']['train_baobab_cfg_path'] = 'some_path' with np.testing.assert_raises(ValueError): train_val_cfg = TrainValConfig(train_val_dict)
if __name__ == '__main__': unittest.main()