pytorch テンソル次元変換百科事典について

# view()    转换维度
# reshape() 转换维度
# permute() 坐标系变换
# squeeze()/unsqueeze() 降维/升维
# expand()   扩张张量
# narraw()   缩小张量
# resize_()  重设尺寸
# repeat(), unfold() 重复张量
# cat(), stack()     拼接张量

1 tensor.view()

view()はtensor の形状を変更するために使用されますが、tensor 内の要素の値は変更しません
使用法 1:
たとえば、ビューを使用して、形状 (2, 3) のテンソルを形状 (3, 2) のテンソルに変換できます。

import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = x.view(3, 2)    

上記の操作は、最初に形状 **(2, 3) のテンソルを (1, 6) に平坦化し、次に (3, 2) ** に平坦化することと同じです。

使用法 2:
変換の前後でテンソルの要素の数は変わりません。view() 内の特定の次元の次元が -1 の場合、この次元の次元は要素の総数と他の次元のサイズに応じて適応的に調整されることを意味します。view() の最大 1 つの次元の次元を-1 に設定できることに注意してください。

z = x.view(-1,2)

画像.png

例:
畳み込みニューラル ネットワークでは、テンソル次元を拡張するためにビューが全結合層でよく使用されます。
入力特徴がB C H*Wの 4 次元テンソルであると仮定します。ここで、B はバッチサイズを表し、C は特徴を表します。チャネル数、H と W はフィーチャの高さと幅を表します。フィーチャを完全に接続された層に送信する前に、 .view を使用してB*(C H W)の 2 次元テンソルに変換されます。つまり、バッチは変更されませんが、各特徴を 1 次元ベクトルに変換します。

2 tensor.reshape()

reshape() は view() と同じように使用されます。
画像.png

3 tensor.squeeze() と tensor.unsqueeze()

3.1 tensor.squeeze() の次元削減

(1)squeeze() 括弧が空の場合、テンソル内の次元 1 を持つすべての次元が圧縮されます ( 1、2、1、9 のテンソルの次元を 2、9 次元に削減するなど)。テンソルの次元が同じ場合、ソース次元は変更されません。たとえば、2 3 4 次元のテンソルが圧縮された場合、変換後に次元は変わりません。
(2) squeeze(idx) を使用すると、テンソルの対応するidx 番目の次元が圧縮されます。たとえば、squeeze(2) が 1、2、1、および 9 のテンソルに対して実行されると、次元は次のようになります。 1、2、9 次元に縮小テンソル; idx 次元の次元が 1 でない場合、スクイーズ後に次元は変化しません。
例えば:
画像.png

3.2 tensor.unsqueeze(idx)升维

次元のアップグレードは idx 次元で実行され、テンソルは元の次元 n からn+1 次元にアップグレードされますたとえば、テンソルの次元は 2*3 ですが、unsqueeze(0) 後は、1、2、3 の次元のテンソルになります。
画像.png

4 tensor.permute()

座標系の変換、つまり行列の転置は、 numpy array のtranspose と同じ方法で使用されます。permute() 括弧内のパラメータ番号は、各次元のインデックス値を参照します。Permute は深層学習でよく使用される手法であり、一般にBCHW の特徴テンソルは転置によってBHWC の特徴テンソルに変換されます。つまり、 **tensor.permute(0 , 2、3、1)**実現。
torch.transpose は 2D 行列の転置のみを操作できますが、 permute() 関数は任意の高次元行列を転置できます;
単純な理解: permute() はテンソルの複数の次元を同時に操作するのと同等であり、transpose は二次元のテンソルに同時に作用します。

画像.png

permute と view/reshape は両方ともテンソルを特定の次元に変換できますが、原理は完全に異なるため、区別に注意してください。表示および形状変更の処理後、テンソル内の要素の順序は変わりませんが、座標系が変更されるため、置換転置後に要素の配置が変わります。

5 torch.cat([a,b],dim)

ディム次元でテンソル スプライシングを実行する場合は、次元の一貫性を保つことに注意する必要があります
a がh1 w1 の 2 次元テンソル、b がh2 w2 の 2 次元テンソルであるとします。 torch.cat(a,b,0) は1 次元でのスプライシング、つまり列方向でのスプライシングを意味します。 w1 と w2 は等しくなければなりません。torch.cat(a,b,1) は 2 次元でのスプライシング、つまり行方向での、 h1 と h2 は等しくなければなりません。
a がc1 h1 w1の 2 次元テンソル、b が c2 h2 w2 の 2 次元テンソルであると仮定すると、torch.cat(a,b,0) は 1 次元でのスプライシング、つまりチャネルでのスプライシングを意味します。他の寸法は一貫性を保つ必要があります (w1=w2、h1=h2)。torch.cat(a,b,1) は 2 次元でのスプライシング、つまり列方向のスプライシングを意味します。w1=w2, c1=c2 であることを確認する必要があります。torch.cat(a,b,2) は、 3 次元でのスプライシング、つまり行、h1=h2、c1=c2 を確保する必要があります。
画像.png

6 tensor.expand()

値のコピーを通じて1 つの次元をより大きなサイズに拡張するには、テンソルを拡張しますExpand() 関数を使用しても元のテンソルは変更されないため、結果を再代入する必要があります。以下に具体的な例を示します:
2 次元テンソルを例にします: tensor は 1 n または n の1 次元テンソルです。行内でそれぞれ tensor.expand(s, n) または tensor.expand(n, s) を呼び出します。方向と列方向、展開する方向。
Expand() の fill-in パラメータは size です

画像.png

7 tensor.narrow(dim, start, len)

Narrow() 関数は、特定の次元でデータをフィルタリングする役割を果たします。

torch.narrow(input, dim, start, length)->Tensor

input はスライスする必要があるテンソル、dim はスライスの次元、start は開始インデックス、length はスライスの長さです。実際のアプリケーションは次のとおりです。

画像.png

8 tensor.resize_()

サイズが変更され、リサイズ後の次元にテンソルが切り詰められます。
画像.png

9 tensor.repeat()

tensor.repeat(a,b) は、テンソル全体を a 行方向にコピーし、 b を列方向にコピーします。

画像.png

参考:

pytorch のテンソル次元の変更に関連する関数 (継続的に更新) - weili21 の記事 - Zhihu
https://zhuanlan.zhihu.com/p/438099006

【pytorch tensor テンソル次元変換(テンソル次元変換)】
https://blog.csdn.net/x_yan033/article/details/104965077

おすすめ

転載: blog.csdn.net/Alexa_/article/details/134171416