torch.mean

求取指定维度的均值

import torch
x=torch.arange(15).view(3,5)
x = x.float()
print(x)
x = x.mean(dim=1,keepdim=True)
print(x)

输出:

tensor([[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.]])
tensor([[ 2.], [ 7.], [12.]])

猜你喜欢

转载自blog.csdn.net/weixin_40210307/article/details/90044058