前言:
留个坑,稍后补上
博主最近在复现GAMSNet的时候,用到将张量B*1*H*W转化为B*C*H*W的操作,pytorch提供三个张量复制的函数tensor.repeat()与tensor.tile()与tensor.repeat_interleave()。本文主要说明这三个函数的使用方法和梯度求解的注意事项。
一、源码解析
tensor.repeat() 与 tensor.tile() tensor.repeat_interleave() 的源码,pycharm鼠标点进去可以看到
可以看出,三个函数返回的对象是一致,区别是传入的参数不一样。也就是说当三者传入的参数完全一致的时候,这三个函数的功能是一模一样的。
沿着指定的维度复制指定次数,
repeat(*sizes) -> Tensor Repeats this tensor along the specified dimensions. Unlike :meth:`~Tensor.expand`, this function copies the tensor's data.
\torch\_C\_TensorBase.py
x = torch.randn([2,1,1,1], requires_grad=True)
print(x)
y = x.repeat([1,3,1,1])
print(y)
print('x size:',x.size())
print('y size:',y.size())
tensor([[[[-1.9459]]],
[[[-1.2008]]]], requires_grad=True)
tensor([[[[-1.9459]],
[[-1.9459]],
[[-1.9459]]],
[[[-1.2008]],
[[-1.2008]],
[[-1.2008]]]], grad_fn=<RepeatBackward>)
x size: torch.Size([2, 1, 1, 1])
y size: torch.Size([2, 3, 1, 1])