pytorch中 max()、view()、 squeeze()、 unsqueeze()

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qian1996/article/details/81265974

查了好多博客都似懂非懂,后来写了几个小例子,瞬间一目了然。

目录

一、torch.max()

二、torch.view()

三、

1.torch.unsqueeze()

2.squeeze()


一、torch.max()

import torch 

a=torch.randn(3)
print("a:\n",a)
print('max(a):',torch.max(a))

b=torch.randn(3,4)
print("b:\n",b)
print('max(b,0):',torch.max(b,0))
print('max(b,1):',torch.max(b,1))

   输出:

a:
 tensor([ 0.9558,  1.1242,  1.9503])
max(a): tensor(1.9503)
b:
 tensor([[ 0.2765,  0.0726, -0.7753,  1.5334],
        [ 0.0201, -0.0005,  0.2616, -1.1912],
        [-0.6225,  0.6477,  0.8259,  0.3526]])
max(b,0): (tensor([ 0.2765,  0.6477,  0.8259,  1.5334]), tensor([ 0,  2,  2,  0]))
max(b,1): (tensor([ 1.5334,  0.2616,  0.8259]), tensor([ 3,  2,  2]))

  max(a),用于一维数据,求出最大值。

  max(a,0),计算出数据中一列的最大值,并输出最大值所在的行号。

  max(a,0),计算出数据中一行的最大值,并输出最大值所在的列号。

print('max(b,1):',torch.max(b,1)[1])

  输出:只输出行最大值所在的列号

max(b,1): tensor([ 3,  2,  2])

  torch.max(b,1)[0], 只返回最大值的每个数

二、view()

a.view(i,j)表示将原矩阵转化为i行j列的形式   

  i为-1表示不限制行数,输出1列

a=torch.randn(3,4)
print(a)

输出:
tensor([[-0.8146, -0.6592,  1.5100,  0.7615],
        [ 1.3021,  1.8362, -0.3590,  0.3028],
        [ 0.0848,  0.7700,  1.0572,  0.6383]])

b=a.view(-1,1)
print(b)

输出:
tensor([[-0.8146],
        [-0.6592],
        [ 1.5100],
        [ 0.7615],
        [ 1.3021],
        [ 1.8362],
        [-0.3590],
        [ 0.3028],
        [ 0.0848],
        [ 0.7700],
        [ 1.0572],
        [ 0.6383]])

i为1,j为-1表示不限制列数,输出1行

b=a.view(1,-1)
print(b)

输出:
tensor([[-0.8146, -0.6592,  1.5100,  0.7615,  1.3021,  1.8362, -0.3590,
          0.3028,  0.0848,  0.7700,  1.0572,  0.6383]])

 i为-1,j为2表示不限制行数,输出2列

b=a.view(-1,2)
print(b)

输出:
tensor([[-0.8146, -0.6592],
        [ 1.5100,  0.7615],
        [ 1.3021,  1.8362],
        [-0.3590,  0.3028],
        [ 0.0848,  0.7700],
        [ 1.0572,  0.6383]])

 i为-1,j为3表示不限制行数,输出3列 

 i为4,j为3表示输出4行3列 

b=a.view(-1,3)
print(b)
b=a.view(4,3)
print(b)

输出:
tensor([[-0.8146, -0.6592,  1.5100],
        [ 0.7615,  1.3021,  1.8362],
        [-0.3590,  0.3028,  0.0848],
        [ 0.7700,  1.0572,  0.6383]])
tensor([[-0.8146, -0.6592,  1.5100],
        [ 0.7615,  1.3021,  1.8362],
        [-0.3590,  0.3028,  0.0848],
        [ 0.7700,  1.0572,  0.6383]])

三、

1.torch.squeeze()

压缩矩阵,我理解为降维

a.squeeze(i)   压缩第i维,如果这一维维数是1,则这一维可有可无,便可以压缩

import torch 

a=torch.randn(1,3,4)
print(a)
b=a.squeeze(0)
print(b)
c=a.squeeze(1)
print(c

输出:

tensor([[[ 0.4627,  1.6447,  0.1320,  2.0946],
         [-0.0080,  0.1794,  1.1898, -1.2525],
         [ 0.8281, -0.8166,  1.8846,  0.9008]]])
一页三行4列的矩阵

第0维为1,则可以通过squeeze(0)删掉,转化为三行4列的矩阵
tensor([[ 0.4627,  1.6447,  0.1320,  2.0946],
        [-0.0080,  0.1794,  1.1898, -1.2525],
        [ 0.8281, -0.8166,  1.8846,  0.9008]])
第1维不为1,则不可以压缩
tensor([[[ 0.4627,  1.6447,  0.1320,  2.0946],
         [-0.0080,  0.1794,  1.1898, -1.2525],
         [ 0.8281, -0.8166,  1.8846,  0.9008]]])

2.torch.unsqueeze()

 unsqueeze(i) 表示将第i维设置为1

对压缩为3行4列后的矩阵b进行操作,将第0维设置为1

c=b.unsqueeze(0)
print(c)

输出一个一页三行四列的矩阵
tensor([[[ 0.0661, -0.2386, -0.6610,  1.5774],
         [ 1.2210, -0.1084, -0.1166, -0.2379],
         [-1.0012, -0.4363,  1.0057, -1.5180]]])

将第一维设置为1 

c=b.unsqueeze(1)
print(c)

输出一个3页,一行,4列的矩阵
tensor([[[-1.0067, -1.1477, -0.3213, -1.0633]],

        [[-2.3976,  0.9857, -0.3462, -0.3648]],

        [[ 1.1012, -0.4659, -0.0858,  1.6631]]])

另外,squeeze、unsqueeze操作不改变原矩阵

ok!!!

猜你喜欢

转载自blog.csdn.net/qian1996/article/details/81265974