l2
def l2_loss(gt, pred):
B, C, H, W = gt.size()
# loss = t.sum(t.abs(gt - pred))
loss = t.sum((gt - pred) * (gt - pred)) / (B * C * H * W * 10)
return loss
l2
def l2_loss(gt, pred):
B, C, H, W = gt.size()
# loss = t.sum(t.abs(gt - pred))
loss = t.sum((gt - pred) * (gt - pred)) / (B * C * H * W * 10)
return loss