[Pytorch]知识---torch.roll

torch.roll(input,shifts,dims=None)


input:输入张量
shifts:滚动的方向和长度,若为正,则向下滚动;若为负,则向上滚动。
dims:张量按什么维度滚动。以二维为例,0就是按行上下或者下上(shifts为负)滚动;1就是左右或者右左(shifts为负)。

这里参考的是(12条消息) 【Pytorch小知识】torch.roll()函数的用法及在Swin Transformer中的应用(详细易懂)_18岁小白想成大牛的博客-CSDN博客

 下面是官方举例

>>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2)
>>> x
tensor([[1, 2],
        [3, 4],
        [5, 6],
        [7, 8]])
>>> torch.roll(x, 1)
tensor([[8, 1],
        [2, 3],
        [4, 5],
        [6, 7]])
>>> torch.roll(x, 1, 0)
tensor([[7, 8],
        [1, 2],
        [3, 4],
        [5, 6]])
>>> torch.roll(x, -1, 0)
tensor([[3, 4],
        [5, 6],
        [7, 8],
        [1, 2]])
>>> torch.roll(x, shifts=(2, 1), dims=(0, 1))
tensor([[6, 5],
        [8, 7],
        [2, 1],
        [4, 3]])

torch.roll(x,1):向下滚动,1换到2的位置,2换到3,.....8换到1

torch.roll(x,1,0):沿着上下滚动,[1,2]换到[3,4],[3,4]换到[5,6],......[7,8]换到[1,2]

猜你喜欢

转载自blog.csdn.net/qq_46073783/article/details/130260856