inference_manager
#
Class managing the model inference
Module Contents#
Classes#
- class inference_manager.InferenceManager(device_type, checkpoint_dir, out_dir, seed=123)[source]#
- property include_los[source]#
Indices to include in inference. Useful when there are faulty examples in the test set you want to exclude.
- load_dataset(data_kwargs, is_train, batch_size, sub_features=None, sub_target=None, sub_target_local=None, rebin=False, num_workers=2, noise_kwargs={'mag': {'override_kwargs': None, 'depth': 5}}, detection_kwargs={})[source]#
Load dataset and dataloader for training or validation
Note#
Should be called for training data first, to set the normalizing stats used for both training and validation!
- set_valtest_loading(sub_idx)[source]#
Set the loading options for val/test set. Should be called whenever there are changes to the test dataset, to update the dataloader.
Parameters#
- subsample_pdf_funccallable
Description
- sub_idxTYPE
Description
- load_state(state_path)[source]#
Load the state dict of the past training
Parameters#
- state_pathstr or os.path object
path of the state dict to load
- get_bnn_kappa(n_samples=50, n_mc_dropout=20, flatten=True)[source]#
Get the samples from the BNN
Parameters#
- n_samplesint
number of samples per MC iterate
- n_mc_dropoutint
number of MC iterates
Returns#
np.array of shape [n_test, self.Y_dim, n_samples*n_mc_dropout]
- delete_previous()[source]#
Delete previously stored files related to the test set and inference results, while leaving any training-set related caches, which take longer to generate.
- get_true_kappa(is_train, compute_summary=True, save=True)[source]#
Fetch true kappa (for train/val/test)
Parameters#
- is_trainbool
Whether to get true kappas for train (test otherwise)
- compute_summarybool, optional
Whether to compute summary stats in the loop
- savebool, optional
Whether to store the kappa to disk
Returns#
- np.ndarray
true kappas of shape [n_data, Y_dim]
- get_summary_stats(thresholds, interim_pdf_func=None, match=True, min_matches=1000, k_max=np.inf)[source]#
Save accepted samples from summary statistics matching
Parameters#
- thresholdsdict
Matching thresholds for summary stats Keys should be one or both of ‘N’ and ‘N_inv_dist’.
- get_log_p_k_given_omega_int(n_samples, n_mc_dropout, interim_pdf_func)[source]#
Compute log(p_k|Omega_int) for BNN samples p_k
Parameters#
- n_samplesint
Number of BNN samples per MC iterate per sightline
- n_mc_dropoutint
Number of MC dropout iterates per sightline
- interim_pdf_funccallable
Function that evaluates the PDF of the interim prior
Returns#
- np.ndarray
Probabilities log(p_k|Omega_int) of shape [n_test, n_mc_dropout*n_samples]
- get_log_p_k_given_omega_int_loop(interim_pdf_func, bnn=False, ss_name='N', k_max=np.inf)[source]#
Compute log(p_k|Omega_int) for BNN or summary stats samples p_k. Useful when the number of samples differs across sightlines, so the computation is not trivially vectorizable.
Parameters#
- interim_pdf_funccallable
Function that evaluates the PDF of the interim prior
- bnnbool, optional
Whether the samples are BNN’s. If False, understood to be summary stats matched samples.
- ss_namestr, optional
Summary stats name. Only used if bnn is False. Default: ‘N’
- get_log_p_k_given_omega_int_per_los(i, samples_i, interim_pdf_func, ss_name='N')[source]#
Compute log(p_k|Omega_int) for BNN or summary stats samples p_k. Useful when the number of samples differs across sightlines, so the computation is not trivially vectorizable.
Parameters#
- iint
ID of sightline
- samples_inp.ndarray
Matched posterior samples for this sightline
- interim_pdf_funccallable
Function that evaluates the PDF of the interim prior
- ss_namestr, optional
Summary stats name. Only used if bnn is False. Default: ‘N’
- run_mcmc_for_omega_post(n_samples, n_mc_dropout, mcmc_kwargs, interim_pdf_func, bounds_lower=-np.inf, bounds_upper=np.inf)[source]#
Run EMCEE to obtain the posterior on test hyperparams, omega
Parameters#
- n_samplesint
Number of BNN samples per MC iterate per sightline
- n_mc_dropoutint
Number of MC dropout iterates
- mcmc_kwargsdict
Config going into infer_utils.run_mcmc
- bounds_lowernp.ndarray or float, optional
Lower bound for target quantities
- bounds_uppernp.ndarray or float, optional
Upper bound for target quantities
- run_mcmc_for_omega_post_summary_stats(ss_name, mcmc_kwargs, interim_pdf_func, bounds_lower=-np.inf, bounds_upper=np.inf)[source]#
Run EMCEE to obtain the posterior on test hyperparams, omega using the matched summary statistics samples, rather than BNN posterior samples
Parameters#
- ss_namestr
What kind of summary stats to query (one of ‘N’, ‘N_inv_dist’)
- mcmc_kwargsdict
Config going into infer_utils.run_mcmc
- bounds_lowernp.ndarray or float, optional
Lower bound for target quantities
- bounds_uppernp.ndarray or float, optional
Upper bound for target quantities
- get_kappa_log_weights(idx, n_samples=None, n_mc_dropout=None, interim_pdf_func=None, grid=None)[source]#
Get log weights for reweighted kappa posterior per sample
Parameters#
- idxint
Index of sightline in test set
- n_samplesint
Number of samples per dropout, for getting kappa samples. (May be overridden with what was used previously, if kappa samples were already drawn and stored)
- n_mc_dropoutint
Number of dropout iterates, for getting kappa samples. (May be overridden with what was used previously, if kappa samples were already drawn and stored)
- interim_pdf_funccallable
Function that returns the density of the interim prior
- gridNone, optional
Unused but kept for consistency with get_kappa_log_weigths_grid
Returns#
- np.ndarray
log weights for each of the BNN samples for this sightline
- get_kappa_log_weights_grid(idx, grid=None, n_samples=None, n_mc_dropout=None, interim_pdf_func=None)[source]#
Get log weights for reweighted kappa posterior, analytically on a grid
Parameters#
- idxint
Index of sightline in test set
- gridnp.ndarray, optional
Grid of kappa values at which to evaluate log weights (May be overridden with what was used previously, if kappa samples were already drawn and stored)
- n_samplesint, optional
Number of samples per dropout, for getting kappa samples. (May be overridden with what was used previously, if kappa samples were already drawn and stored)
- n_mc_dropoutint, optional
Number of dropout iterates, for getting kappa samples. (May be overridden with what was used previously, if kappa samples were already drawn and stored)
- interim_pdf_funccallable, optional
Function that returns the density of the interim prior
Note#
log doesn’t help with numerical stability since we divide probabilities directly, but we’re keeping this just for consistency
Returns#
- np.ndarray
kappa grid, log weights for each of the BNN samples for this sightline
- get_reweighted_bnn_kappa(n_resamples, grid_kappa_kwargs, k_max=None)[source]#
Get the reweighted BNN kappa samples, reweighted either on a grid or per sample
Parameters#
- n_resamplesint
Number of resamples from the reweighted distribution
- grid_kappa_kwargsdict
Kwargs for
Returns#
- tuple
Two arrays of shape [n_test, 1, n_resamples], first of which is resamples using the grid reweighting and second of which is resamples using the per-sample reweighting