import os
import sys
import argparse
import random
from addict import Dict
import numpy as np
import torch
__all__ = ['parse_inference_args', 'seed_everything', 'HiddenPrints']
[docs]def parse_inference_args():
"""Parse command-line arguments
"""
parser = argparse.ArgumentParser()
parser.add_argument('test_config_file_path', help='path to the user-defined test config file')
parser.add_argument('--lens_indices_path', default=None, dest='lens_indices_path', type=str,
help='path to a text file with specific lens indices to test on (Default: None)')
args = parser.parse_args()
# sys.argv rerouting for setuptools entry point
if args is None:
args = Dict()
args.user_cfg_path = sys.argv[0]
#args.n_data = sys.argv[1]
return args
[docs]def seed_everything(global_seed):
"""Seed everything for reproducibility
global_seed : int
seed for `np.random`, `random`, and relevant `torch` backends
"""
np.random.seed(global_seed)
random.seed(global_seed)
torch.manual_seed(global_seed)
torch.cuda.manual_seed(global_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
[docs]class HiddenPrints:
"""Hide standard output
"""
[docs] def __enter__(self):
self._original_stdout = sys.stdout
sys.stdout = open(os.devnull, 'w')
[docs] def __exit__(self, exc_type, exc_val, exc_tb):
sys.stdout.close()
sys.stdout = self._original_stdout