================================== Hierarchical Inference with N2JNet ================================== Once our `N2JNet` is trained, we want to generate convergence predictions on individual sightlines and combine those predictions into a hierarchical inference of the population's convergence statistics. We proceed similarly as with training, e.g. :: $python infer.py nersc_config.yml The "inference" section of the `nersc_config.yml` config file in the repo provides an example of how to configure inference. We take a look at it here. First, we configure the properties of `InferenceManager`, a class that manages inference, with the device type, output directory (where all the inference results will be stored), and sampling seed. :: inference_manager: device_type: 'cpu' out_dir: '/global/cscratch1/sd/jwp/n2j/apj_v4/seed1/inference_E10_N1000' seed: 1025 Then we specify the test healpixes as well as the subsampling distribution for the test sets. For instance, if we had `n_subsample_test` of 1000 and `dist_name` of `'norm'` with `dist_kwargs` such that `loc` and `scale` were 0.04 and 0.005, respectively, we subsample 1,000 sightlines with a Gaussian distribution between with mean 0.04 and standard deviation 0.005. You can use any distribution supported by `scipy.stats`. The distributional parameters are the true hyperparameters governing the test population, which our hierarchical inference scheme will attempt to retrieve. :: test_data: seed: 1 batch_size: 1000 test_hp: [10326, 9686] n_test: [50000, 50000] n_subsample_test: 1000 dist_name: 'norm' dist_kwargs: loc: 0.04 scale: 0.005 The summary statistics matching serves as a useful comparison. We provide a grid of closeness thresholds and a minimum number of matches for a threshold to be considered valid. In the case below, we choose the smallest threshold that resulted in more than 200 matches. :: summary_stats: thresholds: N: [0, 1, 2, 4, 8, 16, 32, 64, 128] N_inv_dist: [0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512] min_matches: 200 # Replace with your own The checkpoint path of the trained `N2JNet` must be passed. :: checkpoint_path: '/global/cscratch1/sd/jwp/n2j/apj_v4/seed1/N2JNet_epoch=118_10-25-2021_08:10.mdl' If we want to run hierarchical inference, we set `run_mcmc` as `True`. If we want to stop with generating individual predictions, this can be set as `False`. We can configure the MCMC, such as the number of "run" iterations, number of "burn" iterations, number of walkers, and the number of CPU cores. :: run_mcmc: True extra_mcmc_kwargs: n_run: 50 n_burn: 20 n_walkers: 10 plot_chain: True clear: True n_cores: 1