"""Class managing the model inference
"""
import os
import os.path as osp
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
import scipy.stats
import torch
import torchvision.transforms as transforms
from torch.utils.data.sampler import WeightedRandomSampler
from torch_geometric.data import DataLoader
from n2j.trainval_data.graphs.cosmodc2_graph import CosmoDC2Graph
import n2j.models as models
import n2j.inference.infer_utils as iutils
import matplotlib.pyplot as plt
import corner
from n2j.trainval_data.utils.transform_utils import (ComposeXYLocal,
Metadata,
Standardizer,
Slicer,
MagErrorSimulatorTorch,
Rejector,
get_bands_in_x,
get_idx)
import n2j.inference.summary_stats_baseline as ssb
import n2j.inference.calibration as calib
[docs]class InferenceManager:
def __init__(self, device_type, checkpoint_dir, out_dir, seed=123):
"""Inference tool
Parameters
----------
device_type : str
checkpoint_dir : os.path or str
training checkpoint_dir (same as one used to instantiate `Trainer`)
out_dir : os.path or str
output directory for inference results
"""
self.device_type = device_type
self.device = torch.device(self.device_type)
self.seed = seed
self.seed_everything()
self.checkpoint_dir = checkpoint_dir
os.makedirs(self.checkpoint_dir, exist_ok=True)
self.out_dir = out_dir
os.makedirs(self.out_dir, exist_ok=True)
self._include_los = slice(None) # do not exclude los from inference
[docs] def seed_everything(self):
"""Seed the training and sampling for reproducibility
"""
np.random.seed(self.seed)
random.seed(self.seed)
torch.manual_seed(self.seed)
torch.cuda.manual_seed(self.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
[docs] def load_dataset(self, data_kwargs, is_train, batch_size,
sub_features=None, sub_target=None, sub_target_local=None,
rebin=False, num_workers=2,
noise_kwargs={'mag': {'override_kwargs': None,
'depth': 5}},
detection_kwargs={}):
"""Load dataset and dataloader for training or validation
Note
----
Should be called for training data first, to set the normalizing
stats used for both training and validation!
"""
self.num_workers = num_workers
if is_train:
self.batch_size = batch_size
else:
self.val_batch_size = batch_size
# X metadata
features = data_kwargs['features']
self.sub_features = sub_features if sub_features else features
self.X_dim = len(self.sub_features)
# Global y metadata
target = ['final_kappa', 'final_gamma1', 'final_gamma2']
self.sub_target = sub_target if sub_target else target
self.Y_dim = len(self.sub_target)
# Lobal y metadata
target_local = ['halo_mass', 'stellar_mass', 'redshift']
self.sub_target_local = sub_target_local if sub_target_local else target_local
self.Y_local_dim = len(self.sub_target_local)
print("Instantiating CosmoDC2Graph...")
print(data_kwargs)
dataset = CosmoDC2Graph(num_workers=self.num_workers, **data_kwargs)
############
# Training #
############
if is_train:
self.train_dataset = dataset
print("Reading train stats...")
if osp.exists(osp.join(self.checkpoint_dir, 'stats.pt')):
stats = torch.load(osp.join(self.checkpoint_dir, 'stats.pt'))
else:
stats = self.train_dataset.data_stats
torch.save(stats, osp.join(self.checkpoint_dir, 'stats.pt'))
print("Configuring transforms...")
# Transforming X
idx = get_idx(features, self.sub_features)
self.X_mean = stats['X_mean'][:, idx]
self.X_std = stats['X_std'][:, idx]
slicing = Slicer(idx)
mag_idx, which_bands = get_bands_in_x(self.sub_features)
print(f"Mag errors added to {which_bands}")
magerr = MagErrorSimulatorTorch(mag_idx=mag_idx,
which_bands=which_bands,
**noise_kwargs['mag'])
magcut = Rejector(self.sub_features, **detection_kwargs)
norming = Standardizer(self.X_mean, self.X_std)
editing_X_meta = Metadata(self.sub_features, ['ra_true', 'dec_true'])
norming_X_meta = Standardizer(stats['X_meta_mean'],
stats['X_meta_std'])
# Transforming local Y
idx_Y_local = get_idx(target_local, self.sub_target_local)
self.Y_local_mean = stats['Y_local_mean'][:, idx_Y_local]
self.Y_local_std = stats['Y_local_std'][:, idx_Y_local]
slicing_Y_local = Slicer(idx_Y_local)
norming_Y_local = Standardizer(self.Y_local_mean,
self.Y_local_std)
# TODO: normalization is based on pre-magcut population
self.transform_X_Y_local = ComposeXYLocal([slicing, magerr],
[slicing_Y_local],
[magcut],
[norming],
[norming_Y_local],
[editing_X_meta, norming_X_meta])
# Transforming global Y
idx_Y = get_idx(target, self.sub_target)
self.Y_mean = stats['Y_mean'][:, idx_Y]
self.Y_std = stats['Y_std'][:, idx_Y]
slicing_Y = Slicer(idx_Y)
norming_Y = Standardizer(self.Y_mean, self.Y_std)
self.transform_Y = transforms.Compose([slicing_Y, norming_Y])
self.train_dataset.transform_X_Y_local = self.transform_X_Y_local
self.train_dataset.transform_Y = self.transform_Y
# Loading option 1: Subsample from a distribution
if data_kwargs['subsample_pdf_func'] is not None:
print("Configuring subsampling...")
self.class_weight = None
train_subset = torch.utils.data.Subset(self.train_dataset,
stats['subsample_idx'])
self.train_dataset = train_subset
self.train_loader = DataLoader(self.train_dataset,
batch_size=batch_size,
shuffle=False, # no need here
num_workers=self.num_workers,
drop_last=True)
else:
# Loading option 2: Over/undersample according to inverse frequency
if rebin:
self.class_weight = stats['class_weight']
sampler = WeightedRandomSampler(stats['y_weight'],
num_samples=len(self.train_dataset))
self.train_loader = DataLoader(self.train_dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=self.num_workers,
drop_last=True)
# Loading option 3: No special sampling, just shuffle
else:
self.class_weight = None
self.train_loader = DataLoader(self.train_dataset,
batch_size=batch_size,
shuffle=False, # no need here
num_workers=self.num_workers,
drop_last=True)
print(f"Train dataset size: {len(self.train_dataset)}")
###################
# Validation/Test #
###################
else:
self.test_dataset = dataset
# Compute or retrieve stats necessary for resampling
# before setting any kind of transforms
# Note: stats_test.pt is in inference out_dir, not checkpoint_dir
if data_kwargs['subsample_pdf_func'] is not None:
stats_test_path = osp.join(self.out_dir, 'stats_test.pt')
if osp.exists(stats_test_path):
stats_test = torch.load(stats_test_path)
else:
stats_test = self.test_dataset.data_stats_valtest
torch.save(stats_test, stats_test_path)
self.test_dataset.transform_X_Y_local = self.transform_X_Y_local
self.test_dataset.transform_Y = self.transform_Y
self.set_valtest_loading(stats_test['subsample_idx'])
print(f"Test dataset size: {len(self.test_dataset)}")
[docs] def set_valtest_loading(self, sub_idx):
"""Set the loading options for val/test set. Should be called
whenever there are changes to the test dataset, to update the
dataloader.
Parameters
----------
subsample_pdf_func : callable
Description
sub_idx : TYPE
Description
"""
self.class_weight = None
test_subset = torch.utils.data.Subset(self.test_dataset,
sub_idx)
self.test_dataset = test_subset
self.test_loader = DataLoader(self.test_dataset,
batch_size=self.val_batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=False)
[docs] def load_state(self, state_path):
"""Load the state dict of the past training
Parameters
----------
state_path : str or os.path object
path of the state dict to load
"""
state = torch.load(state_path,
map_location=torch.device(self.device_type))
self.model.load_state_dict(state['model'])
self.model.to(self.device)
self.epoch = state['epoch']
train_loss = state['train_loss']
val_loss = state['val_loss']
print("Loaded weights at {:s}".format(state_path))
print("Epoch [{}]: TRAIN Loss: {:.4f}".format(self.epoch, train_loss))
print("Epoch [{}]: VALID Loss: {:.4f}".format(self.epoch, val_loss))
self.last_saved_val_loss = val_loss
@property
[docs] def include_los(self):
"""Indices to include in inference. Useful when there are faulty
examples in the test set you want to exclude.
"""
return self._include_los
@include_los.setter
def include_los(self, value):
if value is None:
# Do nothing
return
value = list(value)
self._include_los = value
self.set_valtest_loading(value)
max_guess = max(value)
excluded = np.arange(max_guess)[~np.isin(np.arange(max_guess),
value)]
print(f"Assuming there were {max_guess+1} sightlines in test set, "
f" now excluding indices: {excluded}")
@property
[docs] def n_test(self):
return len(self.test_dataset)
@property
[docs] def bnn_kappa_path(self):
return osp.join(self.out_dir, 'k_bnn.npy')
[docs] def get_bnn_kappa(self, n_samples=50, n_mc_dropout=20, flatten=True):
"""Get the samples from the BNN
Parameters
----------
n_samples : int
number of samples per MC iterate
n_mc_dropout : int
number of MC iterates
Returns
-------
np.array of shape `[n_test, self.Y_dim, n_samples*n_mc_dropout]`
"""
if osp.exists(self.bnn_kappa_path):
samples = np.load(self.bnn_kappa_path)
if flatten:
samples = samples.reshape([self.n_test, self.Y_dim, -1])
return samples
# Fetch precomputed Y_mean, Y_std to de-standardize samples
Y_mean = self.Y_mean.to(self.device)
Y_std = self.Y_std.to(self.device)
self.model.eval()
with torch.no_grad():
samples = np.empty([self.n_test, n_mc_dropout, n_samples, self.Y_dim])
for i, batch in enumerate(self.test_loader):
batch = batch.to(self.device)
for mc_iter in range(n_mc_dropout):
x, u = self.model(batch)
B = u.shape[0] # [this batch size]
# Get pred samples for this MC iterate
self.model.global_nll.set_trained_pred(u)
mc_samples = self.model.global_nll.sample(Y_mean,
Y_std,
n_samples)
samples[i*B: (i+1)*B, mc_iter, :, :] = mc_samples
# Transpose dims to get [n_test, Y_dim, n_mc_dropout, n_samples]
samples = samples.transpose(0, 3, 1, 2)
np.save(self.bnn_kappa_path, samples)
if flatten:
samples = samples.reshape([self.n_test, self.Y_dim, -1])
return samples
@property
[docs] def true_train_kappa_path(self):
return osp.join(self.out_dir, 'k_train.npy')
@property
[docs] def train_summary_stats_path(self):
return osp.join(self.out_dir, 'summary_stats_train.npy')
@property
[docs] def true_test_kappa_path(self):
return osp.join(self.out_dir, 'k_test.npy')
@property
[docs] def test_summary_stats_path(self):
return osp.join(self.out_dir, 'summary_stats_test.npy')
@property
[docs] def matching_dir(self):
return osp.join(self.out_dir, 'matching')
@property
[docs] def log_p_k_given_omega_int_path(self):
return osp.join(self.out_dir, 'log_p_k_given_omega_int.npy')
@property
[docs] def reweighted_grid_dir(self):
return osp.join(self.out_dir, 'reweighted_grid')
@property
[docs] def reweighted_per_sample_dir(self):
return osp.join(self.out_dir, 'reweighted_per_sample')
@property
[docs] def reweighted_bnn_kappa_grid_path(self):
return osp.join(self.reweighted_grid_dir,
'k_bnn_reweighted_grid.npy')
@property
[docs] def reweighted_bnn_kappa_per_sample_path(self):
return osp.join(self.reweighted_per_sample_dir,
'k_bnn_reweighted_per_sample.npy')
[docs] def delete_previous(self):
"""Delete previously stored files related to the test set and
inference results, while leaving any training-set related caches,
which take longer to generate.
"""
import shutil
files = [self.true_test_kappa_path, self.test_summary_stats_path]
files += [self.bnn_kappa_path, self.log_p_k_given_omega_int_path]
files += [self.reweighted_bnn_kappa_grid_path]
files += [self.reweighted_bnn_kappa_per_sample_path]
for f in files:
if osp.exists(f):
print(f"Deleting {f}...")
os.remove(f)
dirs = [self.matching_dir]
dirs += [self.reweighted_grid_dir, self.reweighted_per_sample_dir]
for d in dirs:
if osp.exists(d):
print(f"Deleting {d} and all its contents...")
shutil.rmtree(d)
[docs] def get_true_kappa(self, is_train,
compute_summary=True, save=True):
"""Fetch true kappa (for train/val/test)
Parameters
----------
is_train : bool
Whether to get true kappas for train (test otherwise)
compute_summary : bool, optional
Whether to compute summary stats in the loop
save : bool, optional
Whether to store the kappa to disk
Returns
-------
np.ndarray
true kappas of shape `[n_data, Y_dim]`
"""
# Decide which dataset we're collecting kappa labels for
if is_train:
path = self.true_train_kappa_path
ss_path = self.train_summary_stats_path
else:
path = self.true_test_kappa_path
ss_path = self.test_summary_stats_path
if osp.exists(path):
if compute_summary and osp.exists(ss_path):
true_kappa = np.load(path)
return true_kappa
if is_train:
n_data = len(self.train_dataset)
loader = self.train_loader
else:
n_data = self.n_test
loader = self.test_loader
print(f"Saving {path}...")
# Fetch precomputed Y_mean, Y_std to de-standardize samples
Y_mean = self.Y_mean.to(self.device)
Y_std = self.Y_std.to(self.device)
if compute_summary:
pos_indices = get_idx(self.sub_features,
['ra_true', 'dec_true'])
ss_obj = ssb.SummaryStats(n_data, pos_indices)
# Init empty array
true_kappa = np.empty([n_data, self.Y_dim])
with torch.no_grad():
# Populate `true_kappa` by batches
for i, batch in enumerate(loader):
# Update summary stats using CPU batch
if compute_summary:
ss_obj.update(batch, i)
batch = batch.to(self.device)
B = batch.y.shape[0] # [this batch size]ss_obj
true_kappa[i*B: (i+1)*B, :] = (batch.y*Y_std + Y_mean).cpu().numpy()
if save:
np.save(path, true_kappa)
if compute_summary:
ss_obj.export_stats(ss_path)
return true_kappa
[docs] def get_summary_stats(self, thresholds, interim_pdf_func=None,
match=True, min_matches=1000,
k_max=np.inf):
"""Save accepted samples from summary statistics matching
Parameters
----------
thresholds : dict
Matching thresholds for summary stats
Keys should be one or both of 'N' and 'N_inv_dist'.
"""
train_k = self.get_true_kappa(is_train=True,
compute_summary=True)
test_k = self.get_true_kappa(is_train=False,
compute_summary=True)
pos_indices = get_idx(self.sub_features,
['ra_true', 'dec_true'])
train_ss_obj = ssb.SummaryStats(len(self.train_dataset),
pos_indices)
train_ss_obj.set_stats(self.train_summary_stats_path)
test_ss_obj = ssb.SummaryStats(len(self.test_dataset),
pos_indices)
test_ss_obj.set_stats(self.test_summary_stats_path)
self.matcher = ssb.Matcher(train_ss_obj, test_ss_obj,
train_k,
self.matching_dir,
test_k)
if match:
self.matcher.match_summary_stats(thresholds, interim_pdf_func,
min_matches=min_matches,
k_max=k_max)
overview = self.matcher.get_overview_table()
return overview
[docs] def get_log_p_k_given_omega_int(self, n_samples, n_mc_dropout,
interim_pdf_func):
"""Compute log(p_k|Omega_int) for BNN samples p_k
Parameters
----------
n_samples : int
Number of BNN samples per MC iterate per sightline
n_mc_dropout : int
Number of MC dropout iterates per sightline
interim_pdf_func : callable
Function that evaluates the PDF of the interim prior
Returns
-------
np.ndarray
Probabilities log(p_k|Omega_int) of
shape `[n_test, n_mc_dropout*n_samples]`
"""
if osp.exists(self.log_p_k_given_omega_int_path):
return np.load(self.log_p_k_given_omega_int_path)
# k_train = self.get_true_kappa(is_train=True).squeeze(1)
k_bnn = self.get_bnn_kappa(n_samples=n_samples,
n_mc_dropout=n_mc_dropout).squeeze(1)
log_p_k_given_omega_int = iutils.get_log_p_k_given_omega_int_analytic(k_train=None, #k_train,
k_bnn=k_bnn,
interim_pdf_func=interim_pdf_func)
np.save(self.log_p_k_given_omega_int_path, log_p_k_given_omega_int)
return log_p_k_given_omega_int
[docs] def get_log_p_k_given_omega_int_loop(self, interim_pdf_func, bnn=False,
ss_name='N', k_max=np.inf):
"""Compute log(p_k|Omega_int) for BNN or summary stats samples p_k.
Useful when the number of samples differs across sightlines, so
the computation is not trivially vectorizable.
Parameters
----------
interim_pdf_func : callable
Function that evaluates the PDF of the interim prior
bnn : bool, optional
Whether the samples are BNN's.
If False, understood to be summary stats matched samples.
ss_name : str, optional
Summary stats name. Only used if `bnn` is False.
Default: 'N'
"""
sample_type = 'bnn' if bnn else 'ss'
if bnn:
raise NotImplementedError("Use the vectorized version for BNN!")
path = osp.join(self.matching_dir,
f'log_p_k_given_omega_int_{ss_name}_list.npy')
if osp.exists(path):
return np.load(path, allow_pickle=True)
log_p_k_given_omega_int_list = []
for i in range(self.n_test):
samples_i = self.matcher.get_samples(idx=i, ss_name=ss_name,
threshold=None)
samples_i = samples_i[samples_i < k_max]
samples_i = samples_i.reshape([1, -1]) # artificial n_test of 1
# TODO: use get_log_p_k_given_omega_int_per_los
log_p_i = iutils.get_log_p_k_given_omega_int_analytic(k_train=None,
k_bnn=samples_i,
interim_pdf_func=interim_pdf_func)
# log_p_i ~ [1, len(samples_i)] so squeeze
log_p_k_given_omega_int_list.append(log_p_i.squeeze())
np.save(path, log_p_k_given_omega_int_list, allow_pickle=True)
return log_p_k_given_omega_int_list
[docs] def get_log_p_k_given_omega_int_per_los(self, i, samples_i, interim_pdf_func,
ss_name='N'):
"""Compute log(p_k|Omega_int) for BNN or summary stats samples p_k.
Useful when the number of samples differs across sightlines, so
the computation is not trivially vectorizable.
Parameters
----------
i : int
ID of sightline
samples_i : np.ndarray
Matched posterior samples for this sightline
interim_pdf_func : callable
Function that evaluates the PDF of the interim prior
ss_name : str, optional
Summary stats name. Only used if `bnn` is False.
Default: 'N'
"""
samples_i = samples_i.reshape([1, -1]) # artificial n_test of 1
log_p_i = iutils.get_log_p_k_given_omega_int_analytic(k_train=None,
k_bnn=samples_i,
interim_pdf_func=interim_pdf_func)
return log_p_i.squeeze() # log_p_i ~ [1, len(samples_i)] so squeeze
[docs] def run_mcmc_for_omega_post(self, n_samples, n_mc_dropout,
mcmc_kwargs, interim_pdf_func,
bounds_lower=-np.inf, bounds_upper=np.inf):
"""Run EMCEE to obtain the posterior on test hyperparams, omega
Parameters
----------
n_samples : int
Number of BNN samples per MC iterate per sightline
n_mc_dropout : int
Number of MC dropout iterates
mcmc_kwargs : dict
Config going into `infer_utils.run_mcmc`
bounds_lower : np.ndarray or float, optional
Lower bound for target quantities
bounds_upper : np.ndarray or float, optional
Upper bound for target quantities
"""
k_bnn = self.get_bnn_kappa(n_samples=n_samples,
n_mc_dropout=n_mc_dropout)
log_p_k_given_omega_int = self.get_log_p_k_given_omega_int(n_samples,
n_mc_dropout,
interim_pdf_func)
iutils.get_omega_post(k_bnn, log_p_k_given_omega_int, mcmc_kwargs,
bounds_lower, bounds_upper)
[docs] def run_mcmc_for_omega_post_summary_stats(self, ss_name,
mcmc_kwargs,
interim_pdf_func,
bounds_lower=-np.inf,
bounds_upper=np.inf):
"""Run EMCEE to obtain the posterior on test hyperparams, omega
using the matched summary statistics samples, rather than BNN
posterior samples
Parameters
----------
ss_name : str
What kind of summary stats to query (one of 'N', 'N_inv_dist')
mcmc_kwargs : dict
Config going into `infer_utils.run_mcmc`
bounds_lower : np.ndarray or float, optional
Lower bound for target quantities
bounds_upper : np.ndarray or float, optional
Upper bound for target quantities
"""
log_p_k_given_omega_int_list = []
samples = []
for i in range(self.n_test):
samples_orig_i = self.matcher.get_samples(idx=i, ss_name=ss_name,
threshold=None)
# Fit Gaussian on matched posterior samples
norm_i = scipy.stats.norm(loc=np.median(samples_orig_i),
scale=scipy.stats.median_abs_deviation(samples_orig_i, scale='normal'))
rng = np.random.RandomState(i)
samples_i = norm_i.rvs(20000, random_state=rng)
log_p_i = self.get_log_p_k_given_omega_int_per_los(i, samples_i, interim_pdf_func,
ss_name='N')
log_p_k_given_omega_int_list.append(log_p_i)
samples.append(samples_i)
iutils.get_omega_post_loop(samples, log_p_k_given_omega_int_list, mcmc_kwargs,
bounds_lower, bounds_upper)
[docs] def get_kappa_log_weights(self, idx, n_samples=None, n_mc_dropout=None,
interim_pdf_func=None, grid=None):
"""Get log weights for reweighted kappa posterior per sample
Parameters
----------
idx : int
Index of sightline in test set
n_samples : int
Number of samples per dropout, for getting kappa samples.
(May be overridden with what was used previously, if
kappa samples were already drawn and stored)
n_mc_dropout : int
Number of dropout iterates, for getting kappa samples.
(May be overridden with what was used previously, if
kappa samples were already drawn and stored)
interim_pdf_func : callable
Function that returns the density of the interim prior
grid : None, optional
Unused but kept for consistency with `get_kappa_log_weigths_grid`
Returns
-------
np.ndarray
log weights for each of the BNN samples for this sightline
"""
os.makedirs(self.reweighted_per_sample_dir, exist_ok=True)
path = osp.join(self.reweighted_per_sample_dir,
f'log_weights_{idx}.npy')
k_bnn = self.get_bnn_kappa(n_samples=n_samples,
n_mc_dropout=n_mc_dropout)
log_p_k_given_omega_int = self.get_log_p_k_given_omega_int(n_samples,
n_mc_dropout,
interim_pdf_func)
# omega_post_samples = iutils.get_mcmc_samples(chain_path, chain_kwargs)
log_weights = iutils.get_kappa_log_weights(k_bnn[idx, :],
log_p_k_given_omega_int[idx, :])
np.save(path, log_weights)
return log_weights
[docs] def get_kappa_log_weights_grid(self, idx,
grid=None,
n_samples=None,
n_mc_dropout=None,
interim_pdf_func=None):
"""Get log weights for reweighted kappa posterior, analytically
on a grid
Parameters
----------
idx : int
Index of sightline in test set
grid : np.ndarray, optional
Grid of kappa values at which to evaluate log weights
(May be overridden with what was used previously, if
kappa samples were already drawn and stored)
n_samples : int, optional
Number of samples per dropout, for getting kappa samples.
(May be overridden with what was used previously, if
kappa samples were already drawn and stored)
n_mc_dropout : int, optional
Number of dropout iterates, for getting kappa samples.
(May be overridden with what was used previously, if
kappa samples were already drawn and stored)
interim_pdf_func : callable, optional
Function that returns the density of the interim prior
Note
----
log doesn't help with numerical stability since we divide
probabilities directly, but we're keeping this just for
consistency
Returns
-------
np.ndarray
kappa grid, log weights for each of the BNN samples for
this sightline
"""
os.makedirs(self.reweighted_grid_dir, exist_ok=True)
path = osp.join(self.reweighted_grid_dir,
f'log_weights_{idx}.npy')
if osp.exists(path):
return np.load(path)
# Get unflattened, i.e. [n_test, 1, n_mc_dropout, n_samples]
k_bnn = self.get_bnn_kappa(n_samples=n_samples,
n_mc_dropout=n_mc_dropout,
flatten=False)
k_bnn = k_bnn[idx, 0, :, :] # [n_mc_dropout, n_samples]
n_mc_dropout, n_samples = k_bnn.shape
numer = np.zeros(grid.shape) # init numerator
# Fit a normal for each MC dropout
for d in range(n_mc_dropout):
samples_d = k_bnn[d, :]
norm_d = scipy.stats.norm(loc=samples_d.mean(),
scale=samples_d.std())
bnn_prob_d = norm_d.pdf(grid)
numer += (bnn_prob_d - numer)/(d+1) # running mean
# Useful for debugging
np.save(osp.join(self.reweighted_grid_dir,
f'grid_bnn_gmm_{idx}.npy'),
numer)
denom = interim_pdf_func(grid)
log_weights = np.log(numer/denom)
log_weights_grid = np.stack([grid, log_weights], axis=0)
np.save(path, log_weights_grid)
return log_weights_grid
[docs] def get_reweighted_bnn_kappa(self, n_resamples, grid_kappa_kwargs,
k_max=None,
):
"""Get the reweighted BNN kappa samples, reweighted either on a
grid or per sample
Parameters
----------
n_resamples : int
Number of resamples from the reweighted distribution
grid_kappa_kwargs : dict
Kwargs for
Returns
-------
tuple
Two arrays of shape [n_test, 1, n_resamples], first of which
is resamples using the grid reweighting and second of which
is resamples using the per-sample reweighting
"""
if osp.exists(self.reweighted_bnn_kappa_grid_path):
if osp.exists(self.reweighted_bnn_kappa_per_sample_path):
print("Reading existing reweighted BNN kappa...")
grid = np.load(self.reweighted_bnn_kappa_grid_path)
per_sample = np.load(self.reweighted_bnn_kappa_per_sample_path)
return grid, per_sample
n_test = len(self.test_dataset)
k_bnn = self.get_bnn_kappa(n_samples=grid_kappa_kwargs['n_samples'],
n_mc_dropout=grid_kappa_kwargs['n_mc_dropout'])
# Init reweighted arrays
k_reweighted_grid = np.empty([n_test, 1, n_resamples])
k_reweighted_per_sample = np.empty([n_test, 1, n_resamples])
for idx in tqdm(range(n_test), desc='evaluating, resampling'):
# On a grid
grid, log_p = self.get_kappa_log_weights_grid(idx,
**grid_kappa_kwargs)
per_grid = iutils.resample_from_pdf(grid, log_p, n_resamples)
k_reweighted_grid[idx, 0, :] = per_grid
# Per sample
log_p_sample = self.get_kappa_log_weights(idx, **grid_kappa_kwargs)
k_bnn_i = k_bnn[idx].squeeze() # [n_total_samples,]
probs_i = np.exp(log_p_sample) # [n_total_samples,]
if k_max is not None:
mask = k_bnn_i < k_max
k_bnn_i = k_bnn_i[mask]
probs_i = probs_i[mask]
plot_path = osp.join(self.reweighted_per_sample_dir, f'kde_{idx}.png')
per_sample = iutils.resample_from_samples(k_bnn_i,
probs_i,
n_resamples,
plot_path)
k_reweighted_per_sample[idx, 0, :] = per_sample
# Grid resamples for all sightlines
np.save(self.reweighted_bnn_kappa_grid_path,
k_reweighted_grid)
# Per-sample resamples for all sightlines
np.save(self.reweighted_bnn_kappa_per_sample_path,
k_reweighted_per_sample)
return k_reweighted_grid, k_reweighted_per_sample
[docs] def get_omega_samples(self, chain_path, chain_kwargs, log_idx=None):
omega_post_samples = iutils.get_mcmc_samples(chain_path, chain_kwargs)
if log_idx is not None:
omega_post_samples[:, log_idx] = np.exp(omega_post_samples[:, log_idx])
return omega_post_samples
[docs] def visualize_omega_post(self, chain_path, chain_kwargs,
corner_kwargs, log_idx=None):
# MCMC samples ~ [n_omega, 2]
omega_post_samples = self.get_omega_samples(chain_path, chain_kwargs,
log_idx=log_idx)
print(f"Plotting {omega_post_samples.shape[0]} samples...")
fig = corner.corner(omega_post_samples,
**corner_kwargs)
fig.savefig(osp.join(self.out_dir, 'omega_post.pdf'))
[docs] def visualize_kappa_post(self, idx, n_samples, n_mc_dropout,
interim_pdf_func, grid=None):
log_weights = self.get_kappa_log_weights(idx,
n_samples,
n_mc_dropout,
interim_pdf_func) # [n_samples]
grid, log_w_grid = self.get_kappa_log_weights_grid(idx,
grid,
n_samples,
n_mc_dropout,
interim_pdf_func)
w_grid = np.exp(log_w_grid)
k_bnn = self.get_bnn_kappa(n_samples=n_samples,
n_mc_dropout=n_mc_dropout) # [n_test, n_samples]
true_k = self.get_true_kappa(is_train=False)
fig, ax = plt.subplots()
# Original posterior
bins = np.histogram_bin_edges(k_bnn[idx].squeeze(), bins='scott',)
ax.hist(k_bnn[idx].squeeze(),
histtype='step',
bins=bins,
density=True,
color='#8ca252',
label='original')
# Reweighted posterior, per sample
ax.hist(k_bnn[idx].squeeze(),
histtype='step',
bins=25,
density=True,
weights=np.exp(log_weights),
color='#d6616b',
label='reweighted per sample')
# Reweighted posterior, analytical
reweighted_k_bnn, _ = self.get_reweighted_bnn_kappa(None, None)
reweighted_k_bnn = reweighted_k_bnn[idx, 0, :]
bin_vals, bin_edges = np.histogram(reweighted_k_bnn, bins='scott',
density=True)
norm_factor = np.max(bin_vals)/np.max(w_grid)
ax.plot(grid, norm_factor*w_grid,
color='#d6616b',
label='reweighted on grid')
# Truth
ax.axvline(true_k[idx].squeeze(), color='k', label='truth')
ax.set_xlabel(r'$\kappa$')
ax.legend()
@property
[docs] def pre_reweighting_metrics_path(self):
return osp.join(self.out_dir, 'pre_metrics.csv')
@property
[docs] def pre_reweighting_metrics(self):
return pd.read_csv(self.pre_reweighting_metrics_path,
index_col=False)
@property
[docs] def post_reweighting_metrics_path(self):
return osp.join(self.out_dir, 'post_metrics.csv')
@property
[docs] def post_reweighting_metrics(self):
return pd.read_csv(self.post_reweighting_metrics_path,
index_col=False)
@property
[docs] def post_reweighting_metrics_grid_path(self):
return osp.join(self.out_dir, 'post_metrics_grid.csv')
@property
[docs] def post_reweighting_metrics_grid(self):
return pd.read_csv(self.post_reweighting_metrics_grid_path,
index_col=False)
[docs] def compute_metrics(self):
"""Evaluate metrics for model selection, based on per-sample
reweighting for fair comparison to summary stats metrics
# TODO: move to separate helper module for pre, post, post-grid
"""
columns = ['minus_1sig', 'med', 'plus_1sig']
columns += ['log_p', 'mad', 'mae']
# mae = median absolute errors, robust measure of accuracy
# mad = median absolute deviation, robust measure of precision
# Metrics on pre-reweighting BNN posteriors
k_bnn_pre = self.get_bnn_kappa()
pre_metrics = pd.DataFrame(columns=columns)
# Metrics on post-reweighting BNN posteriors
k_bnn_post_grid, k_bnn_post = self.get_reweighted_bnn_kappa(None, None)
post_metrics = pd.DataFrame(columns=columns)
post_metrics_grid = pd.DataFrame(columns=columns)
# True kappa
k_test = self.get_true_kappa(is_train=False).squeeze()
n_test = len(k_test)
for i in range(n_test):
# Init rows to append
pre_stats = dict()
post_stats = dict()
post_stats_grid = dict()
# Slice samples for this sightline
pre_samples = k_bnn_pre[i, 0, :]
post_samples = k_bnn_post[i, 0, :]
post_samples_grid = k_bnn_post_grid[i, 0, :]
# Evaluate log p at truth, using KDE fit on samples
# with and without 1/prior weights
true_k = k_test[i]
log_w = self.get_kappa_log_weights(i) # per-sample log weights
grid, log_w_grid = self.get_kappa_log_weights_grid(i) # grid log weights
pre_kde = iutils.fit_kde_on_weighted_samples(pre_samples)
post_kde = iutils.fit_kde_on_weighted_samples(pre_samples,
np.exp(log_w))
post_kde_grid = iutils.fit_kde_on_weighted_samples(grid,
np.exp(log_w_grid))
pre_log_p = pre_kde.logpdf(true_k).item()
post_log_p = post_kde.logpdf(true_k).item()
post_log_p_grid = post_kde_grid.logpdf(true_k).item()
pre_stats.update(log_p=pre_log_p)
post_stats.update(log_p=post_log_p)
post_stats_grid.update(log_p=post_log_p_grid)
# Compute descriptive stats
lower, med, upper = np.quantile(pre_samples,
[0.5-0.34, 0.5, 0.5+0.34])
pre_stats.update(minus_1sig=med - lower,
med=med,
plus_1sig=upper - med,
mae=np.median(np.abs(pre_samples - true_k)),
mad=scipy.stats.median_abs_deviation(pre_samples, scale='normal'))
lower, med, upper = np.quantile(post_samples,
[0.5-0.34, 0.5, 0.5+0.34])
post_stats.update(minus_1sig=med - lower,
med=med,
plus_1sig=upper - med,
mae=np.median(np.abs(post_samples - true_k)),
mad=scipy.stats.median_abs_deviation(post_samples, scale='normal'))
lower, med, upper = np.quantile(post_samples_grid,
[0.5-0.34, 0.5, 0.5+0.34])
post_stats_grid.update(minus_1sig=med - lower,
med=med,
plus_1sig=upper - med,
mae=np.median(np.abs(post_samples_grid - true_k)),
mad=scipy.stats.median_abs_deviation(post_samples_grid, scale='normal'))
pre_metrics = pre_metrics.append(pre_stats,
ignore_index=True)
post_metrics = post_metrics.append(post_stats,
ignore_index=True)
post_metrics_grid = post_metrics_grid.append(post_stats_grid,
ignore_index=True)
# Evaluate average metrics over entire test set
pre_metrics = pre_metrics.append(pre_metrics.mean(),
ignore_index=True)
pre_metrics = pre_metrics.append(pre_metrics.median(),
ignore_index=True)
post_metrics = post_metrics.append(post_metrics.mean(),
ignore_index=True)
post_metrics = post_metrics.append(post_metrics.median(),
ignore_index=True)
post_metrics_grid = post_metrics_grid.append(post_metrics_grid.mean(),
ignore_index=True)
post_metrics_grid = post_metrics_grid.append(post_metrics_grid.median(),
ignore_index=True)
# Save as CSV
pre_metrics.to_csv(self.pre_reweighting_metrics_path,
index=False)
post_metrics.to_csv(self.post_reweighting_metrics_path,
index=False)
post_metrics_grid.to_csv(self.post_reweighting_metrics_grid_path,
index=False)
[docs] def get_calibration_plot(self, k_bnn):
"""Plot calibration (should be run on the validation set)
Parameters
----------
k_bnn : np.ndarray
Reweighted BNN samples, of shape [n_test, Y_dim, n_samples]
"""
k_bnn = np.transpose(k_bnn, [2, 0, 1]) # [n_samples, n_test, Y_dim=1]
y_mean = np.mean(k_bnn, axis=0)
k_val = self.get_true_kappa(is_train=False)
train_cov = self.Y_std.cpu().numpy()
fig = calib.plot_calibration(post_samples=k_bnn,
y_mean=y_mean,
y_truth=k_val,
cov=train_cov,
show_plot=False,
ls='--',
color_map=['tab:gray', '#880519'],
legend=['Perfect calibration',
'Dropout'])
fig.savefig(osp.join(self.out_dir, 'calibration.pdf'),
bbox_inches='tight', pad_inches=0, dpi=200)
# TODO: add docstring
# TODO: implement initialization from PSO
# TODO: implement method `visualize_kappa_post_all` comparing before vs after
# for all sightlines in test set
# TODO: implement method `visualize_learned_prior` stacking predictions
# for all sightlines in prior
# TODO: add markdown to notebook