Introduction to torch.copy_() function

 The torch.copy_() function copies the tensor.

 Tensor.copy_(srcnon_blocking=False) → Tensor

Official website definition: Copies the elements from  src into  self tensor and returns  self.

1) If the method of direct assignment is adopted , the storage pointer of the original weight will point to the storage area of ​​the newly obtained weight tensor; and if .copy_() is used, only the value of the original weight is changed, and the storage area does not change. copy_() is good for reducing operations.

import torch
x = torch.tensor([[1,2], [3,4], [5,6]])
y = torch.rand((3,2)) # [0,1)之间均匀分布
print(y,id(y))
y = x #赋值操作,导致y的地址指向变了
print(y,id(y))
import torch
x = torch.tensor([[1,2], [3,4], [5,6]])
y = torch.rand((3,2)) # [0,1)之间均匀分布
print(y,id(y))
y.copy_(x) # copy_()操作,y的地址指向没变,只是重新赋值。
print(y,id(y))

2) Copy the designated area

import torch
x = torch.tensor([[1,2], [3,4], [5,6]])

z = torch.rand((4,2))
z[:x.shape[0],:x.shape[1]].copy_(x) #只拷贝x的大小区域
print(z[:x.shape[0],:x.shape[1]].copy_(x))
print(z)

 Reference article " [Pytorch Model Training Issue] The impact of tensor.copy_() and direct assignment (=) on training time-consuming 

Guess you like

Origin blog.csdn.net/qimo601/article/details/128019342