MSELoss() 函数

MSELoss ( 均 值 损 失 ) pytorch

def MSELoss(pred,target):
    return (pred-target)**2

在这里插入图片描述
代码示例

import torch
import torch.nn as nn
 
a = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
b = torch.tensor([[3, 5], [8, 6]], dtype=torch.float32)
 
loss_fn1 = torch.nn.MSELoss(reduction='none')
loss1 = loss_fn1(a, b)
print(loss1)   # 输出结果:tensor([[ 4.,  9.],
               #                 [25.,  4.]])
 
loss_fn2 = torch.nn.MSELoss(reduction='sum')
loss2 = loss_fn2(a, b)
print(loss2)   # 输出结果:tensor(42.)
 
 
loss_fn3 = torch.nn.MSELoss(reduction='mean')
loss3 = loss_fn3(a, b)
print(loss3)   # 输出结果:tensor(10.5000)

猜你喜欢

转载自blog.csdn.net/DENGSHUCHAO152/article/details/125042592