张量的连续性、contiguous函数

        在pytorch中,tensor的实际数据以一维数组(storage)的形式存储于某个连续的内存中,“行优先”进行存储。

tensor的连续性

         tensor连续(contiguous)是指tensor的storage元素排列顺序与其按行优先时的元素排列顺序相同。如下图所示:

        上图中,tensor b是tensor a经过转置而来的,即使用了 tensor.t() 方法。

        出现不连续现象,本质上是由于pytorch中不同tensor可能共用同一个storage导致的。
    pytorch的很多操作都会导致tensor不连续,如tensor.transpose()(tensor.t())、tensor.narrow()、tensor.expand()。
        以转置为例,因为转置操作前后共用同一个storage,但显然转置后的tensor按照行优先排列成1维后与原storage不同了,因此转置后结果属于不连续(见下例)。

2. tensor.is_contiguous()

   is_contiguous直观的解释是Tensor底层一维数组元素的存储顺序与Tensor按行优先一维展开的元素顺序是否一致

        如果想要变得连续使用contiguous方法,如果Tensor不是连续的,则会重新开辟一块内存空间保证数据是在内存中是连续的;如果Tensor是连续的,则contiguous无操作。

        tensor.is_contiguous()用于判断tensor是否连续,以转置为例说明:

import torch
a = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])
print(a)
print(a.storage())
print(a.is_contiguous())  # a是连续的
"""
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
 1
 2
 3
 4
 5
 6
 7
 8
 9
[torch.LongStorage of size 9]
True
"""

b = a.t()  # b是a的转置
print(b)
print(b.storage())
print(b.is_contiguous())  # b是不连续的
"""
tensor([[1, 4, 7],
        [2, 5, 8],
        [3, 6, 9]])
 1
 2
 3
 4
 5
 6
 7
 8
 9
[torch.LongStorage of size 9]
False
"""

3. tensor不连续的后果

        tensor不连续会导致某些操作无法进行,比如view()就无法进行。在上面的例子中:由于 b 是不连续的,所以对其进行view()操作会报错;b.view(3,3)没报错,因为b本身的shape就是(3,3)。

print(b.view(3, 3))
"""
tensor([[1, 4, 7],
        [2, 5, 8],
        [3, 6, 9]])
"""
print(b.view(1, 9))# 报错
print(b.view(-1))# 报错

 4. tensor.contiguous()


        tensor.contiguous()返回一个与原始tensor有相同元素的 “连续”tensor,如果原始tensor本身就是连续的,则返回原始tensor。
        注意:tensor.contiguous()函数不会对原始数据做任何修改,他不仅返回一个新tensor,还为这个新tensor创建了一个新的storage,在这个storage上,该新的tensor是连续的。
继续使用上面的例子:

c = b.contiguous()
print(b)
print(c)
print(b.storage())
print(c.storage())

 输出结果:

# b
tensor([[1, 4, 7],
        [2, 5, 8],
        [3, 6, 9]])
# c
tensor([[1, 4, 7],
        [2, 5, 8],
        [3, 6, 9]])
# b.storage
 1
 2
 3
 4
 5
 6
 7
 8
 9
[torch.LongStorage of size 9]

#c.storage 
 1
 4
 7
 2
 5
 8
 3
 6
 9
[torch.LongStorage of size 9]

接着运行如下代码: 

print(b.is_contiguous()) # False
print(c.is_contiguous()) # True
print(c.view(1, 9)) # tensor([[1, 4, 7, 2, 5, 8, 3, 6, 9]])


参考自:https://blog.csdn.net/baidu_41774120/article/details/128666944

猜你喜欢

转载自blog.csdn.net/m0_48241022/article/details/132804698