Source code for h0rton.configs.train_val_config

import os, sys
import importlib
import warnings
import json
import glob
import numpy as np
import pandas as pd
from addict import Dict
from baobab.sim_utils import add_g1g2_columns

[docs]class TrainValConfig: """Nested dictionary representing the configuration for H0rton training, h0_inference, visualization, and analysis """ def __init__(self, user_cfg): """ Parameters ---------- user_cfg : dict or Dict user-defined configuration """ self.__dict__ = Dict(user_cfg) self.validate_user_definition() self.preset_default() self.set_monitoring_cfg() @classmethod
[docs] def from_file(cls, user_cfg_path): """Alternative constructor that accepts the path to the user-defined configuration python file Parameters ---------- user_cfg_path : str or os.path object path to the user-defined configuration python file """ dirname, filename = os.path.split(os.path.abspath(user_cfg_path)) module_name, ext = os.path.splitext(filename) sys.path.append(dirname) if ext == '.py': #user_cfg_file = map(__import__, module_name) #user_cfg = getattr(user_cfg_file, 'cfg') user_cfg_script = importlib.import_module(module_name) user_cfg = getattr(user_cfg_script, 'cfg') return cls(user_cfg) elif ext == '.json': with open(user_cfg_path, 'r') as f: user_cfg_str = f.read() user_cfg = Dict(json.loads(user_cfg_str)) return cls(user_cfg) else: raise NotImplementedError("This extension is not supported.")
[docs] def validate_user_definition(self): """Check to see if the user-defined config is valid """ import h0rton.losses if not hasattr(h0rton.losses, self.model.likelihood_class): raise TypeError("Likelihood class supplied in cfg doesn't exist.")
[docs] def preset_default(self): """Preset default config values """ if 'train_baobab_cfg_path' not in self.data: raise ValueError("Must provide training data directory.") if 'val_baobab_cfg_path' not in self.data: raise ValueError("Must provide validation data directory.") # FIXME: doesn't check for contents of baobab config file, just the file names if self.data.train_baobab_cfg_path == self.data.val_baobab_cfg_path: warnings.warn("You're training and validating on the same dataset.", UserWarning, stacklevel=2) if 'float_type' not in self.data: self.data.float_type = 'FloatTensor' warnings.warn("Float type not provided. Defaulting to float32...")
[docs] def set_monitoring_cfg(self): """Set general metadata relevant to network architecture and optimization """ # Data to plot during monitoring if self.monitoring.n_plotting > 100: warnings.warn("Only plotting allowed max of 100 datapoints during training") self.monitoring.n_plotting = 100 if self.monitoring.n_plotting > self.optim.batch_size: raise ValueError("monitoring.n_plotting must be smaller than optim.batch_size")