PyTorch典型函数——torch.roll()

最近在看一些论文的源码过程中,在循环窗口以为过程中,出现了 torch.roll() 这个函数,roll的英文含义为翻滚的意思,在torch中主要表示在对应维度上进行循环移位,其实类似于数据的扩增方式了,由于平时在其他的代码中不常见,这里记录下改代码的具体操作。

torch.roll()

PyTorch官网API点击前往
在这里插入图片描述
这里需要关注的是roll函数中的shifts 和 dims 两者的维度必须是一样的,即 shifts中主要是指定对应维度移动多少个像素,dims指定在哪个维度上进行移动。

实例代码

# -*- coding:utf-8 -*-
import torch    
import numpy as np   
import matplotlib.pyplot as plt

shift_size = 3
'''构造多维张量'''
x=np.arange(301056).reshape(1,56,56,96)
x=torch.from_numpy(x)

if shift_size > 0:
    shifted_x = torch.roll(x, shifts=(-20, -20), dims=(1, 2))
    #shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
    print("---------经过循环位移了---------")
else:
    shifted_x = x
   
'''可视化部分'''
plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(x[0,:,:,0])
plt.title("orgin_img")
plt.subplot(1,2,2)
plt.imshow(shifted_x[0,:,:,0])
if torch.equal(shifted_x, x):
    plt.title("non_shifted")
else:
    plt.title("shifted_img")
plt.show()
plt.pause(5)
plt.close()


这幅图是 shifts = (-20,-20) 在 dim(1,2)即 H,W两个维度上进行循环移动的结果,关注移位图的位置,我们不难发现,当 shifts 值为负时,相当于将图像向上,向左进行移动,即将原来最上方和最左端的像素,移动至底部和右端。大家可以进一步尝试将 shifts 设置为不同的值,进行观察。
在这里插入图片描述
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_29750461/article/details/128844310