inference_manager#

Class managing the model inference

Module Contents#

Classes#

InferenceManager

class inference_manager.InferenceManager(device_type, checkpoint_dir, out_dir, seed=123)[source]#
property include_los[source]#

Indices to include in inference. Useful when there are faulty examples in the test set you want to exclude.

property n_test[source]#
property bnn_kappa_path[source]#
property true_train_kappa_path[source]#
property train_summary_stats_path[source]#
property true_test_kappa_path[source]#
property test_summary_stats_path[source]#
property matching_dir[source]#
property log_p_k_given_omega_int_path[source]#
property reweighted_grid_dir[source]#
property reweighted_per_sample_dir[source]#
property reweighted_bnn_kappa_grid_path[source]#
property reweighted_bnn_kappa_per_sample_path[source]#
property pre_reweighting_metrics_path[source]#
property pre_reweighting_metrics[source]#
property post_reweighting_metrics_path[source]#
property post_reweighting_metrics[source]#
property post_reweighting_metrics_grid_path[source]#
property post_reweighting_metrics_grid[source]#
seed_everything()[source]#

Seed the training and sampling for reproducibility

load_dataset(data_kwargs, is_train, batch_size, sub_features=None, sub_target=None, sub_target_local=None, rebin=False, num_workers=2, noise_kwargs={'mag': {'override_kwargs': None, 'depth': 5}}, detection_kwargs={})[source]#

Load dataset and dataloader for training or validation

Note#

Should be called for training data first, to set the normalizing stats used for both training and validation!

set_valtest_loading(sub_idx)[source]#

Set the loading options for val/test set. Should be called whenever there are changes to the test dataset, to update the dataloader.

Parameters#

subsample_pdf_funccallable

Description

sub_idxTYPE

Description

configure_model(model_name, model_kwargs={})[source]#
load_state(state_path)[source]#

Load the state dict of the past training

Parameters#

state_pathstr or os.path object

path of the state dict to load

get_bnn_kappa(n_samples=50, n_mc_dropout=20, flatten=True)[source]#

Get the samples from the BNN

Parameters#

n_samplesint

number of samples per MC iterate

n_mc_dropoutint

number of MC iterates

Returns#

np.array of shape [n_test, self.Y_dim, n_samples*n_mc_dropout]

delete_previous()[source]#

Delete previously stored files related to the test set and inference results, while leaving any training-set related caches, which take longer to generate.

get_true_kappa(is_train, compute_summary=True, save=True)[source]#

Fetch true kappa (for train/val/test)

Parameters#

is_trainbool

Whether to get true kappas for train (test otherwise)

compute_summarybool, optional

Whether to compute summary stats in the loop

savebool, optional

Whether to store the kappa to disk

Returns#

np.ndarray

true kappas of shape [n_data, Y_dim]

get_summary_stats(thresholds, interim_pdf_func=None, match=True, min_matches=1000, k_max=np.inf)[source]#

Save accepted samples from summary statistics matching

Parameters#

thresholdsdict

Matching thresholds for summary stats Keys should be one or both of ‘N’ and ‘N_inv_dist’.

get_log_p_k_given_omega_int(n_samples, n_mc_dropout, interim_pdf_func)[source]#

Compute log(p_k|Omega_int) for BNN samples p_k

Parameters#

n_samplesint

Number of BNN samples per MC iterate per sightline

n_mc_dropoutint

Number of MC dropout iterates per sightline

interim_pdf_funccallable

Function that evaluates the PDF of the interim prior

Returns#

np.ndarray

Probabilities log(p_k|Omega_int) of shape [n_test, n_mc_dropout*n_samples]

get_log_p_k_given_omega_int_loop(interim_pdf_func, bnn=False, ss_name='N', k_max=np.inf)[source]#

Compute log(p_k|Omega_int) for BNN or summary stats samples p_k. Useful when the number of samples differs across sightlines, so the computation is not trivially vectorizable.

Parameters#

interim_pdf_funccallable

Function that evaluates the PDF of the interim prior

bnnbool, optional

Whether the samples are BNN’s. If False, understood to be summary stats matched samples.

ss_namestr, optional

Summary stats name. Only used if bnn is False. Default: ‘N’

get_log_p_k_given_omega_int_per_los(i, samples_i, interim_pdf_func, ss_name='N')[source]#

Compute log(p_k|Omega_int) for BNN or summary stats samples p_k. Useful when the number of samples differs across sightlines, so the computation is not trivially vectorizable.

Parameters#

iint

ID of sightline

samples_inp.ndarray

Matched posterior samples for this sightline

interim_pdf_funccallable

Function that evaluates the PDF of the interim prior

ss_namestr, optional

Summary stats name. Only used if bnn is False. Default: ‘N’

run_mcmc_for_omega_post(n_samples, n_mc_dropout, mcmc_kwargs, interim_pdf_func, bounds_lower=-np.inf, bounds_upper=np.inf)[source]#

Run EMCEE to obtain the posterior on test hyperparams, omega

Parameters#

n_samplesint

Number of BNN samples per MC iterate per sightline

n_mc_dropoutint

Number of MC dropout iterates

mcmc_kwargsdict

Config going into infer_utils.run_mcmc

bounds_lowernp.ndarray or float, optional

Lower bound for target quantities

bounds_uppernp.ndarray or float, optional

Upper bound for target quantities

run_mcmc_for_omega_post_summary_stats(ss_name, mcmc_kwargs, interim_pdf_func, bounds_lower=-np.inf, bounds_upper=np.inf)[source]#

Run EMCEE to obtain the posterior on test hyperparams, omega using the matched summary statistics samples, rather than BNN posterior samples

Parameters#

ss_namestr

What kind of summary stats to query (one of ‘N’, ‘N_inv_dist’)

mcmc_kwargsdict

Config going into infer_utils.run_mcmc

bounds_lowernp.ndarray or float, optional

Lower bound for target quantities

bounds_uppernp.ndarray or float, optional

Upper bound for target quantities

get_kappa_log_weights(idx, n_samples=None, n_mc_dropout=None, interim_pdf_func=None, grid=None)[source]#

Get log weights for reweighted kappa posterior per sample

Parameters#

idxint

Index of sightline in test set

n_samplesint

Number of samples per dropout, for getting kappa samples. (May be overridden with what was used previously, if kappa samples were already drawn and stored)

n_mc_dropoutint

Number of dropout iterates, for getting kappa samples. (May be overridden with what was used previously, if kappa samples were already drawn and stored)

interim_pdf_funccallable

Function that returns the density of the interim prior

gridNone, optional

Unused but kept for consistency with get_kappa_log_weigths_grid

Returns#

np.ndarray

log weights for each of the BNN samples for this sightline

get_kappa_log_weights_grid(idx, grid=None, n_samples=None, n_mc_dropout=None, interim_pdf_func=None)[source]#

Get log weights for reweighted kappa posterior, analytically on a grid

Parameters#

idxint

Index of sightline in test set

gridnp.ndarray, optional

Grid of kappa values at which to evaluate log weights (May be overridden with what was used previously, if kappa samples were already drawn and stored)

n_samplesint, optional

Number of samples per dropout, for getting kappa samples. (May be overridden with what was used previously, if kappa samples were already drawn and stored)

n_mc_dropoutint, optional

Number of dropout iterates, for getting kappa samples. (May be overridden with what was used previously, if kappa samples were already drawn and stored)

interim_pdf_funccallable, optional

Function that returns the density of the interim prior

Note#

log doesn’t help with numerical stability since we divide probabilities directly, but we’re keeping this just for consistency

Returns#

np.ndarray

kappa grid, log weights for each of the BNN samples for this sightline

get_reweighted_bnn_kappa(n_resamples, grid_kappa_kwargs, k_max=None)[source]#

Get the reweighted BNN kappa samples, reweighted either on a grid or per sample

Parameters#

n_resamplesint

Number of resamples from the reweighted distribution

grid_kappa_kwargsdict

Kwargs for

Returns#

tuple

Two arrays of shape [n_test, 1, n_resamples], first of which is resamples using the grid reweighting and second of which is resamples using the per-sample reweighting

get_omega_samples(chain_path, chain_kwargs, log_idx=None)[source]#
visualize_omega_post(chain_path, chain_kwargs, corner_kwargs, log_idx=None)[source]#
visualize_kappa_post(idx, n_samples, n_mc_dropout, interim_pdf_func, grid=None)[source]#
compute_metrics()[source]#

Evaluate metrics for model selection, based on per-sample reweighting for fair comparison to summary stats metrics

# TODO: move to separate helper module for pre, post, post-grid

get_calibration_plot(k_bnn)[source]#

Plot calibration (should be run on the validation set)

Parameters#

k_bnnnp.ndarray

Reweighted BNN samples, of shape [n_test, Y_dim, n_samples]