The torch.copy_() function copies the tensor.
Tensor.copy_(src, non_blocking=False) → Tensor
Official website definition: Copies the elements from src
into self
tensor and returns self
.
1) If the method of direct assignment is adopted , the storage pointer of the original weight will point to the storage area of the newly obtained weight tensor; and if .copy_() is used, only the value of the original weight is changed, and the storage area does not change. copy_() is good for reducing operations.
import torch
x = torch.tensor([[1,2], [3,4], [5,6]])
y = torch.rand((3,2)) # [0,1)之间均匀分布
print(y,id(y))
y = x #赋值操作,导致y的地址指向变了
print(y,id(y))
import torch
x = torch.tensor([[1,2], [3,4], [5,6]])
y = torch.rand((3,2)) # [0,1)之间均匀分布
print(y,id(y))
y.copy_(x) # copy_()操作,y的地址指向没变,只是重新赋值。
print(y,id(y))
2) Copy the designated area
import torch
x = torch.tensor([[1,2], [3,4], [5,6]])
z = torch.rand((4,2))
z[:x.shape[0],:x.shape[1]].copy_(x) #只拷贝x的大小区域
print(z[:x.shape[0],:x.shape[1]].copy_(x))
print(z)
Reference article " [Pytorch Model Training Issue] The impact of tensor.copy_() and direct assignment (=) on training time-consuming