Pytorch commonly used functions and methods

  • enumerate()

The index and element value can be obtained. The second parameter of enumerate() is used to specify the starting index value, which can be omitted, for example

list = ["Python","Java","C"]

for index,key in enumerate(list):
    print(index,key)
  • torch.randn()

Output a normally distributed random number
example1:

import torch
import torch.nn as nn
input = torch.randn(1, 1, 3, 4)
print(input)

result:

tensor([[[[ 0.5028, -0.7468,  1.8858,  0.1745],
          [ 0.8540,  0.0401,  1.4751,  0.9010],
          [-0.3230, -0.4141, -0.4215,  0.1705]]]])

example2:

import torch
import torch.nn as nn
input = torch.randn(1, 2, 3, 4)
print(input)

result:

tensor([[[[ 1.5492,  0.9120,  0.9391, -0.2901],
          [-0.3356, -0.3431,  1.0347, -1.6674],
          [-0.1109, -0.1498, -1.2600,  0.1818]],

         [[ 1.8994,  0.0805, -2.7722, -1.1939],
          [ 1.4740, -0.2008, -0.1438, -1.1926],
          [-0.6315,  0.8516,  1.9624, -1.2148]]]])
  • nn.Conv2d()

The first three input parameters of nn.Conv2d() are (input_channels, output_chanels, kernel_sizes)
convolution. The key point is to explain that padding=N, then fill N number (default fill 0)
example1:

import torch
import torch.nn as nn
input = torch.randn(1, 1, 3, 4)
print(input)
m = nn.Conv2d(1,1,1,stride=1,bias=False)
print(m(input))
m1 = nn.Conv2d(1,1,1,stride=1,bias=False,padding=1)
print(m1(input))
m2 = nn.Conv2d(1,1,1,stride=1,bias=False,padding=2)
print(m2(input))

result:

tensor([[[[ 1.2922, -1.6056, -0.2292, -1.1778],   
          [-1.1310, -1.9764, -1.2235, -0.5288],   
          [ 1.5305, -0.1229, -1.3054,  1.3235]]]])
tensor([[[[ 0.3306, -0.4108, -0.0586, -0.3014],
          [-0.2894, -0.5057, -0.3131, -0.1353],
          [ 0.3916, -0.0314, -0.3340,  0.3386]]]],
       grad_fn=<ThnnConv2DBackward>)
tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],   
          [ 0.0000,  0.7538, -0.9366, -0.1337, -0.6871,  0.0000],
          [ 0.0000, -0.6598, -1.1529, -0.7137, -0.3085,  0.0000],
          [ 0.0000,  0.8928, -0.0717, -0.7615,  0.7720,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]]],
       grad_fn=<ThnnConv2DBackward>)
tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000],
          [ 0.0000,  0.0000, -0.3415,  0.4243,  0.0606,  0.3113,  0.0000,
            0.0000],
          [ 0.0000,  0.0000,  0.2989,  0.5223,  0.3233,  0.1397,  0.0000,
            0.0000],
          [ 0.0000,  0.0000, -0.4045,  0.0325,  0.3450, -0.3498,  0.0000,
            0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000]]]], grad_fn=<ThnnConv2DBackward>)
  • barters ()

Reverse the order of the dimensions in Tensor
example1:

import torch

x = torch.empty([6, 7, 8, 9])
print(x.size())
x = x.permute([0, 1, 3, 2])
print(x.size())

Output:

torch.Size([6, 7, 8, 9])
torch.Size([6, 7, 9, 8])

Before the replacement, the third dimension of tensor has 9 elements. After the replacement, the third dimension becomes 8 elements (replaced with the fourth dimension), and the fourth dimension has 9 elements

  • view()

The role of view is reshape(), which requires tensor to be continuous . The general usage is .contiguous().view(n, -1,...)
example1:

a = torch.arange(1, 17)  # a's shape is (16,)
 
print(a.view(4, 4)) # output below
print(a.view(2, 2, 4)) # output below

result:

tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12],
        [13, 14, 15, 16]])
[torch.FloatTensor of size 4x4]

tensor([[[ 1,  2,  3,  4],
         [ 5,  6,  7,  8]],
 
        [[ 9, 10, 11, 12],
         [13, 14, 15, 16]]])
[torch.FloatTensor of size 2x2x4]

a.view(2, 2, 4) means reshape into 2 2*4 matrices, if written as a.view(n,-1,4), it means reshape into n x rows and 4 columns matrix (-1 will automatically How many lines match)

Guess you like

Origin blog.csdn.net/weixin_44823313/article/details/113801399