This article is a collection of commonly used code snippets for PyTorch, covering five aspects including basic configuration, tensor processing, model definition and operation, data processing, model training and testing, and also gives a number of noteworthy Tips, the content is very comprehensive.
The best source for PyTorch is the official documentation. This article is a common code segment of PyTorch, and some repairs have been made on the basis of reference [1] (Zhang Hao: PyTorch Cookbook), which is convenient for reference when using
basic configuration
Import package and version query
import torch
import torch.nn as nn
import torchvision
print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())
print(torch.cuda.get_device_name(0))
Reproducibility
When the hardware devices (CPU, GPU) are different, complete reproducibility cannot be guaranteed, even if the random seed is the same. However, on the same device, reproducibility should be guaranteed. 固定torch的随机种子
The specific method is to fix the random seed of numpy at the beginning of the program .
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
graphics card settings
- if only needed
一张显卡
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- If you need to specify
多张显卡
, such as graphics card No. 0 and No. 1.
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
- You can also set the graphics card when running the code from the command line:
CUDA_VISIBLE_DEVICES=0,1 python train.py
- clear video memory
torch.cuda.empty_cache()
- You can also use the command to reset the GPU on the command line
nvidia-smi --gpu-reset -i [gpu_id]
Tensor processing
Data type of tensor
PyTorch has 9种CPU
tensor types and 9种GPU
tensor types.
Tensor basic information
tensor = torch.randn(3,4,5)
print(tensor.type()) # 数据类型
print(tensor.size()) # 张量的shape,是个元组
print(tensor.dim()) # 维度的数量
named tensor
Tensor naming is a very useful method, so that you can easily use the name of the dimension for indexing or other operations, which greatly improves readability, ease of use, and prevents errors.
# 在PyTorch 1.3之前,需要使用注释
# Tensor[N, C, H, W]
images = torch.randn(32, 3, 56, 56)
images.sum(dim=1)
images.select(dim=1, index=0)
# PyTorch 1.3之后
NCHW = [‘N’, ‘C’, ‘H’, ‘W’]
images = torch.randn(32, 3, 56, 56, names=NCHW)
images.sum('C')
images.select('C', index=0)
# 也可以这么设置
tensor = torch.rand(3,4,1,2,names=('C', 'N', 'H', 'W'))
# 使用align_to可以对维度方便地排序
tensor = tensor.align_to('N', 'C', 'H', 'W')
data type conversion
# 设置默认类型,pytorch中的FloatTensor远远快于DoubleTensor
torch.set_default_tensor_type(torch.FloatTensor)
# 类型转换
tensor = tensor.cuda()
tensor = tensor.cpu()
tensor = tensor.float()
tensor = tensor.long()
torch.Tensor and np.ndarray conversion
Tensors on all CPUs except CharTensor support conversion to numpy format and back again.
ndarray = tensor.cpu().numpy()
tensor = torch.from_numpy(ndarray).float()
tensor = torch.from_numpy(ndarray.copy()).float() # If ndarray has negative stride.
Torch.tensor and PIL.Image conversion
# pytorch中的张量默认采用[N, C, H, W]的顺序,并且数据范围在[0,1],需要进行转置和规范化
# torch.Tensor -> PIL.Image
image = PIL.Image.fromarray(torch.clamp(tensor*255, min=0, max=255).byte().permute(1,2,0).cpu().numpy())
image = torchvision.transforms.functional.to_pil_image(tensor) # Equivalently way
# PIL.Image -> torch.Tensor
path = r'./figure.jpg'
tensor = torch.from_numpy(np.asarray(PIL.Image.open(path))).permute(2,0,1).float() / 255
tensor = torchvision.transforms.functional.to_tensor(PIL.Image.open(path)) # Equivalently way
Conversion of np.ndarray and PIL.Image
image = PIL.Image.fromarray(ndarray.astype(np.uint8))
ndarray = np.asarray(PIL.Image.open(path))
Extract a value from a tensor containing only one element
value = torch.rand(1).item()
tensor deformation
# 在将卷积层输入全连接层的情况下通常需要对张量做形变处理,
# 相比torch.view,torch.reshape可以自动处理输入张量不连续的情况
tensor = torch.rand(2,3,4)
shape = (6, 4)
tensor = torch.reshape(tensor, shape)
mess up the order
tensor = tensor[torch.randperm(tensor.size(0))] # 打乱第一个维度
horizontal flip
# pytorch不支持tensor[::-1]这样的负步长操作,水平翻转可以通过张量索引实现
# 假设张量的维度为[N, D, H, W].
tensor = tensor[:,:,:,torch.arange(tensor.size(3) - 1, -1, -1).long()]
copy tensor
# Operation
| New/Shared memory | Still in computation graph |tensor.clone()
# | New | Yes |tensor.detach() # | Shared | No |tensor.detach.clone()() # | New | No |
tensor concatenation
'''
注意torch.cat和torch.stack的区别在于torch.cat沿着给定的维度拼接,
而torch.stack会新增一维。例如当参数是3个10x5的张量,torch.cat的结果是30x5的张量,
而torch.stack的结果是3x10x5的张量。
'''
tensor = torch.cat(list_of_tensors, dim=0)
tensor = torch.stack(list_of_tensors, dim=0)
Convert integer labels to one-hot encoding
# pytorch的标记默认从0开始
tensor = torch.tensor([0, 2, 1, 3])
N = tensor.size(0)
num_classes = 4
one_hot = torch.zeros(N, num_classes).long()
one_hot.scatter_(dim=1, index=torch.unsqueeze(tensor, dim=1), src=torch.ones(N, num_classes).long())
get the non-zero elements
torch.nonzero(tensor) # index of non-zero elements
torch.nonzero(tensor==0) # index of zero elements
torch.nonzero(tensor).size(0) # number of non-zero elements
torch.nonzero(tensor == 0).size(0) # number of zero elements
Determines that two tensors are equal
torch.allclose(tensor1, tensor2) # float tensor
torch.equal(tensor1, tensor2) # int tensor
Tensor expansion
# Expand tensor of shape 64*512 to shape 64*512*7*7.
tensor = torch.rand(64,512)
torch.reshape(tensor, (64, 512, 1, 1)).expand(64, 512, 7, 7)
matrix multiplication
# Matrix multiplcation: (m*n) * (n*p) * -> (m*p).
result = torch.mm(tensor1, tensor2)
# Batch matrix multiplication: (b*m*n) * (b*n*p) -> (b*m*p)
result = torch.bmm(tensor1, tensor2)
# Element-wise multiplication.
result = tensor1 * tensor2
Calculate the pairwise Euclidean distance between two sets of data
Use the broadcast mechanism
dist = torch.sqrt(torch.sum((X1[:,None,:] - X2) ** 2, dim=2))