Source code for n2j.models.gnn

"""Various GNN models

"""
import torch
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, GravNetConv
import torch.nn as nn
import torch.nn.functional as F


__all__ = ['GCNNet', 'GATNet', 'SageNet', 'GravNet']


def get_zero_nodes(batch_idx):
    """Get indices of the zeroth nodes in the batch

    """
    batch_idx = torch.cat([torch.zeros(1, device=batch_idx.device), batch_idx])
    diff = batch_idx[1:] - batch_idx[:-1]
    diff[0] = 1
    return diff.bool()


[docs]class GCNNet(nn.Module): def __init__(self, in_channels, out_channels, hidden_channels=256, n_layers=3, dropout=0.0, kwargs={}): super(GCNNet, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.hidden_channels = hidden_channels self.n_layers = n_layers self.dropout = dropout self.kwargs = kwargs self.convs = nn.ModuleList() for i in range(self.n_layers-1): n_in = self.in_channels if i == 0 else self.hidden_channels self.convs.append(GCNConv(n_in, self.hidden_channels, aggr='add', add_self_loops=False, **self.kwargs)) # Last layer self.convs.append(GCNConv(self.hidden_channels, self.out_channels, aggr='add', add_self_loops=False, **self.kwargs)) # self.fc = nn.Linear(self.hidden_channels, out_channels)
[docs] def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch for i in range(self.n_layers-1): x = self.convs[i](x, edge_index) x = F.leaky_relu(x) x = F.dropout(x, p=self.dropout, training=True) x = self.convs[-1](x, edge_index) zero_idx_mask = get_zero_nodes(batch) x = x[zero_idx_mask, :] return x
[docs]class GATNet(nn.Module): def __init__(self, in_channels, out_channels, hidden_channels=256, kwargs={}, n_layers=3, dropout=0.0): super(GATNet, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.hidden_channels = hidden_channels self.n_layers = n_layers self.dropout = dropout self.kwargs = kwargs self.convs = nn.ModuleList() for i in range(self.n_layers-1): n_in = self.in_channels if i == 0 else self.hidden_channels self.convs.append(GATConv(n_in, self.hidden_channels, aggr='add', dropout=self.dropout, add_self_loops=False, **self.kwargs)) # Last layer self.convs.append(GATConv(self.hidden_channels, self.out_channels, aggr='add', dropout=self.dropout, add_self_loops=False, **self.kwargs)) # self.fc = nn.Linear(self.hidden_channels, out_channels)
[docs] def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch zero_idx_mask = get_zero_nodes(batch) for i in range(self.n_layers-1): x = self.convs[i](x, edge_index) x = F.leaky_relu(x) x = F.dropout(x, p=self.dropout, training=True) if self.training: x = self.convs[-1](x, edge_index) x = x[zero_idx_mask, :] return x else: x, (edge_index, w) = self.convs[-1](x, edge_index, return_attention_weights=True) x = x[zero_idx_mask, :] return x, (edge_index, w)
[docs]class SageNet(nn.Module): def __init__(self, in_channels, out_channels, hidden_channels=256, n_layers=3, dropout=0.0, kwargs={}): super(SageNet, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.hidden_channels = hidden_channels self.n_layers = n_layers self.dropout = dropout self.kwargs = kwargs self.convs = nn.ModuleList() for i in range(self.n_layers-1): n_in = self.in_channels if i == 0 else self.hidden_channels self.convs.append(SAGEConv(n_in, self.hidden_channels, aggr='add', normalize=True, # otherwise explode root_weight=False, **self.kwargs)) # Last layer self.convs.append(SAGEConv(self.hidden_channels, self.out_channels, aggr='add', normalize=True, # otherwise explode root_weight=False, **self.kwargs))
[docs] def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch for i in range(self.n_layers-1): x = self.convs[i](x, edge_index) x = F.leaky_relu(x) x = F.dropout(x, p=self.dropout, training=True) x = self.convs[-1](x, edge_index) zero_idx_mask = get_zero_nodes(batch) x = x[zero_idx_mask, :] return x
[docs]class GravNet(nn.Module): def __init__(self, in_channels, out_channels, hidden_channels=256, n_layers=3, dropout=0.0, kwargs={}): super(GravNet, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.hidden_channels = hidden_channels self.n_layers = n_layers self.dropout = dropout self.kwargs = kwargs self.convs = nn.ModuleList() for i in range(self.n_layers-1): n_in = self.in_channels if i == 0 else self.hidden_channels self.convs.append(GravNetConv(n_in, self.hidden_channels, aggr='add', space_dimensions=3, propagate_dimensions=2, k=20, **self.kwargs)) # Last layer self.convs.append(GravNetConv(self.hidden_channels, self.out_channels, aggr='add', space_dimensions=3, propagate_dimensions=2, k=20, **self.kwargs))
[docs] def forward(self, data): x, batch = data.x, data.batch # edge information not used for i in range(self.n_layers-1): x = self.convs[i](x) x = F.leaky_relu(x) x = F.dropout(x, p=self.dropout, training=True) x = self.convs[-1](x) zero_idx_mask = get_zero_nodes(batch) x = x[zero_idx_mask, :] return x