Source code for n2j.trainval_data.graphs.cosmodc2_graph

"""Training input graph X created from the postprocessed CosmoDC2 catalog

"""

import os
import os.path as osp
import multiprocessing
from functools import cached_property
import bisect
import numpy as np
import pandas as pd
from scipy.spatial import cKDTree
import scipy.stats
from tqdm import tqdm
import torch
from torch.utils.data.dataset import ConcatDataset
from torch_geometric.data import DataLoader
from n2j.trainval_data.graphs.base_graph import BaseGraph, Subgraph
from n2j.trainval_data.utils import coord_utils as cu
from n2j.trainval_data.utils.running_stats import RunningStats
from torch.utils.data.sampler import SubsetRandomSampler  # WeightedRandomSampler,


[docs]class CosmoDC2Graph(ConcatDataset): """Concatenation of multiple CosmoDC2GraphHealpix instances, with an added data transformation functionality """ def __init__(self, in_dir, healpixes, raytracing_out_dirs, aperture_size, n_data, features, subsample_pdf_func=None, n_subsample=None, subsample_with_replacement=True, stop_mean_std_early=False, n_cores=20, num_workers=4, out_dir=None, seed=123): """Summary Parameters ---------- in_dir : TYPE Description healpixes : TYPE Description raytracing_out_dirs : TYPE Description aperture_size : TYPE Description n_data : TYPE Description features : TYPE Description subsample_pdf_func : callable, optional Function that evaluates the target subsampling PDF n_subsample : int, optional How many examples to subsample, to form the final effective dataset size. Required if subsample_pdf_func is not None. stop_mean_std_early : bool, optional Description n_cores : int, optional Description """ self.stop_mean_std_early = stop_mean_std_early self.n_datasets = len(healpixes) self.n_cores = n_cores self.num_workers = num_workers self.subsample_pdf_func = subsample_pdf_func self.seed = seed if out_dir is None: out_dir = in_dir else: out_dir = out_dir if self.subsample_pdf_func is not None: assert n_subsample is not None self.n_subsample = n_subsample self.replace = False datasets = [] Y_list = [] for i in range(self.n_datasets): print(f"Appending healpix {healpixes[i]}") graph_hp = CosmoDC2GraphHealpix(healpixes[i], in_dir, raytracing_out_dirs[i], aperture_size, n_data[i], features, n_cores=self.n_cores, out_dir=out_dir ) datasets.append(graph_hp) #Y_list.append(graph_hp.Y) #self.Y = pd.concat(Y_list, ignore_index=True).reset_index(drop=True) ConcatDataset.__init__(self, datasets) self.transform_X_Y_local = None self.transform_Y = None @cached_property
[docs] def data_stats(self): """Statistics of the X, Y data used for standardizing """ loader_dict = dict(X=lambda b: b.x, # node features x Y_local=lambda b: b.y_local, # node labels y_local Y=lambda b: b.y, X_meta=lambda b: b.x_meta) # graph labels y rs = RunningStats(loader_dict) y_class_counts = 0 # [n_classes,] where n_classes = number of bins y_class = torch.zeros(len(self), dtype=torch.long) # [n_train,] if self.subsample_pdf_func is None: subsample_weight = None else: subsample_weight = np.zeros(len(self)) # [n_train,] y_values_orig = np.zeros(len(self)) batch_size = 2000 if self.n_cores < 8 else 10000 dummy_loader = DataLoader(dataset=self, batch_size=batch_size, shuffle=False, num_workers=self.num_workers, drop_last=False) print("Generating standardizing metadata...") for i, b in enumerate(dummy_loader): # Update running stats for this new batch rs.update(b, i) # Update running bin count for kappa y_class_b = b.y_class y_class_b[y_class_b > 3] = 3 y_class_counts += torch.bincount(y_class_b, minlength=4)[:4] y_class[i*batch_size:(i+1)*batch_size] = y_class_b # Log original kappa values k_values_orig_batch = b.y[:, 0].cpu().numpy() y_values_orig[i*batch_size:(i+1)*batch_size] = k_values_orig_batch # Compute subsampling weights if self.subsample_pdf_func is not None: subsample_weight[i*batch_size:(i+1)*batch_size] = self.subsample_pdf_func(k_values_orig_batch) if self.stop_mean_std_early and i > 100: break print("Y_mean without resampling: ", rs.stats['Y_mean']) print("Y_std without resampling: ", rs.stats['Y_var']**0.5) # Each bin is weighted by the inverse frequency class_weight = torch.sum(y_class_counts)/y_class_counts # [n_bins,] y_weight = class_weight[y_class] # [n_train] subsample_idx = None # Recompute mean, std if subsampling according to a distribution if self.subsample_pdf_func is not None: print("Re-generating standardizing metadata for subsampling dist...") # Re-initialize mean, std rs = RunningStats(loader_dict) # Define SubsetRandomSampler to follow dist in subsample_pdf_func print("Subsampling with replacement to follow provided subsample_pdf_func...") # See https://github.com/pytorch/pytorch/issues/11201 torch.multiprocessing.set_sharing_strategy('file_system') rng = np.random.default_rng(123) kde = scipy.stats.gaussian_kde(y_values_orig, bw_method='scott') p = subsample_weight/kde.pdf(y_values_orig) p /= np.sum(p) subsample_idx = rng.choice(np.arange(len(y_values_orig)), p=p, replace=self.replace, size=self.n_subsample) subsample_idx = subsample_idx.tolist() sampler = SubsetRandomSampler(subsample_idx) sampling_loader = DataLoader(self, batch_size=batch_size, sampler=sampler, num_workers=self.num_workers, drop_last=False) for i, b in enumerate(sampling_loader): # Update running stats for this new batch rs.update(b, i) if self.stop_mean_std_early and i > 100: break class_weight = None y_weight = None print("Y_mean with resampling: ", rs.stats['Y_mean']) print("Y_std with resampling: ", rs.stats['Y_var']**0.5) print("X_meta_mean with resampling: ", rs.stats['X_meta_mean']) print("X_meta_std with resampling: ", rs.stats['X_meta_var']**0.5) stats = dict(X_mean=rs.stats['X_mean'], X_std=rs.stats['X_var']**0.5, Y_mean=rs.stats['Y_mean'], Y_std=rs.stats['Y_var']**0.5, Y_local_mean=rs.stats['Y_local_mean'], Y_local_std=rs.stats['Y_local_var']**0.5, X_meta_mean=rs.stats['X_meta_mean'], X_meta_std=rs.stats['X_meta_var']**0.5, y_weight=y_weight, # [n_train,] or None subsample_idx=subsample_idx, class_weight=class_weight, # [n_classes,] or None ) return stats
@cached_property
[docs] def data_stats_valtest(self): """Statistics of the X, Y data on validation set used for resampling to mimic training dist. Mean, std computation skipped. """ print("Computing resampling stats for val/test set...") B = 1000 dummy_loader = DataLoader(self, # val_dataset batch_size=B, shuffle=False, num_workers=self.num_workers, drop_last=False) # If subsample_pdf_func is None, don't need this attribute assert self.subsample_pdf_func is not None assert self.n_subsample is not None torch.multiprocessing.set_sharing_strategy('file_system') y_values_orig = np.zeros(len(self)) # [n_val,] subsample_weight = np.zeros(len(self)) # [n_val,] # Evaluate target density on all validation examples for i, b in enumerate(dummy_loader): # Log original kappa values k_batch = b.y[:, 0].cpu().numpy() y_values_orig[i*B:(i+1)*B] = k_batch # Compute subsampling weights subsample_weight[i*B:(i+1)*B] = self.subsample_pdf_func(k_batch) rng = np.random.default_rng(self.seed) kde = scipy.stats.gaussian_kde(y_values_orig, bw_method='scott') p = subsample_weight/kde.pdf(y_values_orig) p /= np.sum(p) subsample_idx = rng.choice(np.arange(len(y_values_orig)), p=p, replace=self.replace, size=self.n_subsample) subsample_idx = subsample_idx.tolist() stats_val = dict(subsample_idx=subsample_idx) return stats_val
[docs] def __getitem__(self, idx): if idx < 0: if -idx > len(self): raise ValueError("absolute value of index should not exceed" " dataset length") idx = len(self) + idx dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] data = self.datasets[dataset_idx][sample_idx] if self.transform_X_Y_local is not None: data.x, data.y_local, data.x_meta = self.transform_X_Y_local(data.x, data.y_local, data.x_meta) if self.transform_Y is not None: data.y = self.transform_Y(data.y) return data
[docs]class CosmoDC2GraphHealpix(BaseGraph): """Set of graphs representing a single healpix of the CosmoDC2 field """
[docs] columns = ['ra', 'dec', 'galaxy_id', 'redshift']
columns += ['ra_true', 'dec_true', 'redshift_true'] columns += ['ellipticity_1_true', 'ellipticity_2_true'] columns += ['bulge_to_total_ratio_i'] columns += ['ellipticity_1_bulge_true', 'ellipticity_1_disk_true'] columns += ['ellipticity_2_bulge_true', 'ellipticity_2_disk_true'] # columns += ['shear1', 'shear2', 'convergence'] columns += ['size_bulge_true', 'size_disk_true', 'size_true'] columns += ['mag_{:s}_lsst'.format(b) for b in 'ugrizY'] def __init__(self, healpix, in_dir, raytracing_out_dir, aperture_size, n_data, features, n_cores=20, out_dir=None, debug=False,): """Graph dataset for a single healpix Parameters ---------- healpix : int Healpix ID of NSIDE=32 from CosmoDC2 in_dir : str Directory from which to read input. Catalogs for this healpix should be placed in `in_dir/cosmodc2_{healpix}/raw` raytracing_out_dir : str Directory containing the raytraced labels, which should live in `raytracing_out_dir/Y_{healpix}` aperture_size : float Radius of aperture in arcmin n_data : int Number of sightlines features : list Input features per node n_cores : int, optional Number of cores to parallelize across. Only used when generating the data. out_dir : str, optional Directory to store the generated graphs. Graphs will go to `out_dir/cosmodc2_{healpix}/processed`. debug : bool, optional Debug mode. Default: False """ self.in_dir = in_dir if out_dir is None: self.out_dir = in_dir else: self.out_dir = out_dir self.healpix = healpix self.features = features self.n_cores = n_cores self.closeness = 0.5/60.0 # deg, edge criterion between neighbors self.mag_lower = -np.inf # lower magnitude cut, excludes stars # LSST gold sample i-band mag (Gorecki et al 2014) = 25.3 # LSST 10-year coadded 5-sigma depth = 26.8 self.mag_upper = 25.3 # 26.8 # upper magnitude cut, excludes small halos # Store output in <root>/processed for processed_dir # Read input from in_dir/cosmodc2_{healpix}/raw root = osp.join(self.out_dir, f'cosmodc2_{self.healpix}') BaseGraph.__init__(self, root, raytracing_out_dir, aperture_size, n_data, debug) @property
[docs] def n_features(self): return len(self.features)
@property
[docs] def raw_dir(self) -> str: return osp.join(self.in_dir, f'cosmodc2_{self.healpix}', 'raw')
@property
[docs] def raw_file_name(self): if self.debug: return 'debug_gals.csv' else: return 'gals_{:d}.csv'.format(self.healpix)
@property
[docs] def raw_file_names(self): return [self.raw_file_name]
@property
[docs] def processed_file_fmt(self): if self.debug: return 'debug_subgraph_{:d}.pt' else: return 'subgraph_{:d}.pt'
@property
[docs] def processed_file_path_fmt(self): return osp.join(self.processed_dir, self.processed_file_fmt)
@property
[docs] def processed_file_names(self): """A list of files relative to self.processed_dir which needs to be found in order to skip the processing """ return [self.processed_file_fmt.format(n) for n in range(self.n_data)]
[docs] def get_los_node(self): """Properties of the sightline galaxy, with unobservable features (everything other than position) appropriately masked out. Parameters ---------- ra_los : ra of sightline, in arcmin dec_los : dec of sightline, in arcmin """ node = dict(zip(self.features, [[0]]*len(self.features))) return node
[docs] def download(self): """Called when `raw_file_names` aren't found """ pass
[docs] def get_gals_iterator(self, healpix, columns, chunksize=100000): """Get an iterator over the galaxy catalog defining the line-of-sight galaxies """ # dtype = dict(zip(columns, [np.float32]*len(columns))) # if 'galaxy_id' in columns: # dtype['galaxy_id'] = np.int64 if self.debug: cat = pd.read_csv(self.raw_paths[0], chunksize=50, nrows=1000, usecols=columns, dtype=np.float32) else: cat = pd.read_csv(self.raw_paths[0], chunksize=chunksize, nrows=None, usecols=columns, dtype=np.float32) return cat
[docs] def get_edges(self, ra_dec): """Get the edge indices from the node positions Parameters ---------- ra_dec : `np.ndarray` ra and dec of nodes, of shape `[n_nodes, 2]` Returns ------- `torch.LongTensor` edge indices, of shape `[2, n_edges]` """ n_nodes = ra_dec.shape[0] kd_tree = cKDTree(ra_dec) # Pairs of galaxies that are close enough edges_close = kd_tree.query_pairs(r=self.closeness, p=2, eps=self.closeness/5.0, output_type='set') edges_close_reverse = [(b, a) for a, b in edges_close] # bidirectional # All neighboring gals have edge to central LOS gal edges_to_center = set(zip(np.arange(n_nodes), np.zeros(n_nodes))) edge_index = edges_to_center.union(edges_close) edge_index = edge_index.union(edges_close_reverse) edge_index = torch.LongTensor(list(edge_index)).transpose(0, 1) return edge_index
[docs] def _save_graph_to_disk(self, i): los_info = self.Y.iloc[i] # Init with central galaxy containing masked-out features # Back when central galaxy was given a node # nodes = pd.DataFrame(self.get_los_node()) nodes = pd.DataFrame(columns=self.features + ['halo_mass', 'stellar_mass']) gals_iter = self.get_gals_iterator(self.healpix, self.features + ['halo_mass', 'stellar_mass']) for gals_df in gals_iter: # Query neighboring galaxies within 3' to sightline dist, ra_diff, dec_diff = cu.get_distance(gals_df['ra_true'].values, gals_df['dec_true'].values, los_info['ra'], los_info['dec']) gals_df['ra_true'] = ra_diff # deg gals_df['dec_true'] = dec_diff # deg gals_df['r'] = dist dist_keep = np.logical_and(dist < self.aperture_size/60.0, dist > 1.e-7) # exclude LOS gal mag_keep = np.logical_and(gals_df['mag_i_lsst'].values > self.mag_lower, gals_df['mag_i_lsst'].values < self.mag_upper) keep = np.logical_and(dist_keep, mag_keep) nodes = nodes.append(gals_df.loc[keep, :], ignore_index=True) x = torch.from_numpy(nodes[self.features].values).to(torch.float32) y_local = torch.from_numpy(nodes[['halo_mass', 'stellar_mass', 'redshift_true']].values).to(torch.float32) y_global = torch.FloatTensor([[los_info['final_kappa'], los_info['final_gamma1'], los_info['final_gamma2']]]) # [1, 3] x_meta = torch.FloatTensor([[x.shape[0], np.sum(1.0/(nodes['r'].values + 1.e-5))]]) # [1, 2] # Vestiges of adhoc edge definitions # edge_index = self.get_edges(nodes[['ra_true', 'dec_true']].values) # data = Subgraph(x, global_y, edge_index) y_class = self._get_y_class(y_global) data = Subgraph(x=x, y=y_global, y_local=y_local, x_meta=x_meta, y_class=y_class) if self.pre_transform is not None: data = self.pre_transform(data) torch.save(data, self.processed_file_path_fmt.format(i))
[docs] def _get_y_class(self, y): y_class = torch.bucketize(y[:, 0], # only kappa boundaries=torch.Tensor([0.0, 0.03, 0.05, 1.e6])) return y_class
[docs] def process_single(self, i): """Process a single sightline indexed i """ if not osp.exists(self.processed_file_path_fmt.format(i)): self._save_graph_to_disk(i)
# else: # self._save_graph_to_disk(i) # else: # data = torch.load(self.processed_file_path_fmt.format(i)) # data.y_class = self._get_y_class(data.y) # torch.save(data, self.processed_file_path_fmt.format(i))
[docs] def process(self): """Process multiple sightline in parallel """ print("Parallelizing across {:d} cores...".format(self.n_cores)) with multiprocessing.Pool(self.n_cores) as pool: return list(tqdm(pool.imap(self.process_single, range(self.n_data)), total=self.n_data))
[docs] def len(self): return len(self.processed_file_names)
[docs] def get(self, idx): data = torch.load(self.processed_file_path_fmt.format(idx)) return data