Pytorch中tensor.expand()和tensor.expand_as()函数

Tensor.expand()函数详解

函数语法:

# 官方解释:
Docstring:
expand(*sizes) -> Tensor

Returns a new view of the :attr:`self` tensor with singleton dimensions expanded
to a larger size.

基本功能:

tensor.expand()函数可以将维度值包含 1 的Tensor(如:torch.Size([1, n])或者torch.Size([n, 1]))的维度进行扩展。其具体的扩展规则如下:

  1. 只能对维度值包含 1 的张量Tensor进行扩展,即:Tensor的size必须满足:torch.Size([1, n]) 或者 torch.Size([n, 1]) 。
  2. 只能对维度值等于 1 的那个维度进行扩展,无需扩展的维度务必保持维度值不变,或者置为-1,否则,报错。(简言之,只要是单维度均可进行扩展,但是若非单维度会报错。
  3. 扩展的Tensor不会分配新的内存,只是原来的基础上创建新的视图并返回;
  4. 新扩展维度的取值范围为: − 1 以 及 [ 1 , + ∞ ] 区 间 内 的 任 意 整 数 -1以及[1, +∞]区间内的任意整数 1[1,+],例如:将 torch.Size([1, n]) 扩展为torch.Size([m, n])时,新扩展维度 m 的可能取值为-1,或者 m ≥ 1的任意整数;
  5. 只能对张量Tensor进行维度扩展,而不能降维;否则,报错。
  6. tensor通过.expand()函数扩展某一维度后,tensor自身不会发生变化。

备注:
1、将 -1 传递给新扩展维度或者无需扩展维度均表示不更改该维度的尺寸。(Passing -1 as the size for a dimension means not changing the size of that dimension.)
2、如果令m=0,则会将原tensor变为空张量。示例如下:

a = torch.tensor([[2], [3], [4]])
print("a:\n", a)
print(a.size())
>>>
a:
 tensor([[2],
        [3],
        [4]])
torch.Size([3, 1])
# 将新扩展维度m置0时,原tensor变为空张量。
a.expand(3,0)
>>>
tensor([], size=(3, 0), dtype=torch.int64)

应用实例01: torch.Size([n, 1]) 扩展为 torch.Size([n, m])

import torch

a = torch.tensor([[2], [3], [4]])   # 创建size为3行1列的张量
print("a:\n", a)
print(a.size())
>>>
a:
 tensor([[2],
        [3],
        [4]])
torch.Size([3, 1])
# (1)将torch.Size([3, 1])扩展为torch.Size([3, 2])
a.expand(3,2)
>>>
tensor([[2, 2],
        [3, 3],
        [4, 4]])
# (2)将 -1 赋值给“无需扩展维度”,同时将torch.Size([3, 1])扩展为torch.Size([3, 4])
a.expand(-1,4)   # 此处a.expand(-1,4)与a.expand(3,4)是等价的
>>>
tensor([[2, 2, 2, 2],
        [3, 3, 3, 3],
        [4, 4, 4, 4]])

#(3)将 -1 赋值给“新扩展维度”,此时torch.Size([3, 1])将保持原状,不扩展。
a.expand(3,-1) # -1 means not changing the size of that dimension
>>>
tensor([[2],
        [3],
        [4]])
# (4)同时将 -1 赋值给“新扩展维度”和“无需扩展维度”,此时torch.Size([3, 1])将保持原状,不扩展。
a.expand(-1,-1)  # 此处a.expand(-1,-1)与a.expand(3,-1)是等价的
>>>
tensor([[2],
        [3],
        [4]])

如果原始Tensor的维度值中不包含1,则不能使用tensor.expand()函数进行扩展;否则,报错。 示例如下:

b = torch.tensor([[2, 1], [3, 5], [4, 7]])
print("b:\n", b)
print(b.size())
>>>
b:
 tensor([[2, 1],
        [3, 5],
        [4, 7]])
torch.Size([3, 2])
# 欲将torch.Size([3, 2])扩展为torch.Size([3, 4]),结果报错。
b.expand(3, 4)
>>>
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-24-0fa681906d91> in <module>
----> 1 b.expand(3, 4)

RuntimeError: The expanded size of the tensor (4) must match the existing size (2) at non-singleton dimension 1.  Target sizes: [3, 4].  Tensor sizes: [3, 2]
# 欲将torch.Size([3, 2])降维至torch.Size([3, 1]),结果报错。
b.expand(3, 1)
>>>
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-26-3ca2598393c4> in <module>
----> 1 b.expand(3, 1)

RuntimeError: The expanded size of the tensor (1) must match the existing size (2) at non-singleton dimension 1.  Target sizes: [3, 1].  Tensor sizes: [3, 2]

应用实例02: torch.Size([1, n]) 扩展为 torch.Size([m, n])

c = torch.tensor([[2, 1, 5, 8, 9]])  # 创建size为1行5列的张量
print("c:\n", c)
print(c.size())
>>>
c:
tensor([[2, 1, 5, 8, 9]])
torch.Size([1, 5])
# (1)将torch.Size([1, 5])扩展为torch.Size([3, 5])
c.expand(3,5)
>>>
tensor([[2, 1, 5, 8, 9],
        [2, 1, 5, 8, 9],
        [2, 1, 5, 8, 9]])
# (2)将 -1 赋值给“无需扩展维度”,同时将torch.Size([1, 5])扩展为torch.Size([3, 5])
c.expand(3,-1)  # 此处c.expand(3,-1)与c.expand(3,5)是等价的
>>>
tensor([[2, 1, 5, 8, 9],
        [2, 1, 5, 8, 9],
        [2, 1, 5, 8, 9]]) 
#(3)将 -1 赋值给“新扩展维度”,此时torch.Size([1, 5])将保持原状,不扩展。
c.expand(-1,5) # -1 means not changing the size of that dimension
>>>
tensor([[2, 1, 5, 8, 9]])
# (4) 欲将torch.Size([1, 5])扩展为torch.Size([3, 4]),结果报错。
c.expand(3,4)
>>>
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-37-f39aa37f0529> in <module>
----> 1 c.expand(3,4)

RuntimeError: The expanded size of the tensor (4) must match the existing size (5) at non-singleton dimension 1.  Target sizes: [3, 4].  Tensor sizes: [1, 5]

Tensor.expand_as()函数

函数语法:

# 官方解释:
Docstring:
expand_as(other) -> Tensor

Expand this tensor to the same size as :attr:`other`.
``self.expand_as(other)`` is equivalent to ``self.expand(other.size())``.


Please see :meth:`~Tensor.expand` for more information about ``expand``.

Args:
    other (:class:`torch.Tensor`): The result tensor has the same size
        as :attr:`other`.
Type:      builtin_function_or_method

基本功能:

张量b和a.expand_as(b)的size是一样大的,并且是不共享内存的。

应用实例:

a = torch.tensor([[2], [3], [4]])
print("a:\n", a)
print(a.size())
>>>
a:
 tensor([[2],
        [3],
        [4]])
torch.Size([3, 1])

b = torch.tensor([[2,2],[3,3],[5,5]])
print("b:\n", b)
print(b.size())
>>>
b:
 tensor([[2, 2],
        [3, 3],
        [5, 5]])
torch.Size([3, 2])
# 使用a.expand_as(b)将张量a的size——torch.Size([3, 1])扩展为与张量b的size——torch.Size([3, 2])同形的高维张量
a.expand_as(b)
>>>
tensor([[2, 2],
        [3, 3],
        [4, 4]])

猜你喜欢

转载自blog.csdn.net/weixin_42782150/article/details/108615706