pytorch中tensor的复制,tensor.repeat()与tensor.tile()与tensor.repeat_interleave()

前言:

留个坑,稍后补上

博主最近在复现GAMSNet的时候,用到将张量B*1*H*W转化为B*C*H*W的操作,pytorch提供三个张量复制的函数tensor.repeat()与tensor.tile()与tensor.repeat_interleave()。本文主要说明这三个函数的使用方法和梯度求解的注意事项。

一、源码解析

tensor.repeat() 与 tensor.tile() tensor.repeat_interleave() 的源码,pycharm鼠标点进去可以看到

        可以看出,三个函数返回的对象是一致,区别是传入的参数不一样。也就是说当三者传入的参数完全一致的时候,这三个函数的功能是一模一样的

沿着指定的维度复制指定次数,

repeat(*sizes) -> Tensor
Repeats this tensor along the specified dimensions.
Unlike :meth:`~Tensor.expand`, this function copies the tensor's data.

\torch\_C\_TensorBase.py

x = torch.randn([2,1,1,1], requires_grad=True)
print(x)
y = x.repeat([1,3,1,1])
print(y)
print('x size:',x.size())
print('y size:',y.size())
tensor([[[[-1.9459]]],

        [[[-1.2008]]]], requires_grad=True)

tensor([[[[-1.9459]],
         [[-1.9459]],
         [[-1.9459]]],

        [[[-1.2008]],
         [[-1.2008]],
         [[-1.2008]]]], grad_fn=<RepeatBackward>)
x size: torch.Size([2, 1, 1, 1])
y size: torch.Size([2, 3, 1, 1])

猜你喜欢

转载自blog.csdn.net/weixin_44503976/article/details/129472229