pytorch 展開

基本的な使い方

以下はchatgptによって生成されます

PyTorch のexpandメソッドは、テンソル次元を拡張するために使用されます。これは、小さなテンソルを大きなテンソルにコピーして、テンソルの形状をコピーするために使用できるテンソルのメソッドです。

最も一般的な使用シナリオは、特定の次元を拡張して他のテンソルの形状に合わせることです。たとえば、N 次元のテンソル A があり、それを B 次元のテンソルで操作したいが、B の次元が A よりも大きいとします。この場合、expandメソッドを使用して A の次元を B の次元に合わせて拡張できます。

具体的には、このメソッドの使用法は次のとおりです。

expanded_tensor = tensor.expand(sizes)

その中で、tensor は拡張する必要がある tensor であり、sizes は拡張された形状を示す整数のリストです。たとえば、テンソルの形状が (3, 4) でサイズが (3, 4, 5) の場合、展開されたテンソルの形状は (3, 4, 5) になります。

簡単な例:

import torch

a = torch.tensor([1, 2, 3])
b = a.expand(3, 3)
print(b)

# output:
# tensor([[1, 2, 3],
#         [1, 2, 3],
#         [1, 2, 3]])

上記の例では、expand メソッドを使用して 1 次元テンソル a を 2 次元テンソル b に展開し、サイズ パラメータを指定して、展開されたテンソルの形状を (3, 3) として指定します。

サイズに-1が含まれる場合

-1 がパラメーター リストのサイズに含まれている場合、PyTorch は元の次元と等しくなるように次元のサイズを一致させようとします。これにより、サイズが指定された次元のみが拡張されます。

簡単な例:

import torch

a = torch.tensor([1, 2, 3, 4])
b = a.expand(2, -1)
print(b)

# output:
# tensor([[1, 2, 3, 4],
#         [1, 2, 3, 4]])

上記の例では、テンソル a の形状を (2, 4) に拡張しました。テンソル a は次元 0 に沿って 2 回拡張されます。

おすすめ

転載: blog.csdn.net/duoyasong5907/article/details/128991391