Pytorch関数expand()の詳細説明

Pytorch 関数.expand()

単一の次元をより大きな次元に拡張し、新しいテンソルを返します。詳細については、次の例を参照してください

import torch

a = torch.Tensor([[1], [2], [3],[4]])
# 未使用expand()函数前的a
print('a.size: ', a.size())
print('a: ', a)

b = a.expand(4, 2)
# 使用expand()函数后的输出
print('a.size: ', a.size())
print('a: ', a)
print('b.size: ', b.size())
print('b: ', b)

a は、expand() 関数を使用する前後で変化せず、出力は次のようになります。

a.size: torch.Size([4, 1])
a:
1
2
3
4
[サイズ 4x1 の torch.FloatTensor]

b の出力は次のとおりです。

b.size: torch.Size([4, 2])
b:
1 1
2 2
3 3
4 4
[サイズ 4x2 の torch.FloatTensor]
これから、 a は、expand() を通じて特定の次元でそれ自体を拡張すると結論付けることができます。機能は変わりません

a = torch.Tensor([[[[1,2], [2,3], [3,4],[4,5]]]])
b = a.expand(2, 1, 4, 2)
c = a.expand(1, 2, 4, 2)
# 使用expand()函数后的输出
print('a.size: ', a.size())

print('b.size: ', b.size())
print('b: ', b)

print('c.size: ', c.size())
print('c: ', c)

 b2 = b.expand(3, 1, 4, 2)  # b: torch.Size([2, 1, 4, 2])
 print('b2.size: ', b2.size())

出力:

a.size: torch.Size([1, 1, 4, 2])

b.size: torch.Size([2, 1, 4, 2])
b:
(0 ,0 ,.,.) =
1 2
2 3 3
4
4 5
(1 ,0 ,.,.) =
1 2
2 3
3 4
4 5
[サイズ 2x1x4x2 の torch.FloatTensor]

c.size: torch.Size([1, 2, 4, 2])
c:
(0 ,0 ,.,.) =
1 2 2
3 3
4
4 5
(0 ,1 ,.,.) =
1 2
2 3
3 4
4 5
[サイズ 1x2x4x2 の torch.FloatTensor]

b2 出力:

トレースバック (最新の呼び出しは最後): RuntimeError のファイル
「」、行 1 : テンソルの拡張サイズ (3) は、非単一次元次元 0 の既存のサイズ (2) と一致する必要があります。 /opt/conda/conda-
bld/pytorch_1525796793591/work/torch/lib/TH/generic/THTensor.c:309

単一次元であれば展開できるが、単一次元でない場合はエラーが報告されることがわかります。

おすすめ

転載: blog.csdn.net/weixin_43994864/article/details/106244379