Source code for n2j.trainval_data.graphs.base_graph

"""Generic catalog-agnostic module for input graph X

"""

import os
import random
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Dataset, Data


[docs]class Subgraph(Data): """Subgraph representing a single sightline """ def __init__(self, x=None, y_local=None, y=None, x_meta=None, y_class=None, edge_index=None): """ Parameters ---------- x : `torch.FloatTensor` of shape `[n_nodes, n_features]` the galaxy defining the sightline (first node = v0) and its neighbors edge_index : `torch.LongTensor` of shape `[2, n_edges]` directed edges from each of the neighbors to v0 y : `torch.FloatTensor` of shape `[3]` the label to infer """ Data.__init__(self, x=x, y=y, edge_index=edge_index) self.y_local = y_local self.x_meta = x_meta self.y_class = y_class
[docs]class BaseGraph(Dataset): """ABC for graphs created from photometric catalogs. Not to be used on its own. Child classes follow the naming convention, `<name of catalog>Graph` """ def __init__(self, root, raytracing_out_dir, aperture_size, n_data, debug=False, transform=None, pre_transform=None, pre_filter=None): """ Parameters ---------- root : str path to train or val directory containing `raw` and `processed` folders raytracing_out_dir : str path to output directory of raytracer containing `Y.csv` aperture_size : float Radius of field of view around each sightline in arcmin debug : bool debug mode. Default: False """ self.raytracing_out_dir = raytracing_out_dir self.aperture_size = aperture_size self.n_data = n_data self.debug = debug self._get_sightlines() Dataset.__init__(self, root, transform, pre_transform, pre_filter)
[docs] def _get_sightlines(self): """Load the precomputed sightlines containing the pointings and labels """ Y_path = os.path.join(self.raytracing_out_dir, 'Y.csv') cols = ['galaxy_id', 'final_kappa', 'final_gamma1', 'final_gamma2'] cols += ['ra', 'dec', 'z'] self.Y = pd.read_csv(Y_path, usecols=cols, index_col=None, nrows=self.n_data)
# Convert deg to arcmin # self.Y.loc[:, ['ra', 'dec']] *= 60.0