:py:mod:`infer_utils` ===================== .. py:module:: infer_utils .. autoapi-nested-parse:: Utility methods managing inference based on a trained model Module Contents --------------- Functions ~~~~~~~~~ .. autoapisummary:: infer_utils.get_normal_logpdf infer_utils.run_mcmc infer_utils.get_chain_plot infer_utils.get_log_p_k_given_omega_int_kde infer_utils.get_log_p_k_given_omega_int_analytic infer_utils.get_omega_post infer_utils.get_omega_post_loop infer_utils.log_prob_mcmc infer_utils.log_prob_mcmc_loop infer_utils.get_mcmc_samples infer_utils.get_kappa_log_weights infer_utils.get_kappa_log_weights_vectorized infer_utils.resample_from_pdf infer_utils.fit_kde_on_weighted_samples infer_utils.resample_from_samples Attributes ~~~~~~~~~~ .. autoapisummary:: infer_utils.DEBUG .. py:data:: DEBUG :value: False .. py:function:: get_normal_logpdf(mu, log_sigma, x, bounds_lower=-np.inf, bounds_upper=np.inf) Evaluate the log kappa likelihood of the test set, log p(k_j|Omega), exactly Note ---- Only normal likelihood supported for now. Returns ------- np.ndarray or float Log PDF, of shape broadcasted across mu, log_sigma, and x .. py:function:: run_mcmc(log_prob, log_prob_kwargs, p0, n_run, n_burn, chain_path, run_name='mcmc', n_walkers=100, plot_chain=True, clear=False, n_cores=None) Run MCMC sampling Parameters ---------- p0 : np.array of shape `[n_walkers, n_dim]` n_run : int n_burn : int chain_path : os.path or str n_walkers : int plot_chain : bool .. py:function:: get_chain_plot(samples, out_path='mcmc_chain.png') Plot MCMC chain Note ---- Borrowed from https://emcee.readthedocs.io/en/stable/tutorials/line/ .. py:function:: get_log_p_k_given_omega_int_kde(k_train, k_bnn, kwargs=None) Evaluate the log likelihood, log p(k|Omega_int), using kernel density estimation (KDE) on training kappa, on the BNN kappa samples of test sightlines Parameters ---------- k_train : np.array of shape `[n_train]` kappa in the training set k_bnn : np.array of shape `[n_test, n_samples]` kwargs : dict currently unused, placeholder for symmetry with analytic version Returns ------- np.array of shape `[n_test, n_samples]` log p(k|Omega_int) .. py:function:: get_log_p_k_given_omega_int_analytic(k_train, k_bnn, interim_pdf_func) Evaluate the log likelihood, log p(k|Omega_int), using kernel density estimation (KDE) on training kappa, on the BNN kappa samples of test sightlines Parameters ---------- k_train : np.array of shape `[n_train]` kappa in the training set. Unused. k_bnn : np.array of shape `[n_test, n_samples]` interim_pdf_func : callable function that evaluates the PDF of the interim prior Returns ------- np.array of shape `[n_test, n_samples]` log p(k|Omega_int) .. py:function:: get_omega_post(k_bnn, log_p_k_given_omega_int, mcmc_kwargs, bounds_lower, bounds_upper) Sample from p(Omega|{d}) using MCMC Parameters ---------- k_bnn : np.array of shape `[n_test, n_samples]` BNN samples for `n_test` sightlines log_p_k_given_omega_int : np.array of shape `[n_test, n_samples]` log p(k_bnn|Omega_int) .. py:function:: get_omega_post_loop(k_samples_list, log_p_k_given_omega_int_list, mcmc_kwargs, bounds_lower, bounds_upper) Sample from p(Omega|{d}) using MCMC Parameters ---------- k_samples_list : list Each element is the array of samples for a sightline, so the list has length `n_test` log_p_k_given_omega_int_list : list Each element is the array of log p(k_samples|Omega_int) for a sightline, so the list has length `n_test` .. py:function:: log_prob_mcmc(omega, log_p_k_given_omega_func, log_p_k_given_omega_int) Evaluate the MCMC objective Parameters ---------- omega : list Current MCMC sample of [mu, log_sigma] = Omega log_p_k_given_omega_func : callable function that returns p(k|Omega) of shape [n_test, n_samples] for given omega and k fixed to be the BNN samples log_p_k_given_omega_int : np.ndarray Values of p(k|Omega_int) of shape [n_test, n_samples] for k fixed to be the BNN samples Returns ------- float Description .. py:function:: log_prob_mcmc_loop(omega, log_p_k_given_omega_func_list, log_p_k_given_omega_int_list) Evaluate the MCMC objective Parameters ---------- omega : list Current MCMC sample of [mu, log_sigma] = Omega log_p_k_given_omega_func_list : list of callable List of functions that returns p(k|Omega) of shape [n_samples] for given omega and k fixed to be the posterior samples log_p_k_given_omega_int_list : np.ndarray List of values of p(k|Omega_int) of shape [n_samples] for k fixed to be the posterior samples Returns ------- float Description .. py:function:: get_mcmc_samples(chain_path, chain_kwargs) Load the samples from saved MCMC run Parameters ---------- chain_path : str Path to the stored chain chain_kwargs : dict Options for chain postprocessing, including flat, thin, discard Returns ------- np.array of shape `[n_omega, 2]` omega samples from MCMC chain .. py:function:: get_kappa_log_weights(k_bnn, log_p_k_given_omega_int, omega_post_samples=None) Evaluate the log weights used to reweight individual kappa posteriors Parameters ---------- k_bnn : np.ndarray of shape `[n_samples]` BNN posterior samples log_p_k_given_omega_int : np.ndarray of shape `[n_samples]` Likelihood of BNN kappa given the interim prior. omega_post_samples : np.ndarray, optional Omega posterior samples used as the prior to apply. Should be np.array of shape `[n_omega, 2]`. If None, only division by the interim prior will be done. Returns ------- np.ndarray log weights evaluated at k_bnn .. py:function:: get_kappa_log_weights_vectorized(k_bnn, omega_post_samples, log_p_k_given_omega_int) Evaluate the log weights used to reweight individual kappa posteriors Parameters ---------- k_bnn : np.array of shape `[n_test, n_samples]` omega_post_samples : np.array of shape `[n_omega, 2]` log_p_k_given_omega_int : np.array of shape `[n_test, n_samples]` .. py:function:: resample_from_pdf(grid, log_pdf, n_samples) .. py:function:: fit_kde_on_weighted_samples(samples, weights=None) Fit a KDE on weighted samples .. py:function:: resample_from_samples(samples, weights, n_resamples, plot_path=None) Resample from a distribution defined by weighted samples Parameters ---------- samples : np.ndarray weights : np.ndarray n_resamples : int plot_path : str Path for the plot illustrating the KDE fit