torch.nn.functional.normalize详解

torch.nn.functional.normalize

torch.nn.functional.normalize(input, p=2, dim=1, eps=1e-12, out=None)

功能:将某一个维度除以那个维度对应的范数(默认是2范数)。
v = v max ( v p , ϵ ) v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}

主要讲以下三种情况:

  • 输入为一维Tensor

    a = torch.Tensor([1,2,3])
    
    torch.nn.functional.normalize(a, dim=0)
    
    tensor([0.2673, 0.5345, 0.8018])
    

    可以看到每一个数字都除以了这个Tensor的范数: 1 2 + 2 2 + 3 2 = 3.7416 \sqrt{1^2+2^2+3^2}=3.7416

  • 输入为二维Tensor

    b = torch.Tensor([[1,2,3], [4,5,6]])
    
    torch.nn.functional.normalize(b, dim=0)
    
    tensor([[0.2425, 0.3714, 0.4472],
            [0.9701, 0.9285, 0.8944]])
    

    因为dim=0,所以是对列操作。以第一列为例,整体除以了第一列的范数: 1 2 + 4 2 = 4.1231 \sqrt{1^2+4^2}=4.1231

    b = torch.Tensor([[1,2,3], [4,5,6]])
    
    torch.nn.functional.normalize(b, dim=1)
    
    tensor([[0.2673, 0.5345, 0.8018],
            [0.4558, 0.5698, 0.6838]])
    

    因为dim=1,所以是对行操作。以第一行为例,整体除以了第一行的范数: 1 2 + 2 2 + 3 2 = 3.7416 \sqrt{1^2+2^2+3^2}=3.7416

  • 输入为三维Tensor

    b = torch.Tensor([[[1,2,3], [4,5,6]], [[1,2,3], [4,5,6]]])
    
    torch.nn.functional.normalize(b, dim=2)
    
    tensor([[[0.2673, 0.5345, 0.8018],
             [0.4558, 0.5698, 0.6838]],
    
            [[0.2673, 0.5345, 0.8018],
             [0.4558, 0.5698, 0.6838]]])
    

    注意此时dim=2,所以是对第三个维度,也就是每一行操作。以第一行为例,除以了第一行的范数: 1 2 + 2 2 + 3 2 = 3.7416 \sqrt{1^2+2^2+3^2}=3.7416

    b = torch.Tensor([[[1,2,3], [4,5,6]], [[1,2,3], [4,5,6]]])
    
    torch.nn.functional.normalize(b, dim=1)
    
    tensor([[[0.2425, 0.3714, 0.4472],
             [0.9701, 0.9285, 0.8944]],
    
            [[0.2425, 0.3714, 0.4472],
             [0.9701, 0.9285, 0.8944]]])
    

    注意此时dim=1,所以是对第二个维度操作。第二个维度是二维数组,所以此时相当于对二维数组的第0维操作。
    [[1,2,3], [4,5,6]]为例,此时要对它的列操作。第一列要除以这一列的范数: 1 2 + 4 2 = 4.1231 \sqrt{1^2+4^2}=4.1231

    b = torch.Tensor([[[1,2,3], [4,5,6]], [[1,2,3], [4,5,6]]])
    
    torch.nn.functional.normalize(b, dim=0)
    
    tensor([[[0.7071, 0.7071, 0.7071],
             [0.7071, 0.7071, 0.7071]],
    
            [[0.7071, 0.7071, 0.7071],
             [0.7071, 0.7071, 0.7071]]])
    

    dim=0的时候现在还看不懂,以后再补吧。

参考:pytorch-document

发布了173 篇原创文章 · 获赞 28 · 访问量 6万+

猜你喜欢

转载自blog.csdn.net/ECNU_LZJ/article/details/103653133