Source code for h0rton.h0_inference.plotting_utils

import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import corner
from lenstronomy.Plots import chain_plot
from h0rton.h0_inference import h0_utils
from scipy.stats import norm#, median_absolute_deviation
from lenstronomy.LensModel.lens_model_extensions import LensModelExtensions
from lenstronomy.LensModel.Solver.lens_equation_solver import LensEquationSolver
#import lenstronomy.Util.util as util
import lenstronomy.Util.simulation_util as sim_util
from lenstronomy.Data.imaging_data import ImageData
from lenstronomy.Plots import plot_util
#import scipy.ndimage as ndimage

__all__ = ["plot_weighted_h0_histogram", 'plot_h0_histogram', "plot_D_dt_histogram", "plot_mcmc_corner", "gaussian", "plot_forward_modeling_comparisons"]

[docs]def gaussian(x, mean, standard_deviation, amplitude): return amplitude * np.exp( - ((x - mean) / standard_deviation) ** 2)
def lognormal(x, mu, sig): return np.exp(-0.5*(np.log(x) - mu)**2.0/sig**2.0)/(x*sig*(2.0*np.pi)**0.5)
[docs]def plot_weighted_h0_histogram(all_samples, all_weights, lens_i=0, true_h0=None, include_fit_gaussian=True, save_dir='.'): """Plot the histogram of H0 samples, overlaid with a Gaussian fit and truth H0 all_samples : np.array H0 samples all_weights : np.array H0 weights corresponding to `all_samples`, possibly including nan values """ stats = h0_utils.get_normal_stats_naive(all_samples, all_weights) _ = plt.hist(stats['samples'], weights=stats['weights'], bins=290, alpha=0.5, density=True, edgecolor='k', color='tab:blue', range=[10.0, 300.0]) #print(mean, std) x_interval_for_fit = np.linspace(10, 300, 1000) # Overlay the fit gaussian pdf plt.plot(x_interval_for_fit, norm.pdf(x_interval_for_fit, stats['mean'], stats['std']), color='k', label='fit: mu={:0.1f}, sig={:0.1f}'.format(stats['mean'], stats['std'])) if save_dir is not None: if true_h0 is not None: plt.axvline(x=true_h0, linestyle='--', color='red', label='truth') plt.xlabel('H0 (km/Mpc/s)') plt.ylabel('density') plt.title('H0 posterior for lens {0:04d}'.format(lens_i)) plt.legend() save_path = os.path.join(save_dir, 'h0_histogram_{0:04d}.png'.format(lens_i)) plt.savefig(save_path) plt.close() return stats
def plot_weighted_D_dt_histogram(all_samples, all_weights, lens_i=0, true_D_dt=None, save_dir='.'): """Plot the histogram of H0 samples, overlaid with a Gaussian fit and truth H0 all_samples : np.array H0 samples all_weights : np.array H0 weights corresponding to `all_samples`, possibly including nan values """ # Normalize weights to unity is_nan_mask = np.logical_or(np.isnan(all_weights), ~np.isfinite(all_weights)) all_weights[~is_nan_mask] = all_weights[~is_nan_mask]/np.sum(all_weights[~is_nan_mask]) samples = all_samples[~is_nan_mask] weights = all_weights[~is_nan_mask] bin_heights, bin_borders, _ = plt.hist(samples, weights=weights, bins=200, alpha=0.5, density=True, edgecolor='k', color='tab:blue', range=[0.0, 15000.0]) bin_centers = bin_borders[:-1] + np.diff(bin_borders) / 2 # Compute the weighted mean and std analytically lognorm_stats = h0_utils.get_lognormal_stats_naive(samples, weights) mu = lognorm_stats['mu'] sigma = lognorm_stats['sigma'] mode = lognorm_stats['mode'] std = lognorm_stats['std'] popt = [mu, sigma] #x_interval_for_fit = np.linspace(bin_borders[0], bin_borders[-1], 10000) x_interval_for_fit = np.linspace(bin_centers[0], bin_centers[-1], 1000) # Overlay the fit gaussian pdf plt.plot(x_interval_for_fit, lognormal(x_interval_for_fit, *popt), color='k', label='fit: mode={:0.1f}, std={:0.1f}'.format(mode, std)) if save_dir is not None: if true_D_dt is not None: plt.axvline(x=true_D_dt, linestyle='--', color='red', label='truth') plt.xlabel('D_dt (Mpc)') plt.ylabel('density') plt.title('D_dt posterior for lens {0:04d}'.format(lens_i)) plt.legend() save_path = os.path.join(save_dir, 'D_dt_histogram_{0:04d}.png'.format(lens_i)) plt.savefig(save_path) plt.close() return mu, sigma
[docs]def plot_h0_histogram(samples, lens_i=0, true_h0=None, include_fit_gaussian=True, save_dir='.'): """Plot the histogram of H0 samples, overlaid with a Gaussian fit and truth H0 all_samples : np.array H0 samples all_weights : np.array H0 weights corresponding to `all_samples`, possibly including nan values """ # Normalize weights to unity bin_heights, bin_borders, _ = plt.hist(samples, bins=80, alpha=0.5, density=True, edgecolor='k', color='tab:blue', range=[40.0, 100.0]) bin_centers = bin_borders[:-1] + np.diff(bin_borders) / 2 if include_fit_gaussian: # Fit a gaussian best_guess_mean = bin_centers[np.argmax(bin_heights)] popt, _ = curve_fit(gaussian, bin_centers, bin_heights, p0=[best_guess_mean, 0.3, 3.0], maxfev=10000) mean = popt[0] std = popt[1] else: # Compute the weighted mean and std analytically mean = np.median(samples) std = np.median_absolute_deviation(samples, axis=None) #print(mean, std) popt = [mean, std, 1.0/std/np.sqrt(2*np.pi)] #x_interval_for_fit = np.linspace(bin_borders[0], bin_borders[-1], 10000) x_interval_for_fit = np.linspace(bin_centers[0], bin_centers[-1], 1000) # Overlay the fit gaussian pdf plt.plot(x_interval_for_fit, gaussian(x_interval_for_fit, *popt), color='k', label='fit: mu={:0.1f}, sig={:0.1f}'.format(mean, std)) #if std < 1.0: # bin_heights, bin_borders, _ = plt.hist(samples, weights=weights, bins=80, alpha=0.5, density=True, edgecolor='k', color='tab:blue', range=[mean - 5, mean + 5]) # bin_centers = bin_borders[:-1] + np.diff(bin_borders) / 2 # best_guess_mean = bin_centers[np.argmax(bin_heights)] # popt, _ = curve_fit(gaussian, bin_centers, bin_heights, p0=[mean, 0.3, 1.0], maxfev=10000) # mean = popt[0] # std = popt[-1] #print(popt) if save_dir is not None: if true_h0 is not None: plt.axvline(x=true_h0, linestyle='--', color='red', label='truth') plt.xlabel('H0 (km/Mpc/s)') plt.ylabel('density') plt.title('H0 posterior for lens {0:04d}'.format(lens_i)) plt.legend() save_path = os.path.join(save_dir, 'h0_histogram_{0:04d}.png'.format(lens_i)) plt.savefig(save_path) plt.close() return mean, std
[docs]def plot_D_dt_histogram(all_samples, lens_i=0, true_D_dt=None, save_dir='.'): """Plot the histogram of D_dt samples, overlaid with a Gaussian fit and truth D_dt all_samples : np.array D_dt MCMC samples """ bin_heights, bin_borders, _ = plt.hist(all_samples, bins=200, alpha=0.5, density=True, edgecolor='k', color='tab:blue', range=[0.0, 15000.0]) bin_centers = bin_borders[:-1] + np.diff(bin_borders) / 2 # Compute the mode and std for lognormal lognorm_stats = h0_utils.get_lognormal_stats(all_samples) mu = lognorm_stats['mu'] sigma = lognorm_stats['sigma'] mode = lognorm_stats['mode'] std = lognorm_stats['std'] popt = [mu, sigma] #x_interval_for_fit = np.linspace(bin_borders[0], bin_borders[-1], 10000) x_interval_for_fit = np.linspace(bin_centers[0], bin_centers[-1], 1000) # Overlay the fit gaussian pdf plt.plot(x_interval_for_fit, lognormal(x_interval_for_fit, *popt), color='k', label='fit: mode={:0.1f}, std={:0.1f}'.format(mode, std)) if save_dir is not None: if true_D_dt is not None: plt.axvline(x=true_D_dt, linestyle='--', color='red', label='truth') plt.xlabel(r'$D_{{\Delta t}}$ (Mpc)') plt.ylabel('density') plt.title(r'$D_{{\Delta t}}$ posterior for lens {0:04d}'.format(lens_i)) plt.legend() save_path = os.path.join(save_dir, 'D_dt_histogram_{0:04d}.png'.format(lens_i)) plt.savefig(save_path) plt.close() return mu, sigma
def plot_mcmc_chain(chain_list_mcmc, save_path): fig, ax = chain_plot.plot_chain_list(chain_list_mcmc) fig.savefig(save_path, dpi=100) plt.close()
[docs]def plot_mcmc_corner(mcmc_samples, truth, col_labels, save_path): fig = corner.corner(mcmc_samples, truths=truth, truth_color='r', labels=col_labels, smooth=1.0, no_fill_contours=True, plot_datapoints=False, show_titles=True, quiet=True, plot_contours=True, use_math_text=True, contour_kwargs=dict(linestyles='solid'), levels=[0.68, 0.95],) fig.savefig(save_path, dpi=100) plt.close()
[docs]def plot_forward_modeling_comparisons(model_plot_instance, out_dir): """Plot the data vs. model comparisons using the Lenstronomy modelPlot tool Parameters ---------- model_plot_instance : lenstronomy.Plots.model_plot.ModelPlot object out_dir : directory in which the plots will be saved """ f, axes = model_plot_instance.plot_main() f.savefig(os.path.join(out_dir, 'main_plot_lenstronomy.png')) plt.close() f, axes = model_plot_instance.plot_separate() f.savefig(os.path.join(out_dir, 'separate_plot_lenstronomy.png')) plt.close() f, axes = model_plot_instance.plot_subtract_from_data_all() f.savefig(os.path.join(out_dir, 'subtract_plot_lenstronomy.png')) plt.close('all')
# TODO define coordinate grid beforehand, e.g. kwargs_data def lens_model_plot_custom(image, ax, lensModel, kwargs_lens, numPix=500, deltaPix=0.01, sourcePos_x=0, sourcePos_y=0, point_source=False, with_caustics=False): """ plots a lens model (convergence) and the critical curves and caustics :param ax: :param kwargs_lens: :param numPix: :param deltaPix: :return: """ kwargs_data = sim_util.data_configure_simple(numPix, deltaPix) data = ImageData(**kwargs_data) _coords = data _frame_size = numPix * deltaPix x_grid, y_grid = data.pixel_coordinates lensModelExt = LensModelExtensions(lensModel) #ra_crit_list, dec_crit_list, ra_caustic_list, dec_caustic_list = lensModelExt.critical_curve_caustics( # kwargs_lens, compute_window=_frame_size, grid_scale=deltaPix/2.) #x_grid1d = util.image2array(x_grid) #y_grid1d = util.image2array(y_grid) #kappa_result = lensModel.kappa(x_grid1d, y_grid1d, kwargs_lens) #kappa_result = util.array2image(kappa_result) #im = ax.matshow(np.log10(kappa_result), origin='lower', extent=[0, _frame_size, 0, _frame_size], cmap='Greys',vmin=-1, vmax=1) #, cmap=self._cmap, vmin=v_min, vmax=v_max) im = ax.matshow(image, origin='lower', extent=[0, _frame_size, 0, _frame_size]) if with_caustics is True: ra_crit_list, dec_crit_list = lensModelExt.critical_curve_tiling(kwargs_lens, compute_window=_frame_size, start_scale=deltaPix, max_order=20) ra_caustic_list, dec_caustic_list = lensModel.ray_shooting(ra_crit_list, dec_crit_list, kwargs_lens) plot_util.plot_line_set(ax, _coords, ra_caustic_list, dec_caustic_list, color='tab:red') plot_util.plot_line_set(ax, _coords, ra_crit_list, dec_crit_list, color='yellow') if point_source: solver = LensEquationSolver(lensModel) theta_x, theta_y = solver.image_position_from_source(sourcePos_x, sourcePos_y, kwargs_lens, min_distance=deltaPix, search_window=deltaPix*numPix) mag_images = lensModel.magnification(theta_x, theta_y, kwargs_lens) x_image, y_image = _coords.map_coord2pix(theta_x, theta_y) abc_list = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K'] for i in range(len(x_image)): x_ = (x_image[i] + 0.5) * deltaPix y_ = (y_image[i] + 0.5) * deltaPix ax.plot(x_, y_, 'dk', markersize=4*(1 + np.log(np.abs(mag_images[i]))), markerfacecolor='none') ax.text(x_+0.1, y_+0.1, abc_list[i], fontsize=15, color='white') x_source, y_source = _coords.map_coord2pix(sourcePos_x, sourcePos_y) ax.plot((x_source + 0.5) * deltaPix, (y_source + 0.5) * deltaPix, marker='*', color='tab:red', markersize=7.5) #ax.plot(numPix * deltaPix*0.5 + pred['lens_mass_center_x'] + pred['src_light_center_x'], numPix * deltaPix*0.5 + pred['lens_mass_center_y'] + pred['src_light_center_y'], '*k', markersize=5) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) ax.autoscale(False) return ax