Source code for h0rton.trainval_data.xy_data

import os
import glob
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from baobab import BaobabConfig
from baobab.data_augmentation.noise_torch import NoiseModelTorch
from baobab.sim_utils import add_g1g2_columns
from .data_utils import whiten_pixels, rescale_01, plus_1_log, whiten_Y_cols

__all__ = ['XYData']

[docs]class XYData(Dataset): # torch.utils.data.Dataset """Represents the XYData used to train or validate the BNN """ def __init__(self, is_train, Y_cols, float_type, define_src_pos_wrt_lens, rescale_pixels, log_pixels, add_pixel_noise, eff_exposure_time, train_Y_mean=None, train_Y_std=None, train_baobab_cfg_path=None, val_baobab_cfg_path=None, for_cosmology=False, rescale_pixels_type='whiten_pixels'): """ Parameters ---------- dataset_dir : str or os.path object path to the directory containing the images and metadata data_cfg : dict or Dict copy of the `data` field of `BNNConfig` for_cosmology : bool whether the dataset will be used in cosmological inference (in which case cosmology-related metadata will be stored) """ #self.__dict__ = data_cfg.deepcopy() self.is_train = is_train if self.is_train: self.baobab_cfg = BaobabConfig.from_file(train_baobab_cfg_path) else: self.baobab_cfg = BaobabConfig.from_file(val_baobab_cfg_path) self.dataset_dir = self.baobab_cfg.out_dir if not self.is_train: if train_Y_mean is None or train_Y_std is None: raise ValueError("Mean and std of training set must be provided for whitening.") self.train_Y_mean = train_Y_mean self.train_Y_std = train_Y_std self.Y_cols = Y_cols self.float_type = float_type self.float_type_numpy = np.float64 if 'Double' in float_type else np.float32 self.define_src_pos_wrt_lens = define_src_pos_wrt_lens self.rescale_pixels = rescale_pixels self.log_pixels = log_pixels self.add_pixel_noise = add_pixel_noise self.eff_exposure_time = eff_exposure_time self.bandpass_list = self.baobab_cfg.survey_info.bandpass_list self.for_cosmology = for_cosmology ################# # Target labels # ################# metadata_path = os.path.join(self.dataset_dir, 'metadata.csv') Y_df = pd.read_csv(metadata_path, index_col=False) if 'external_shear_gamma1' not in Y_df.columns: # assumes gamma_ext, psi_ext were sampled Y_df = add_g1g2_columns(Y_df) # Define source light position as offset from lens mass if self.define_src_pos_wrt_lens: Y_df['src_light_center_x'] -= Y_df['lens_mass_center_x'] Y_df['src_light_center_y'] -= Y_df['lens_mass_center_y'] train_Y_to_whiten = Y_df[self.Y_cols].values if self.is_train: self.train_Y_mean = np.mean(train_Y_to_whiten, axis=0, keepdims=True) self.train_Y_std = np.std(train_Y_to_whiten, axis=0, keepdims=True) # Store the unwhitened metadata if self.for_cosmology: self.Y_df = Y_df.copy() # Number of predictive columns self.Y_dim = len(self.Y_cols) # Whiten the columns whiten_Y_cols(Y_df, self.train_Y_mean, self.train_Y_std, self.Y_cols) # Convert into array the columns required for training self.img_filenames = Y_df['img_filename'].values self.Y_array = Y_df[self.Y_cols].values.astype(self.float_type_numpy) # Free memory if not self.for_cosmology: del Y_df ################ # Input images # ################ # Set some metadata img_path = glob.glob(os.path.join(self.dataset_dir, '*.npy'))[0] img = np.load(img_path) self.X_dim = img.shape[0] # Rescale pixels, stack filters, and shift/scale pixels on the fly if rescale_pixels_type == 'rescale_01': rescale = transforms.Lambda(rescale_01) else: rescale = transforms.Lambda(whiten_pixels) log = transforms.Lambda(plus_1_log) transforms_list = [] if self.log_pixels: transforms_list.append(log) if self.rescale_pixels: transforms_list.append(rescale) if len(transforms_list) == 0: self.X_transform = lambda x: x else: self.X_transform = transforms.Compose(transforms_list) # Noise-related kwargs self.noise_kwargs = {} self.noiseless_exposure_time = {} self.noise_model = {} self.exposure_time_factor = np.ones([len(self.bandpass_list), 1, 1]) # for broadcasting for i, bp in enumerate(self.bandpass_list): survey_object = self.baobab_cfg.survey_object_dict[bp] # Dictionary of SingleBand kwargs self.noise_kwargs[bp] = survey_object.kwargs_single_band() # Factor of effective exptime relative to exptime of the noiseless images self.exposure_time_factor[i, :, :] = self.eff_exposure_time[bp]/self.noise_kwargs[bp]['exposure_time'] if self.add_pixel_noise: self.noise_kwargs[bp].update(exposure_time=self.eff_exposure_time[bp]) # Dictionary of noise models self.noise_model[bp] = NoiseModelTorch(**self.noise_kwargs[bp])
[docs] def __getitem__(self, index): # Image X img_filename = self.img_filenames[index] img_path = os.path.join(self.dataset_dir, img_filename) img = np.load(img_path) img *= self.exposure_time_factor img = torch.as_tensor(img.astype(self.float_type_numpy)) # np array type must match with default tensor type if self.add_pixel_noise: for i, bp in enumerate(self.bandpass_list): img[i, :, :] += self.noise_model[bp].get_noise_map(img[i, :, :]) img = self.X_transform(img) # Label Y Y_row = self.Y_array[index, :] Y_row = torch.as_tensor(Y_row) return img, Y_row
[docs] def __len__(self): return self.Y_array.shape[0]