Source code for h0rton.script_utils

import os
import sys
import argparse
import random
from addict import Dict
import numpy as np
import torch

__all__ = ['parse_inference_args', 'seed_everything', 'HiddenPrints']
__all__ += ['get_batch_size', 'infer_bnn']


def infer_bnn(net, bnn_post, param_logL, test_loader,
              batch_size, n_dropout, n_samples_per_dropout,
              device, global_seed=123, mode='default'):
    """Infer with MC dropout

    """
    Y_dim = param_logL.Y_dim
    init_pos = np.empty([batch_size, n_dropout, n_samples_per_dropout, Y_dim])
    mcmc_pred = np.empty([batch_size, n_dropout, param_logL.out_dim])
    with torch.no_grad():
        net.train()
        # Send some empty forward passes through the test data without backprop
        # to adjust batchnorm weights
        # (This is often not necessary. Beware if using for just 1 lens.)
        for nograd_pass in range(5):
            for X_, Y_ in test_loader:
                X = X_.to(device)
                _ = net(X)
        # Obtain MC dropout samples
        for d in range(n_dropout):
            net.eval()
            for X_, Y_ in test_loader:
                X = X_.to(device)
                if mode == 'default_with_truth_mean':
                    Y = Y_.to(device)
                pred = net(X)
                break
            mcmc_pred_d = pred.cpu().numpy()
            # Replace BNN posterior's primary gaussian mean with truth values
            if mode == 'default_with_truth_mean':
                mcmc_pred_d[:, :Y_dim] = Y[:, :Y_dim].cpu().numpy()
            # Leave only the MCMC parameters in pred
            mcmc_pred_d = param_logL.remove_params_from_pred(mcmc_pred_d,
                                                             return_as_tensor=False)
            # Populate pred that will define the MCMC penalty function
            mcmc_pred[:, d, :] = mcmc_pred_d
            # Instantiate posterior to generate BNN samples, which will serve as
            # initial positions for walkers
            bnn_post.set_sliced_pred(mcmc_pred_d)
            init_pos[:, d, :, :] = bnn_post.sample(n_samples_per_dropout,
            sample_seed=global_seed+d)  # just the lens model params, no D_dt
    return init_pos, mcmc_pred


def get_batch_size(cfg_lens_indices, cfg_n_test, args_lens_indices_path):
    """Figure out how many consecutive lenses BNN will predict on

    Parameters
    ----------
    cfg_lens_indices : list
        lens indices specified in the config file
    cfg_n_test : int
        number of test lenses specified in the config file
    args_lens_indices_path : os.path instance or str
        path to the text file containing lens indices, from the command line

    """
    if cfg_lens_indices is None:
        if args_lens_indices_path is None:
            # Test on all n_test lenses in the test set
            n_test = cfg_n_test
            lens_range = range(n_test)
        else:
            # Test on the lens indices in a text file at the specified path
            lens_range = []
            with open(args_lens_indices_path, "r") as f:
                for line in f:
                    lens_range.append(int(line.strip()))
            n_test = len(lens_range)
            msg = ("Performing H0 inference on {n_test}"
                   " specified lenses...")
            print(msg)
    else:
        if args_lens_indices_path is None:
            # Test on the lens indices specified in the test config file
            lens_range = cfg_lens_indices
            n_test = len(lens_range)
            msg = ("Performing H0 inference on {n_test}"
                   " specified lenses...")
            print(msg)
        else:
            raise ValueError("Specific lens indices were specified in both the"
                             " test config file and the command-line argument.")
    batch_size = max(lens_range) + 1
    return batch_size, n_test, lens_range


[docs]def parse_inference_args(): """Parse command-line arguments """ parser = argparse.ArgumentParser() parser.add_argument('test_config_file_path', help='path to the user-defined test config file') parser.add_argument('--lens_indices_path', default=None, dest='lens_indices_path', type=str, help='path to a text file with specific lens indices to test on (Default: None)') args = parser.parse_args() # sys.argv rerouting for setuptools entry point if args is None: args = Dict() args.user_cfg_path = sys.argv[0] #args.n_data = sys.argv[1] return args
[docs]def seed_everything(global_seed): """Seed everything for reproducibility global_seed : int seed for `np.random`, `random`, and relevant `torch` backends """ np.random.seed(global_seed) random.seed(global_seed) torch.manual_seed(global_seed) torch.cuda.manual_seed(global_seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False
[docs]class HiddenPrints: """Hide standard output """
[docs] def __enter__(self): self._original_stdout = sys.stdout sys.stdout = open(os.devnull, 'w')
[docs] def __exit__(self, exc_type, exc_val, exc_tb): sys.stdout.close() sys.stdout = self._original_stdout