Pytorch常用函数功能使用(一)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/sinat_35512245/article/details/88090301

1. view

import torch

number_1 = torch.randn(2, 3)

print(number_1)
print(number_1.shape)

print(number_1.view(1, -1))
print(number_1.view(3, -1))

输出:

tensor([[ 1.0506, -0.5875, -1.2477],
        [ 0.0635,  0.8997,  0.1551]])
        
torch.Size([2, 3])

tensor([[ 1.0506, -0.5875, -1.2477,  0.0635,  0.8997,  0.1551]])

tensor([[ 1.0506, -0.5875],
        [-1.2477,  0.0635],
        [ 0.8997,  0.1551]])

View(a,b)中第一个参数a代表目标张量的行数,b代表列数。为了简便起见,也可以只指定第一个参数a,b这个参数设置成-1,函数会自动计算对应的列数。

2. squeeze

number_2 = torch.randn(2, 1)

print(number_2)
print(torch.squeeze(number_2))
print(torch.squeeze(number_2, 0))
print(torch.squeeze(number_2, 1))

输出:

tensor([[ 0.5856],
        [-1.7095]])
tensor([ 0.5856, -1.7095])
tensor([[ 0.5856],
        [-1.7095]])
tensor([ 0.5856, -1.7095])

Squeeze的功能是进行维度缩减(维度为1的删除)。Squeeze(a,b)中第一个参数a代表传入的张量,b代表要缩减的维数。如果第二个参数没有指定,则默认删除所有维度为1的维度

number_3 = torch.randn(1, 2)

print(number_3)
print(torch.squeeze(number_3))
print(torch.squeeze(number_3, 0))
print(torch.squeeze(number_3, 1))

输出:

tensor([[ 0.1555, -0.4286]])
tensor([ 0.1555, -0.4286])
tensor([ 0.1555, -0.4286])
tensor([[ 0.1555, -0.4286]])

3. unsqueeze

number_4 = torch.randn(3, 2)

print(number_4)
print(torch.unsqueeze(number_4, 0))
print(torch.unsqueeze(number_4, 1))

输出:

tensor([[ 0.0358, -0.2769],
        [-0.3257,  0.1895],
        [ 1.9278, -0.9444]])
tensor([[[ 0.0358, -0.2769],
         [-0.3257,  0.1895],
         [ 1.9278, -0.9444]]])
tensor([[[ 0.0358, -0.2769]],

        [[-0.3257,  0.1895]],

        [[ 1.9278, -0.9444]]])

Unsqueeze的功能与squeeze相反,可以增加张量的维度。Unqueeze(a,b)中第一个参数a代表传入的张量,b代表要增加维度的维数。

4. max

number_5 = torch.randn(2, 3)
print(number_5)
print(torch.max(number_5, 0))
print(torch.max(number_5, 1))

输出:

tensor([[-0.4916,  1.3999,  1.0527],
        [ 1.0194, -2.4695, -0.2378]])
(tensor([1.0194, 1.3999, 1.0527]), tensor([1, 0, 0]))
(tensor([1.3999, 1.0194]), tensor([1, 0]))

Max的功能是返回对应维度最大的数以及对应的索引。Max(a,b)中第一个参数a代表传入的张量,b代表要对应的维数。0代表返回每一列的最大值,1代表返回每一行的最大值。

猜你喜欢

转载自blog.csdn.net/sinat_35512245/article/details/88090301
今日推荐