torch.roll(input,shifts,dims=None)
input:输入张量
shifts:滚动的方向和长度,若为正,则向下滚动;若为负,则向上滚动。
dims:张量按什么维度滚动。以二维为例,0就是按行上下或者下上(shifts为负)滚动;1就是左右或者右左(shifts为负)。
这里参考的是(12条消息) 【Pytorch小知识】torch.roll()函数的用法及在Swin Transformer中的应用(详细易懂)_18岁小白想成大牛的博客-CSDN博客
下面是官方举例
>>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2)
>>> x
tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
>>> torch.roll(x, 1)
tensor([[8, 1],
[2, 3],
[4, 5],
[6, 7]])
>>> torch.roll(x, 1, 0)
tensor([[7, 8],
[1, 2],
[3, 4],
[5, 6]])
>>> torch.roll(x, -1, 0)
tensor([[3, 4],
[5, 6],
[7, 8],
[1, 2]])
>>> torch.roll(x, shifts=(2, 1), dims=(0, 1))
tensor([[6, 5],
[8, 7],
[2, 1],
[4, 3]])
torch.roll(x,1):向下滚动,1换到2的位置,2换到3,.....8换到1
torch.roll(x,1,0):沿着上下滚动,[1,2]换到[3,4],[3,4]换到[5,6],......[7,8]换到[1,2]