Source code for n2j.losses.local_global_loss
import torch.nn as nn
__all__ = ['MSELoss']
[docs]class MSELoss:
def __init__(self):
self.local_mse = nn.MSELoss(reduction='mean')
self.global_mse = nn.MSELoss(reduction='mean')
[docs] def __call__(self, pred, target):
pred_local, pred_global = pred
target_local, target_global = target
mse = self.local_mse(pred_local, target_local)
mse += self.global_mse(pred_global, target_global)
return mse