:py:mod:`inference_manager` =========================== .. py:module:: inference_manager .. autoapi-nested-parse:: Class managing the model inference Module Contents --------------- Classes ~~~~~~~ .. autoapisummary:: inference_manager.InferenceManager .. py:class:: InferenceManager(device_type, checkpoint_dir, out_dir, seed=123) .. py:property:: include_los Indices to include in inference. Useful when there are faulty examples in the test set you want to exclude. .. py:property:: n_test .. py:property:: bnn_kappa_path .. py:property:: true_train_kappa_path .. py:property:: train_summary_stats_path .. py:property:: true_test_kappa_path .. py:property:: test_summary_stats_path .. py:property:: matching_dir .. py:property:: log_p_k_given_omega_int_path .. py:property:: reweighted_grid_dir .. py:property:: reweighted_per_sample_dir .. py:property:: reweighted_bnn_kappa_grid_path .. py:property:: reweighted_bnn_kappa_per_sample_path .. py:property:: pre_reweighting_metrics_path .. py:property:: pre_reweighting_metrics .. py:property:: post_reweighting_metrics_path .. py:property:: post_reweighting_metrics .. py:property:: post_reweighting_metrics_grid_path .. py:property:: post_reweighting_metrics_grid .. py:method:: seed_everything() Seed the training and sampling for reproducibility .. py:method:: load_dataset(data_kwargs, is_train, batch_size, sub_features=None, sub_target=None, sub_target_local=None, rebin=False, num_workers=2, noise_kwargs={'mag': {'override_kwargs': None, 'depth': 5}}, detection_kwargs={}) Load dataset and dataloader for training or validation Note ---- Should be called for training data first, to set the normalizing stats used for both training and validation! .. py:method:: set_valtest_loading(sub_idx) Set the loading options for val/test set. Should be called whenever there are changes to the test dataset, to update the dataloader. Parameters ---------- subsample_pdf_func : callable Description sub_idx : TYPE Description .. py:method:: configure_model(model_name, model_kwargs={}) .. py:method:: load_state(state_path) Load the state dict of the past training Parameters ---------- state_path : str or os.path object path of the state dict to load .. py:method:: get_bnn_kappa(n_samples=50, n_mc_dropout=20, flatten=True) Get the samples from the BNN Parameters ---------- n_samples : int number of samples per MC iterate n_mc_dropout : int number of MC iterates Returns ------- np.array of shape `[n_test, self.Y_dim, n_samples*n_mc_dropout]` .. py:method:: delete_previous() Delete previously stored files related to the test set and inference results, while leaving any training-set related caches, which take longer to generate. .. py:method:: get_true_kappa(is_train, compute_summary=True, save=True) Fetch true kappa (for train/val/test) Parameters ---------- is_train : bool Whether to get true kappas for train (test otherwise) compute_summary : bool, optional Whether to compute summary stats in the loop save : bool, optional Whether to store the kappa to disk Returns ------- np.ndarray true kappas of shape `[n_data, Y_dim]` .. py:method:: get_summary_stats(thresholds, interim_pdf_func=None, match=True, min_matches=1000, k_max=np.inf) Save accepted samples from summary statistics matching Parameters ---------- thresholds : dict Matching thresholds for summary stats Keys should be one or both of 'N' and 'N_inv_dist'. .. py:method:: get_log_p_k_given_omega_int(n_samples, n_mc_dropout, interim_pdf_func) Compute log(p_k|Omega_int) for BNN samples p_k Parameters ---------- n_samples : int Number of BNN samples per MC iterate per sightline n_mc_dropout : int Number of MC dropout iterates per sightline interim_pdf_func : callable Function that evaluates the PDF of the interim prior Returns ------- np.ndarray Probabilities log(p_k|Omega_int) of shape `[n_test, n_mc_dropout*n_samples]` .. py:method:: get_log_p_k_given_omega_int_loop(interim_pdf_func, bnn=False, ss_name='N', k_max=np.inf) Compute log(p_k|Omega_int) for BNN or summary stats samples p_k. Useful when the number of samples differs across sightlines, so the computation is not trivially vectorizable. Parameters ---------- interim_pdf_func : callable Function that evaluates the PDF of the interim prior bnn : bool, optional Whether the samples are BNN's. If False, understood to be summary stats matched samples. ss_name : str, optional Summary stats name. Only used if `bnn` is False. Default: 'N' .. py:method:: get_log_p_k_given_omega_int_per_los(i, samples_i, interim_pdf_func, ss_name='N') Compute log(p_k|Omega_int) for BNN or summary stats samples p_k. Useful when the number of samples differs across sightlines, so the computation is not trivially vectorizable. Parameters ---------- i : int ID of sightline samples_i : np.ndarray Matched posterior samples for this sightline interim_pdf_func : callable Function that evaluates the PDF of the interim prior ss_name : str, optional Summary stats name. Only used if `bnn` is False. Default: 'N' .. py:method:: run_mcmc_for_omega_post(n_samples, n_mc_dropout, mcmc_kwargs, interim_pdf_func, bounds_lower=-np.inf, bounds_upper=np.inf) Run EMCEE to obtain the posterior on test hyperparams, omega Parameters ---------- n_samples : int Number of BNN samples per MC iterate per sightline n_mc_dropout : int Number of MC dropout iterates mcmc_kwargs : dict Config going into `infer_utils.run_mcmc` bounds_lower : np.ndarray or float, optional Lower bound for target quantities bounds_upper : np.ndarray or float, optional Upper bound for target quantities .. py:method:: run_mcmc_for_omega_post_summary_stats(ss_name, mcmc_kwargs, interim_pdf_func, bounds_lower=-np.inf, bounds_upper=np.inf) Run EMCEE to obtain the posterior on test hyperparams, omega using the matched summary statistics samples, rather than BNN posterior samples Parameters ---------- ss_name : str What kind of summary stats to query (one of 'N', 'N_inv_dist') mcmc_kwargs : dict Config going into `infer_utils.run_mcmc` bounds_lower : np.ndarray or float, optional Lower bound for target quantities bounds_upper : np.ndarray or float, optional Upper bound for target quantities .. py:method:: get_kappa_log_weights(idx, n_samples=None, n_mc_dropout=None, interim_pdf_func=None, grid=None) Get log weights for reweighted kappa posterior per sample Parameters ---------- idx : int Index of sightline in test set n_samples : int Number of samples per dropout, for getting kappa samples. (May be overridden with what was used previously, if kappa samples were already drawn and stored) n_mc_dropout : int Number of dropout iterates, for getting kappa samples. (May be overridden with what was used previously, if kappa samples were already drawn and stored) interim_pdf_func : callable Function that returns the density of the interim prior grid : None, optional Unused but kept for consistency with `get_kappa_log_weigths_grid` Returns ------- np.ndarray log weights for each of the BNN samples for this sightline .. py:method:: get_kappa_log_weights_grid(idx, grid=None, n_samples=None, n_mc_dropout=None, interim_pdf_func=None) Get log weights for reweighted kappa posterior, analytically on a grid Parameters ---------- idx : int Index of sightline in test set grid : np.ndarray, optional Grid of kappa values at which to evaluate log weights (May be overridden with what was used previously, if kappa samples were already drawn and stored) n_samples : int, optional Number of samples per dropout, for getting kappa samples. (May be overridden with what was used previously, if kappa samples were already drawn and stored) n_mc_dropout : int, optional Number of dropout iterates, for getting kappa samples. (May be overridden with what was used previously, if kappa samples were already drawn and stored) interim_pdf_func : callable, optional Function that returns the density of the interim prior Note ---- log doesn't help with numerical stability since we divide probabilities directly, but we're keeping this just for consistency Returns ------- np.ndarray kappa grid, log weights for each of the BNN samples for this sightline .. py:method:: get_reweighted_bnn_kappa(n_resamples, grid_kappa_kwargs, k_max=None) Get the reweighted BNN kappa samples, reweighted either on a grid or per sample Parameters ---------- n_resamples : int Number of resamples from the reweighted distribution grid_kappa_kwargs : dict Kwargs for Returns ------- tuple Two arrays of shape [n_test, 1, n_resamples], first of which is resamples using the grid reweighting and second of which is resamples using the per-sample reweighting .. py:method:: get_omega_samples(chain_path, chain_kwargs, log_idx=None) .. py:method:: visualize_omega_post(chain_path, chain_kwargs, corner_kwargs, log_idx=None) .. py:method:: visualize_kappa_post(idx, n_samples, n_mc_dropout, interim_pdf_func, grid=None) .. py:method:: compute_metrics() Evaluate metrics for model selection, based on per-sample reweighting for fair comparison to summary stats metrics # TODO: move to separate helper module for pre, post, post-grid .. py:method:: get_calibration_plot(k_bnn) Plot calibration (should be run on the validation set) Parameters ---------- k_bnn : np.ndarray Reweighted BNN samples, of shape [n_test, Y_dim, n_samples]