PyTorch中contiguous、view、Sequential、permute函数的用法

        在pytorch中,tensor的实际数据以一维数组(storage)的形式存储于某个连续的内存中,“行优先”进行存储。

1. tensor的连续性

        tensor连续(contiguous)是指tensor的storage元素排列顺序与其按行优先时的元素排列顺序相同。如下图所示:

        出现不连续现象,本质上是由于pytorch中不同tensor可能共用同一个storage导致的。
pytorch的很多操作都会导致tensor不连续,如tensor.transpose()(tensor.t())、tensor.narrow()、tensor.expand()。
以转置为例,因为转置操作前后共用同一个storage,但显然转置后的tensor按照行优先排列成1维后与原storage不同了,因此转置后结果属于不连续(见下例)。

2. tensor.is_contiguous()

        tensor.is_contiguous()用于判断tensor是否连续,以转置为例说明:

>>>a = torch.tensor([[1,2,3],[4,5,6]])
>>>print(a)
tensor([[1, 2, 3],
        [4, 5, 6]])
>>>print(a.storage())
 1
 2
 3
 4
 5
 6
[torch.LongStorage of size 6]
>>>print(a.is_contiguous()) #a是连续的
True

>>>b = a.t() #b是a的转置
>>>print(b)
tensor([[1, 4],
        [2, 5],
        [3, 6]])
>>>print(b.storage())
 1
 2
 3
 4
 5
 6
[torch.LongStorage of size 6]
>>>print(b.is_contiguous()) #b是不连续的
False

# 之所以出现b不连续,是因为转置操作前后是共用同一个storage的
>>>print(a.storage().data_ptr())
>>>print(b.storage().data_ptr())
2638924341056
2638924341056

3. tensor不连续的后果

        tensor不连续会导致某些操作无法进行,比如view()就无法进行。在上面的例子中:由于 b 是不连续的,所以对其进行view()操作会报错;b.view(3,2)没报错,因为b本身的shape就是(3,2)。

>>>b.view(2,3)
RuntimeError                              Traceback (most recent call last)
>>>b.view(1,6)
RuntimeError                              Traceback (most recent call last)
>>>b.view(-1)
RuntimeError                              Traceback (most recent call last)

>>>b.view(3,2)
tensor([[1, 4],
        [2, 5],
        [3, 6]])

4. tensor.contiguous()

        tensor.contiguous()返回一个与原始tensor有相同元素的 “连续”tensor,如果原始tensor本身就是连续的,则返回原始tensor。
        注意:tensor.contiguous()函数不会对原始数据做任何修改,他不仅返回一个新tensor,还为这个新tensor创建了一个新的storage,在这个storage上,该新的tensor是连续的。
继续使用上面的例子:

>>>c = b.contiguous()

# 形式上两者一样
>>>print(b)
>>>print(c)
tensor([[1, 4],
        [2, 5],
        [3, 6]])
tensor([[1, 4],
        [2, 5],
        [3, 6]])

# 显然storage已经不是同一个了
>>>print(b.storage())
>>>print(c.storage())
 1
 2
 3
 4
 5
 6
[torch.LongStorage of size 6]
 1
 4
 2
 5
 3
 6
[torch.LongStorage of size 6]
False

# b不连续,c是连续的
>>>print(b.is_contiguous())
False
>>>print(c.is_contiguous())
True

#此时执行c.view()不会出错
>>>c.view(2,3)
tensor([[1, 4, 2],
        [5, 3, 6]])

以上原文出自:tensor的连续性、tensor.is_contiguous()、tensor.contiguous() - 简书 (jianshu.com) 

 5. view()

        类似于resize操作,基于前面所说的tensor连续存储,view()函数把原tensor中的数据按照行优先的顺序排成一个一维的数据,然后按照参数组合成其他维度的tensor。

        举个例子:

a=torch.Tensor([[[1,2,3],[4,5,6]]])
b=torch.Tensor([1,2,3,4,5,6])

print(a.view(1,6))
print(b.view(1,6))

# 输出结果都是 tensor([[1, 2, 3, 4, 5, 6]]) 

再如输出3维向量: 

a=torch.Tensor([[[1,2,3],[4,5,6]]])
print(a.view(3,2))
#输出结果为:
#tensor([[1, 2],
#        [3, 4],
#        [5, 6]])

6. nn.Sequential()

        一个序列容器,用于搭建神经网络的模块被按照被传入构造器的顺序添加到nn.Sequential()容器中。除此之外,一个包含神经网络模块的OrderedDict也可以被传入nn.Sequential()容器中。利用nn.Sequential()搭建好模型架构,模型前向传播时调用forward()方法,模型接收的输入首先被传入nn.Sequential()包含的第一个网络模块中。然后,第一个网络模块的输出传入第二个网络模块作为输入,按照顺序依次计算并传播,直到nn.Sequential()里的最后一个模块输出结果。
        即nn.Sequential()相当于把多个模块封装成一个模块。它与nn.ModuleList()不同,nn.ModuleList()只是存储网络模块的list,其中的网络模块之间没有连接关系和顺序关系

7.permute()

        permute()函数将tensor的维度换位,相当于同时操作tensor的若干维度,与transpose()函数不同,transpose()只能同时作用于tensor的两个维度。

如:

>>>torch.randn(2,3,4,5).permute(3,2,0,1).shape
# 输出结果为torch.size([5,4,2,3])
# 上面的结果等价于:
>>>torch.randn(2,3,4,5).transpose(3,0).transpose(2,1).transpose(3,2).shape
# 输出结果为torch.size([5,4,2,3])

猜你喜欢

转载自blog.csdn.net/baidu_41774120/article/details/128666944