n2j.models#

Submodules#

Package Contents#

Classes#

GCNNet

GATNet

SageNet

GravNet

N2JNet

class n2j.models.GCNNet(in_channels, out_channels, hidden_channels=256, n_layers=3, dropout=0.0, kwargs={})[source]#

Bases: torch.nn.Module

forward(data)#
class n2j.models.GATNet(in_channels, out_channels, hidden_channels=256, kwargs={}, n_layers=3, dropout=0.0)[source]#

Bases: torch.nn.Module

forward(data)#
class n2j.models.SageNet(in_channels, out_channels, hidden_channels=256, n_layers=3, dropout=0.0, kwargs={})[source]#

Bases: torch.nn.Module

forward(data)#
class n2j.models.GravNet(in_channels, out_channels, hidden_channels=256, n_layers=3, dropout=0.0, kwargs={})[source]#

Bases: torch.nn.Module

forward(data)#
class n2j.models.N2JNet(dim_in, dim_out_local, dim_out_global, dim_local, dim_global, dim_hidden=20, dim_pre_aggr=20, n_iter=20, n_out_layers=5, global_flow=False, dropout=0.0, class_weight=None, use_ss=True, dim_in_meta=2, weight_sum=True)[source]#

Bases: torch.nn.Module

forward(data)#
local_loss(x, data)#
global_loss(u, data)#
loss(x, u, data)#