Source code for h0rton.script_utils

import os
import sys
import argparse
import random
from addict import Dict
import numpy as np
import torch

__all__ = ['parse_inference_args', 'seed_everything', 'HiddenPrints']

[docs]def parse_inference_args(): """Parse command-line arguments """ parser = argparse.ArgumentParser() parser.add_argument('test_config_file_path', help='path to the user-defined test config file') parser.add_argument('--lens_indices_path', default=None, dest='lens_indices_path', type=str, help='path to a text file with specific lens indices to test on (Default: None)') args = parser.parse_args() # sys.argv rerouting for setuptools entry point if args is None: args = Dict() args.user_cfg_path = sys.argv[0] #args.n_data = sys.argv[1] return args
[docs]def seed_everything(global_seed): """Seed everything for reproducibility global_seed : int seed for `np.random`, `random`, and relevant `torch` backends """ np.random.seed(global_seed) random.seed(global_seed) torch.manual_seed(global_seed) torch.cuda.manual_seed(global_seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False
[docs]class HiddenPrints: """Hide standard output """
[docs] def __enter__(self): self._original_stdout = sys.stdout sys.stdout = open(os.devnull, 'w')
[docs] def __exit__(self, exc_type, exc_val, exc_tb): sys.stdout.close() sys.stdout = self._original_stdout