n2j.models.n2jnet#

Module Contents#

Classes#

N2JNet

class n2j.models.n2jnet.N2JNet(dim_in, dim_out_local, dim_out_global, dim_local, dim_global, dim_hidden=20, dim_pre_aggr=20, n_iter=20, n_out_layers=5, global_flow=False, dropout=0.0, class_weight=None, use_ss=True, dim_in_meta=2, weight_sum=True)[source]#

Bases: torch.nn.Module

forward(data)[source]#
local_loss(x, data)[source]#
global_loss(u, data)[source]#
loss(x, u, data)[source]#