"""Summary stats baseline computations
"""
import os
import os.path as osp
import copy
import numpy as np
from scipy import stats
import pandas as pd
from tqdm import tqdm
import torch
from torch_scatter import scatter_add
import n2j.inference.infer_utils as iutils
[docs]def get_number_counts(x, batch_indices):
"""Get the unweighted number counts
Parameters
----------
x : torch.tensor
Input features of shape [n_nodes, n_features] for a given batch
batch_indices : torch.tensor
Batch indices of shape [n_nodes,] for a given batch
"""
ones = torch.ones(x.shape[0])
N = scatter_add(ones, batch_indices)
return N.numpy()
[docs]def get_inv_dist_number_counts(x, batch_indices, pos_indices):
"""Get the inverse-dist weighted number counts
Parameters
----------
x : torch.tensor
Input features of shape [n_nodes, n_features] for a given batch
batch_indices : torch.tensor
Batch indices of shape [n_nodes,] for a given batch
pos_indices : list
List of the two indices corresponding to ra, dec in x
"""
dist = torch.sum(x[:, pos_indices]**2.0, dim=1)**0.5 # [n_nodes,]
weights = 1.0/(dist + 1.e-5)
weighted_N = scatter_add(weights, batch_indices)
return weighted_N.numpy()
[docs]class SummaryStats:
def __init__(self, n_data, pos_indices=[0, 1]):
"""Summary stats calculator
Parameters
----------
n_data : int
Size of dataset
pos_indices : list
Indices of `sub_features` corresponding to positions.
By default, assumed to be the first two indices.
"""
self.pos_indices = pos_indices
# Init stats
# TODO: don't hold all elements in memory, append to file
# in chunks
stats = dict()
stats['N_inv_dist'] = np.zeros(n_data)
stats['N'] = np.zeros(n_data)
self.stats = stats
[docs] def update(self, batch, i):
"""Update `stats` for a new batch
Parameters
----------
batch : array or dict
new batch of data whose data can be accessed by the functions in
`loader_dict`
i : int
index indicating that the batch is the i-th batch
"""
x = batch.x
batch_indices = batch.batch
N = get_number_counts(x, batch_indices)
N_inv_dist = get_inv_dist_number_counts(x, batch_indices,
self.pos_indices)
B = len(N)
self.stats['N'][i*B: (i+1)*B] = N
self.stats['N_inv_dist'][i*B: (i+1)*B] = N_inv_dist
[docs] def set_stats(self, stats_path):
"""Loads a previously stored stats
Parameters
----------
stats_path : str
Path to the .npy file of the stats dictionary
"""
stats = np.load(stats_path, allow_pickle=True).item()
self.stats = stats
[docs] def export_stats(self, stats_path):
"""Exports the stats attribute to disk as a npy file
Parameters
----------
stats_path : str
Path to the .npy file of the stats dictionary
"""
np.save(stats_path, self.stats, allow_pickle=True)
[docs]class Matcher:
def __init__(self, train_stats, test_stats,
train_y, out_dir, test_y=None):
"""Matcher of summary statistics between two datasets, train
and test
Parameters
----------
train_stats : SummaryStatistics instance
test_stats : SummaryStatistics instance
train_y : np.ndarray
out_dir : str
Output dir for matched data products
"""
self.train_stats = train_stats
self.test_stats = test_stats
self.train_y = train_y
self.test_y = test_y
self.out_dir = out_dir
os.makedirs(self.out_dir, exist_ok=True)
self.orig_samples_dir = os.path.join(self.out_dir, 'orig_samples')
os.makedirs(self.orig_samples_dir, exist_ok=True)
self.overview_path = osp.join(self.out_dir, 'overview.csv')
[docs] def match_summary_stats(self, thresholds, interim_pdf_func=None,
min_matches=1000, k_max=np.inf):
"""Match summary stats between train and test
Parameters
----------
thresholds : dict
Matching thresholds for summary stats
Keys should be one or both of 'N' and 'N_inv_dist'.
interim_pdf_func : callable, optional
Interim prior PDF with which to reweight the samples
"""
ss_names = list(thresholds.keys())
n_test = len(self.test_stats.stats[ss_names[0]])
overview = pd.DataFrame(columns=['los_i',
'summary_stats_name',
'threshold',
'n_matches',
'med',
'plus_1sig',
'minus_1sig',
'logp',
'mad',
'mae',
'is_optimal'])
for i in tqdm(range(n_test), desc="matching"):
for s in ss_names:
test_x = self.test_stats.stats[s][i]
optimal_crit = np.empty(len(thresholds[s]))
rows_for_s = []
for t_idx, t in enumerate(thresholds[s]):
# TODO: do this in chunks
accepted, _ = match(self.train_stats.stats[s],
test_x,
self.train_y,
t)
accepted = accepted[accepted < k_max]
n_matches = len(accepted)
np.save(osp.join(self.orig_samples_dir,
f'matched_k_los_{i}_ss_{s}_{t:.3f}.npy'),
accepted)
# Add descriptive stats to overview table
row = dict(los_i=i,
summary_stats_name=s,
threshold=t,
test_x=test_x,
n_matches=n_matches)
optimal_crit[t_idx] = n_matches
if len(accepted) > 1:
# Fit normal on accepted samples, to resample from
norm_obj = stats.norm(loc=np.median(accepted),
scale=stats.median_abs_deviation(accepted, scale='normal'))
rng = np.random.RandomState(i)
accepted_norm = norm_obj.rvs(20000, random_state=rng)
if interim_pdf_func is not None:
inv_prior = 1.0/interim_pdf_func(accepted_norm)
# Reweight normal samples
try:
resamples = iutils.resample_from_samples(accepted_norm,
inv_prior,
n_resamples=20000,
plot_path=None)
# resamples = resamples[resamples < k_max]
except ValueError:
print("Reweighting normal")
print(f"Sightline {i}")
print("Accepted samples were of shape", accepted.shape)
print("Threshold was", t)
resamples = resamples.squeeze() # [n_resamples]
np.save(osp.join(self.out_dir,
f'matched_resampled_los_{i}_ss_{s}_{t:.3f}.npy'),
resamples)
else:
resamples = accepted_norm # do not weight
# Computing summary stats metrics
lower, med, upper = np.quantile(resamples,
[0.5-0.34, 0.5, 0.5+0.34])
row.update(med=med,
plus_1sig=upper-med,
minus_1sig=med-lower,
mad=stats.median_abs_deviation(resamples, scale='normal')
)
# Comparison with truth, if available
if self.test_y is not None:
kde = stats.gaussian_kde(resamples,
bw_method='scott')
true_k = self.test_y[i, 0]
row.update(logp=kde.logpdf(true_k).item(),
mae=np.median(np.abs(resamples - true_k)),
true_k=true_k
)
# Each ss name and threshold combo gets a row
# Wait until all thresholds are collected to append
rows_for_s.append(row)
# Determine optimal threshold
try:
is_optimal = get_optimal_threshold(thresholds[s],
optimal_crit,
min_matches=min_matches)
except:
print("Test x:", test_x)
print("Summary stat: ", s)
print("Thresholds: ", thresholds[s])
print("Matches: ", optimal_crit)
raise ValueError("Can't find the optimal threshold!")
# Record whether each row was "optimal"
# There's only one optimal row for a given ss_name
for r_i, r in enumerate(rows_for_s):
r.update(is_optimal=is_optimal[r_i])
overview = overview.append(rows_for_s, ignore_index=True)
print(f"Saving overview table at {self.overview_path}...")
overview.to_csv(self.overview_path, index=False)
[docs] def get_samples(self, idx, ss_name, threshold=None):
"""Get the pre-weighting (raw) accepted samples
Parameters
----------
idx : int
ID of sightline
ss_name : str
Summary stats name
threshold : int, optional
Matching threshold. If None, use the optimal threshold.
Default: None
Returns
-------
np.ndarray
Samples of shape `[n_matches]`
"""
if threshold is None:
# Default to optimal threshold
overview = self.get_overview_table()
crit = np.logical_and(np.logical_and(overview['los_i'] == idx,
overview['summary_stats_name'] == ss_name),
overview['is_optimal'])
threshold = overview[crit]['threshold'].item()
path = osp.join(self.orig_samples_dir,
f'matched_k_los_{idx}_ss_{ss_name}_{threshold:.3f}.npy')
samples = np.load(path)
return samples
[docs] def get_overview_table(self):
if not osp.exists(self.overview_path):
raise OSError("Table doesn't exist. Please generate it first.")
else:
overview = pd.read_csv(self.overview_path, index_col=None)
return overview
[docs]def get_optimal_threshold(thresholds, n_matches, min_matches=1000):
"""Get the smallest threshold that has some minimum number of matches
Parameters
----------
thresholds : array-like
n_matches : array-like
min_matches : int
"""
is_optimal = np.zeros(len(thresholds)).astype(bool) # init all False
thresholds = np.array(thresholds).astype(float)
n_matches = np.array(n_matches)
# Impossible for thresholds with n_matches < min_matches to be selected
thresholds[n_matches < min_matches] = np.nan # hacky
if np.isnan(thresholds).all():
raise ValueError("No threshold with sufficient matches.")
else:
i = np.nanargmin(thresholds)
is_optimal[i] = True
return is_optimal
[docs]def match(train_x, test_x, train_y, threshold):
"""Match summary stats between train and test within given threshold
Parameters
----------
train_x : np.ndarray
train summary stats
test_x : float
test summary stats
train_y : np.ndarray
train target values
threshold : float
closeness threshold matching is based on
Returns
-------
tuple
boolean mask of accepted samples for train_y and the accepted
samples
"""
is_passing = np.abs(train_x - test_x) < threshold
accepted = train_y[is_passing] # [n_test, Y_dim]
return accepted.squeeze(-1), is_passing