Summary of common code snippets in PyTorch

This article is a collection of commonly used code snippets for PyTorch, covering five aspects including basic configuration, tensor processing, model definition and operation, data processing, model training and testing, and also gives a number of noteworthy Tips, the content is very comprehensive.

The best source for PyTorch is the official documentation. This article is a common code segment of PyTorch, and some repairs have been made on the basis of reference [1] (Zhang Hao: PyTorch Cookbook), which is convenient for reference when using

basic configuration

Import package and version query

import torch
import torch.nn as nn
import torchvision
print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())
print(torch.cuda.get_device_name(0))

Reproducibility

When the hardware devices (CPU, GPU) are different, complete reproducibility cannot be guaranteed, even if the random seed is the same. However, on the same device, reproducibility should be guaranteed. 固定torch的随机种子The specific method is to fix the random seed of numpy at the beginning of the program .

np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

graphics card settings

  • if only needed一张显卡
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  • If you need to specify 多张显卡, such as graphics card No. 0 and No. 1.
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
  • You can also set the graphics card when running the code from the command line:
CUDA_VISIBLE_DEVICES=0,1 python train.py
  • clear video memory
torch.cuda.empty_cache()
  • You can also use the command to reset the GPU on the command line
nvidia-smi --gpu-reset -i [gpu_id]

Tensor processing

Data type of tensor

PyTorch has 9种CPUtensor types and 9种GPUtensor types.
insert image description here

Tensor basic information

tensor = torch.randn(3,4,5)
print(tensor.type())  # 数据类型
print(tensor.size())  # 张量的shape,是个元组
print(tensor.dim())   # 维度的数量

named tensor

Tensor naming is a very useful method, so that you can easily use the name of the dimension for indexing or other operations, which greatly improves readability, ease of use, and prevents errors.

# 在PyTorch 1.3之前,需要使用注释
# Tensor[N, C, H, W]
images = torch.randn(32, 3, 56, 56)
images.sum(dim=1)
images.select(dim=1, index=0)

# PyTorch 1.3之后
NCHW = [‘N’, ‘C’, ‘H’, ‘W’]
images = torch.randn(32, 3, 56, 56, names=NCHW)
images.sum('C')
images.select('C', index=0)
# 也可以这么设置
tensor = torch.rand(3,4,1,2,names=('C', 'N', 'H', 'W'))
# 使用align_to可以对维度方便地排序
tensor = tensor.align_to('N', 'C', 'H', 'W')

data type conversion

# 设置默认类型,pytorch中的FloatTensor远远快于DoubleTensor
torch.set_default_tensor_type(torch.FloatTensor)
# 类型转换
tensor = tensor.cuda()
tensor = tensor.cpu()
tensor = tensor.float()
tensor = tensor.long()

torch.Tensor and np.ndarray conversion

Tensors on all CPUs except CharTensor support conversion to numpy format and back again.

ndarray = tensor.cpu().numpy()
tensor = torch.from_numpy(ndarray).float()
tensor = torch.from_numpy(ndarray.copy()).float() # If ndarray has negative stride.

Torch.tensor and PIL.Image conversion

# pytorch中的张量默认采用[N, C, H, W]的顺序,并且数据范围在[0,1],需要进行转置和规范化
# torch.Tensor -> PIL.Image
image = PIL.Image.fromarray(torch.clamp(tensor*255, min=0, max=255).byte().permute(1,2,0).cpu().numpy())
image = torchvision.transforms.functional.to_pil_image(tensor)  # Equivalently way

# PIL.Image -> torch.Tensor
path = r'./figure.jpg'
tensor = torch.from_numpy(np.asarray(PIL.Image.open(path))).permute(2,0,1).float() / 255
tensor = torchvision.transforms.functional.to_tensor(PIL.Image.open(path)) # Equivalently way

Conversion of np.ndarray and PIL.Image

image = PIL.Image.fromarray(ndarray.astype(np.uint8))
ndarray = np.asarray(PIL.Image.open(path))

Extract a value from a tensor containing only one element

value = torch.rand(1).item()

tensor deformation

# 在将卷积层输入全连接层的情况下通常需要对张量做形变处理,
# 相比torch.view,torch.reshape可以自动处理输入张量不连续的情况

tensor = torch.rand(2,3,4)
shape = (6, 4)
tensor = torch.reshape(tensor, shape)

mess up the order

tensor = tensor[torch.randperm(tensor.size(0))]  # 打乱第一个维度

horizontal flip

# pytorch不支持tensor[::-1]这样的负步长操作,水平翻转可以通过张量索引实现
# 假设张量的维度为[N, D, H, W].

tensor = tensor[:,:,:,torch.arange(tensor.size(3) - 1, -1, -1).long()]

copy tensor

# Operation                
 |  New/Shared memory | Still in computation graph |tensor.clone()           
  # |        New         |          Yes               |tensor.detach()           # |      Shared        |          No                |tensor.detach.clone()()   # |        New         |          No                |

tensor concatenation

'''
注意torch.cat和torch.stack的区别在于torch.cat沿着给定的维度拼接,
而torch.stack会新增一维。例如当参数是3个10x5的张量,torch.cat的结果是30x5的张量,
而torch.stack的结果是3x10x5的张量。
'''
tensor = torch.cat(list_of_tensors, dim=0)
tensor = torch.stack(list_of_tensors, dim=0)

Convert integer labels to one-hot encoding

# pytorch的标记默认从0开始
tensor = torch.tensor([0, 2, 1, 3])
N = tensor.size(0)
num_classes = 4
one_hot = torch.zeros(N, num_classes).long()
one_hot.scatter_(dim=1, index=torch.unsqueeze(tensor, dim=1), src=torch.ones(N, num_classes).long())

get the non-zero elements

torch.nonzero(tensor)               # index of non-zero elements
torch.nonzero(tensor==0)            # index of zero elements
torch.nonzero(tensor).size(0)       # number of non-zero elements
torch.nonzero(tensor == 0).size(0)  # number of zero elements

Determines that two tensors are equal

torch.allclose(tensor1, tensor2)  # float tensor
torch.equal(tensor1, tensor2)     # int tensor

Tensor expansion

# Expand tensor of shape 64*512 to shape 64*512*7*7.
tensor = torch.rand(64,512)
torch.reshape(tensor, (64, 512, 1, 1)).expand(64, 512, 7, 7)

matrix multiplication

# Matrix multiplcation: (m*n) * (n*p) * -> (m*p).
result = torch.mm(tensor1, tensor2)

# Batch matrix multiplication: (b*m*n) * (b*n*p) -> (b*m*p)
result = torch.bmm(tensor1, tensor2)

# Element-wise multiplication.
result = tensor1 * tensor2

Calculate the pairwise Euclidean distance between two sets of data

Use the broadcast mechanism

dist = torch.sqrt(torch.sum((X1[:,None,:] - X2) ** 2, dim=2))

Guess you like

Origin blog.csdn.net/weixin_38346042/article/details/131361528