Source code for n2j.models.n2jnet

from typing import Optional, Tuple
from torch import Tensor
import torch
from torch.nn import (Module, ModuleList, ReLU, LeakyReLU, LayerNorm,
                      Sequential as Seq, Linear as Lin)
from torch_scatter import scatter_add
from torch_geometric.nn import MetaLayer
from torch import nn
import torch.nn.functional as F
from n2j.models.flow import Flow, MAF, Perm
from n2j.losses.gaussian_nll import DiagonalGaussianNLL


__all__ = ['N2JNet']
DEBUG = False


class MCDropout(nn.Dropout):
    """1D dropout that stays on during training and testing

    """
    def forward(self, input: Tensor) -> Tensor:
        return F.dropout(input, self.p, True, self.inplace)


class MCDropout2d(nn.Dropout2d):
    """2D dropout that stays on during training and testing

    """
    def forward(self, input: Tensor) -> Tensor:
        return F.dropout2d(input, self.p, True, self.inplace)


class CustomMetaLayer(MetaLayer):
    def __init__(self, node_model=None, global_model=None):
        super(CustomMetaLayer, self).__init__(edge_model=None,
                                              node_model=node_model,
                                              global_model=global_model)
        pass

    def forward(self,
                x: Tensor,
                u: Optional[Tensor] = None,
                batch: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
        if self.node_model is not None:
            x = self.node_model(x, u, batch)

        if self.global_model is not None:
            u = self.global_model(x, u, batch)
        return x, u


[docs]class N2JNet(Module): def __init__(self, dim_in, dim_out_local, dim_out_global, dim_local, dim_global, dim_hidden=20, dim_pre_aggr=20, n_iter=20, n_out_layers=5, global_flow=False, dropout=0.0, class_weight=None, use_ss=True, dim_in_meta=2, weight_sum=True): """Edgeless graph neural network modeling relationships among nodes and between nodes and global Parameters ---------- dim_in : int number of input features per node dim_out_local : int number of targets per node dim_out_global : int number of targets per graph dim_local : int dim_global : int dim_hidden : int dim_pre_aggr : int n_iter : int n_out_layers : int global_flow : bool dropout : float fraction of weights to zero during training and testing, for MC dropout. Default: 0.0 class_weight : torch.tensor use_ss : bool, True Whether to use summary stats to init global encoding. weight_sum : bool, True Whether to learn the weights for aggregating node features. """ super(N2JNet, self).__init__() self.dim_in = dim_in self.dim_out_local = dim_out_local self.dim_out_global = dim_out_global self.dim_hidden = dim_hidden self.dim_local = dim_local self.dim_global = dim_global self.dim_pre_aggr = dim_pre_aggr self.n_iter = n_iter self.n_out_layers = n_out_layers self.global_flow = global_flow self.class_weight = class_weight self.dropout = dropout self.dim_in_meta = dim_in_meta self.weight_sum = weight_sum # MLP for initially encoding local self.mlp_node_init = Seq(Lin(self.dim_in, self.dim_hidden), ReLU(), MCDropout(self.dropout), Lin(self.dim_hidden, self.dim_hidden), ReLU(), MCDropout(self.dropout), Lin(self.dim_hidden, self.dim_local), LayerNorm(self.dim_local)) # MLP for initially encoding global self.use_ss = use_ss if self.use_ss: self.mlp_global_init = Seq(Lin(self.dim_in_meta, self.dim_hidden), ReLU(), MCDropout(self.dropout), Lin(self.dim_hidden, self.dim_hidden), ReLU(), MCDropout(self.dropout), Lin(self.dim_hidden, self.dim_global), LayerNorm(self.dim_global)) # MLPs for encoding local and global meta_layers = ModuleList() for i in range(self.n_iter): node_model = NodeModel(self.dim_local, self.dim_global, self.dim_hidden, self.dropout) global_model = GlobalModel(self.dim_local, self.dim_global, self.dim_hidden, self.dim_pre_aggr, self.dropout, weight_sum=self.weight_sum) meta = CustomMetaLayer(node_model=node_model, global_model=global_model) meta_layers.append(meta) self.meta_layers = meta_layers # Networks for local and global output self.net_out_local = Seq(Lin(self.dim_local, self.dim_hidden), ReLU(), MCDropout(self.dropout), Lin(self.dim_hidden, self.dim_hidden), ReLU(), MCDropout(self.dropout), Lin(self.dim_hidden, self.dim_out_local*2)) if self.global_flow: self.net_out_global = Flow(*[[ MAF(self.dim_global, self.dim_out_global, hidden=dim_hidden), Perm(self.dim_global)][i%2] for i in \ range(self.n_out_layers*2 + 1)]) else: self.net_out_global = Seq(Lin(self.dim_global, self.dim_hidden), ReLU(), MCDropout(self.dropout), Lin(self.dim_hidden, self.dim_hidden), ReLU(), MCDropout(self.dropout), Lin(self.dim_hidden, self.dim_out_global*2)) # Losses self.local_nll = DiagonalGaussianNLL(dim_out_local) self.global_nll = DiagonalGaussianNLL(dim_out_global)
[docs] def forward(self, data): x = data.x # [n_nodes, n_features] x_meta = data.x_meta # [batch_size, 2] batch = data.batch # [batch_size,] batch_size = data.y.shape[0] # Init node and global encodings x, u x = self.mlp_node_init(x) # [n_nodes, dim_local] if self.use_ss: u = self.mlp_global_init(x_meta) else: u = torch.zeros(batch_size, self.dim_global).to(x.dtype).to(x.device) for i, meta in enumerate(self.meta_layers): x, u = meta(x=x, u=u, batch=batch) # x : [n_nodes, dim_local] # u : [batch_size, dim_global] if DEBUG: print("x is nan:", torch.any(torch.isnan(x)), x.mean()) print("u is nan:", torch.any(torch.isnan(u)), u.mean()) x = self.net_out_local(x) # [n_nodes, dim_local_out*2] if not self.global_flow: u = self.net_out_global(u) # [batch_size, dim_global_out*2] return x, u
[docs] def local_loss(self, x, data): y_local = data.y_local # [n_nodes, 2] nlogp_local = self.local_nll(x, y_local) # [n_nodes,] return nlogp_local
[docs] def global_loss(self, u, data): y = data.y if self.global_flow: u_out, log_det = self.net_out_global(u, y) if DEBUG: print("u_out", torch.any(torch.isnan(u_out)), u_out.mean()) print("log_det", torch.any(torch.isnan(log_det)), log_det.mean()) log_prob = -u_out.pow(2).sum(1)/2 # Standard normal base dist normalized_log_prob = log_prob + log_det nlogp_global = - normalized_log_prob # [batch_size,] else: nlogp_global = self.global_nll(u, y) # [batch_size,] return nlogp_global
[docs] def loss(self, x, u, data): local_loss = self.local_loss(x, data) # [n_nodes,] local_loss = scatter_add(local_loss, data.batch, dim=0) # [batch_size,] global_loss = self.global_loss(u, data) # [batch_size,] # Weight by inverse class number counts if self.class_weight is not None: y_weight = 1.0/self.class_weight[data.y_class].squeeze() local_loss *= y_weight global_loss *= y_weight return local_loss.mean(), global_loss.mean()
class NodeModel(Module): def __init__(self, dim_local, dim_global, dim_hidden, dropout): """MLP governing the node representation """ super(NodeModel, self).__init__() self.dim_local = dim_local self.dim_global = dim_global self.dim_hidden = dim_hidden self.dim_concat = self.dim_local + self.dim_global self.dropout = dropout self.mlp = Seq(Lin(self.dim_concat, self.dim_hidden), LayerNorm(self.dim_hidden), ReLU(), MCDropout(self.dropout), Lin(self.dim_hidden, self.dim_hidden), LayerNorm(self.dim_hidden), ReLU(), MCDropout(self.dropout), Lin(self.dim_hidden, self.dim_local), LayerNorm(self.dim_local)) def forward(self, x, u, batch): # x ~ [n_nodes, dim_local] # u ~ [batch, dim_global] but u[batch] ~ [n_nodes, dim_global] out = torch.cat([x, u[batch]], dim=-1) # [n_nodes, dim_local + dim_global] out = self.mlp(out) + x # [n_nodes, dim_local] return out class GlobalModel(Module): def __init__(self, dim_local, dim_global, dim_hidden, dim_pre_aggr, dropout, weight_sum): """MLP governing the global representation """ super(GlobalModel, self).__init__() self.dim_local = dim_local self.dim_global = dim_global self.dim_hidden = dim_hidden self.dim_concat = self.dim_local + self.dim_global self.dim_pre_aggr = dim_pre_aggr self.dropout = dropout self.weight_sum = weight_sum # MLP prior to aggregating node encodings self.mlp_pre_aggr = Seq(Lin(self.dim_concat, self.dim_hidden), ReLU(), MCDropout(self.dropout), Lin(self.dim_hidden, self.dim_hidden), ReLU(), MCDropout(self.dropout), Lin(self.dim_hidden, self.dim_pre_aggr), LayerNorm(self.dim_pre_aggr)) # MLP after aggregating node encodings self.mlp_post_aggr = Seq(Lin(self.dim_pre_aggr+self.dim_global, self.dim_hidden), ReLU(), MCDropout(self.dropout), Lin(self.dim_hidden, self.dim_hidden), ReLU(), MCDropout(self.dropout), Lin(self.dim_hidden, self.dim_global), LayerNorm(self.dim_global)) # Modeled after attention, e.g. # https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/gat_conv.html if self.weight_sum: self.node_alpha = Seq(Lin(self.dim_pre_aggr, 1), # fix n_heads = 1 LeakyReLU(negative_slope=0.01)) self.dropout_weights = MCDropout(self.dropout) def forward(self, x, u, batch): out = torch.cat([x, u[batch]], dim=-1) # [n_nodes, dim_local + dim_global] out = self.mlp_pre_aggr(out) # [n_nodes, self.dim_pre_aggr] if self.weight_sum: alpha = self.node_alpha(out) # [n_nodes, 1] alpha = F.softmax(alpha, dim=0) # [n_nodes, 1] the weights alpha = self.dropout_weights(alpha) out = scatter_add(out*alpha, batch, dim=0) # [batch_size, dim_pre_aggr] else: out = scatter_add(out, batch, dim=0) # [batch_size, dim_pre_aggr] out = torch.cat([out, u], dim=-1) # [batch_size, dim_pre_aggr + dim_global] out = self.mlp_post_aggr(out) # [batch_size, dim_global] out += u # [batch_size, dim_global] return out if __name__ == '__main__': net = N2JNet(dim_in=4, dim_out_local=2, dim_out_global=1, dim_local=11, dim_global=7, dim_hidden=19, dim_pre_aggr=21, n_iter=5, n_out_layers=7) class Batch: def __init__(self, x, y_local, y, y_class, batch): self.x = x self.y_local = y_local self.y = y self.y_class = y_class self.batch = batch batch = Batch(x=torch.randn(5, 4), y_local=torch.randn(5, 2), y=torch.randn(3, 1), y_class=torch.tensor([2, 3, 3]).long(), batch=torch.LongTensor([0, 0, 1, 1, 2])) x, u = net(batch) print(x.shape) print(u.shape) print("local loss: ", net.local_loss(x, batch).shape) print("global loss: ", net.global_loss(u, batch).shape) local_loss, global_loss = net.loss(x, u, batch) print("loss: ", local_loss.shape, global_loss.shape) print(local_loss.cpu().item()) print((local_loss/2.0 + 0.0).item()) n_params = sum(p.numel() for p in net.parameters() if p.requires_grad) print(f"Number of params: {n_params}")