torch.norm的理解

官方文档

torch.norm是对输入的Tensor求范数

1.版本1--------------求张量范数

torch.norm(input, p=2) → float

参数:

  • input (Tensor) – 输入张量
  • p (float,optional) – 范数计算中的幂指数值

这是pytorch中的默认版本。输入为一个Tensor,输出是一个数。没啥多说的,直接上例子:

import torch
import torch.tensor as tensor

a = torch.ones((2,3))  #建立tensor
a2 = torch.norm(a)      #默认求2范数
a1 = torch.norm(a,p=1)  #指定求1范数

print(a)
print(a2)
print(a1)
 

2.版本2---------------------求指定维度上的范数

torch.norm(input, p, dim, out=None,keepdim=False) → Tensor

返回输入张量给定维dim 上每行的p 范数。 

参数:

  • input (Tensor) – 输入张量
  • p (float) – 范数计算中的幂指数值
  • dim (int) – 缩减的维度
  • out (Tensor, optional) – 结果张量
  • keepdim(bool)– 保持输出的维度  (此参数官方文档中未给出,但是很常用)

其中p,input,output与版本1相同,不做赘述。我们重点看dim和keepdim两个参数。

先看dim

import torch
import torch.tensor as tensor

a = tensor([[1, 2, 3, 4],
        [1, 2, 3, 4]]).float()  #norm仅支持floatTensor,a是一个2*4的Tensor
a0 = torch.norm(a,p=2,dim=0)    #按0维度求2范数
a1 = torch.norm(a,p=2,dim=1)    #按1维度求2范数
print(a0)
print(a1)

扫描二维码关注公众号,回复: 4619073 查看本文章

可以看输出,dim=0是对0维度上的一个向量求范数,返回结果数量等于其列的个数,也就是说有多少个0维度的向量,将得到多少个范数。dim=1同理。

再看keepdim

其含义是保持输出的维度,挺抽象的,我们还是通过具体例子来看吧。

a = torch.rand((2,3,4))
at = torch.norm(a,p=2,dim=1,keepdim=True)   #保持维度
af = torch.norm(a,p=2,dim=1,keepdim=False)  #不保持维度

print(a.shape)
print(at.shape)
print(af.shape)

我们重点来看一下,输入tensor,keepdim=True和keepdim=False的形状。

可以发现,当keepdim=False时,输出比输入少一个维度(就是指定的dim求范数的维度)。而keepdim=True时,输出与输入维度相同,仅仅是输出在求范数的维度上元素个数变为1。这也是为什么有时我们把参数中的dim称为缩减的维度,因为norm运算之后,此维度或者消失或者元素个数变为1.

猜你喜欢

转载自blog.csdn.net/goodxin_ie/article/details/84657975