pytorch テンソル次元変換とテンソル演算について

tensor 乘
tensor 加
# view()    转换维度
# reshape() 转换维度
# permute() 坐标系变换
# squeeze()/unsqueeze() 降维/升维
# expand()   扩张张量
# narraw()   缩小张量
# resize_()  重设尺寸
# repeat(), unfold() 重复张量
# cat(), stack()     拼接张量

1 tensor.view()

view()はtensor の形状を変更するために使用されますが、tensor 内の要素の値は変更しません
使用法 1:
たとえば、ビューを使用して、形状 (2, 3) のテンソルを形状 (3, 2) のテンソルに変換できます。

import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = x.view(3, 2)    

上記の操作は、最初に形状 **(2, 3) のテンソルを (1, 6) に平坦化し、次に (3, 2) ** に平坦化することと同じです。

使用法 2:
変換の前後でテンソルの要素の数は変わりません。view() 内の特定の次元の次元が -1 の場合、この次元の次元は要素の総数と他の次元のサイズに応じて適応的に調整されることを意味します。view() の最大 1 つの次元の次元を-1 に設定できることに注意してください。

z = x.view(-1,2)

画像.png

例:
畳み込みニューラル ネットワークでは、テンソル次元を拡張するためにビューが全結合層でよく使用されます。
入力特徴がB C H*Wの 4 次元テンソルであると仮定します。ここで、B はバッチサイズを表し、C は特徴を表します。チャネル数、H と W はフィーチャの高さと幅を表します。フィーチャを完全に接続された層に送信する前に、 .view を使用してB*(C H W)の 2 次元テンソルに変換されます。つまり、バッチは変更されませんが、各特徴を 1 次元ベクトルに変換します。

2 tensor.reshape()

reshape() は view() と同じように使用されます。
画像.png

3 tensor.squeeze() と tensor.unsqueeze()

3.1 tensor.squeeze() の次元削減

(1)squeeze() 括弧が空の場合、テンソル内の次元 1 を持つすべての次元が圧縮されます ( 1、2、1、9 のテンソルの次元を 2、9 次元に削減するなど)。テンソルの次元が同じ場合、ソース次元は変更されません。たとえば、2 3 4 次元のテンソルが圧縮された場合、変換後に次元は変わりません。
(2) squeeze(idx) を使用すると、テンソルの対応するidx 番目の次元が圧縮されます。たとえば、squeeze(2) が 1、2、1、および 9 のテンソルに対して実行されると、次元は次のようになります。 1、2、9 次元に縮小テンソル; idx 次元の次元が 1 でない場合、スクイーズ後に次元は変化しません。
例えば:
画像.png

3.2 tensor.unsqueeze(idx)升维

次元のアップグレードは idx 次元で実行され、テンソルは元の次元 n からn+1 次元にアップグレードされますたとえば、テンソルの次元は 2*3 ですが、unsqueeze(0) 後は、1、2、3 の次元のテンソルになります。
画像.png

4 tensor.permute()

座標系の変換、つまり行列の転置は、 numpy array のtranspose と同じ方法で使用されます。permute() 括弧内のパラメータ番号は、各次元のインデックス値を参照します。Permute は深層学習でよく使用される手法であり、一般にBCHW の特徴テンソルは転置によってBHWC の特徴テンソルに変換されます。つまり、 **tensor.permute(0 , 2、3、1)**実現。
torch.transpose は 2D 行列の転置のみを操作できますが、 permute() 関数は任意の高次元行列を転置できます;
単純な理解: permute() はテンソルの複数の次元を同時に操作するのと同等であり、transpose は二次元のテンソルに同時に作用します。

画像.png

permute と view/reshape は両方ともテンソルを特定の次元に変換できますが、原理は完全に異なるため、区別に注意してください。表示および形状変更の処理後、テンソル内の要素の順序は変わりませんが、座標系が変更されるため、置換転置後に要素の配置が変わります。

5 torch.cat([a,b],dim)

ディム次元でテンソル スプライシングを実行する場合は、次元の一貫性を保つことに注意する必要があります
a がh1 w1 の 2 次元テンソル、b がh2 w2 の 2 次元テンソルであるとします。 torch.cat(a,b,0) は1 次元でのスプライシング、つまり列方向でのスプライシングを意味します。 w1 と w2 は等しくなければなりません。torch.cat(a,b,1) は 2 次元でのスプライシング、つまり行方向での、 h1 と h2 は等しくなければなりません。
a がc1 h1 w1の 2 次元テンソル、b が c2 h2 w2 の 2 次元テンソルであると仮定すると、torch.cat(a,b,0) は 1 次元でのスプライシング、つまりチャネルでのスプライシングを意味します。他の寸法は一貫性を保つ必要があります (w1=w2、h1=h2)。torch.cat(a,b,1) は 2 次元でのスプライシング、つまり列方向のスプライシングを意味します。w1=w2, c1=c2 であることを確認する必要があります。torch.cat(a,b,2) は、 3 次元でのスプライシング、つまり行、h1=h2、c1=c2 を確保する必要があります。
画像.png

6 トーチ.スタック()

この関数は、同じ形状を持つ複数のテンソルをある次元で接続し、最終結果は次元的に増加します。つまり、いくつかのテンソルが特定の次元で接続されて、拡張されたテンソルが生成されます。積み重なった感じ。
画像.png

7 torch.chunk() と torch.split()

torch.chunk(input, chunks, dim)

**torch.chunk()** の機能は、テンソルをいくつかの小さなテンソルに均等に分割することです。入力は分割されたテンソルです。チャンクは均等に分割された部分の数です。分割の次元のサイズがチャンクで割り切れない場合、最後のテンソルはわずかに小さくなります (または空になる可能性があります)。dim は、特定の次元に沿った分割を決定します。この関数は、小さなテンソルで構成されるタプルを返します。
画像.png

**torch.split()** は torch.chunk() のバージョンアップ版とも言え、コピー数に応じて均等に分割するだけでなく、特定の計画に従って分割することもできます。

torch.split(input, split_size_or_sections, dim=0)

torch.chunk() との違いは 2 番目のパラメータにあります。第二引数が分割数の場合は torch.chunk() と同じで、第二引数は分割計画でリスト型データで、分割されるテンソルは len (リスト) 個の部分に分割されます. 、各部分のサイズはリスト内の要素によって異なります。
画像.png

8 テンソルを使用した乗算演算

  • 要素ごとに、つまり、同じ形状の行列の対応する要素を乗算します。得られた要素は、結果行列の各要素の値です。対応する関数はtorch.mul() (* と同じ効果です) )。

画像.png

  • 行列乗算の場合、対応する関数はtorch.mm() ( 2D テンソルにのみ使用可能) または torch.matmul() (記号 @ と同じ効果) です。torch.matmul()の場合、行列の乗算は最後の 2 つの次元でのみ定義し、前の次元は一貫している必要があります。以前の次元がbroadcast_tensorメカニズムに準拠している場合、2つの行列の前の次元が一貫していることを保証するために、次元は自動的に拡張されます。

画像.png
画像.png

9 テンソルによる加算演算

次の 2 つの点に従ってください。

  • 2 つのテンソルの次元が同じ場合、対応する軸の値が同じ (各次元のサイズが等しい) か、一部の次元のサイズが 1 である必要があります。追加する場合は、1 であるすべての軸をコピーして展開し、同じ次元の 2 つのテンソルを取得し、対応する位置を追加します。
  • 追加した 2 つのテンソルの次元が一致しない場合は、まず次元の低いテンソルを右から次元の高いテンソルに位置合わせし、次元を 1 で拡張して高次元のテンソルの次元と一致させます。 <1>の操作を実行します。

画像.png

10 テンソル.expand()

値のコピーを通じて1 つの次元をより大きなサイズに拡張するには、テンソルを拡張しますExpand() 関数を使用しても元のテンソルは変更されないため、結果を再代入する必要があります。以下に具体的な例を示します:
2 次元テンソルを例にします: tensor は 1 n または n の1 次元テンソルです。行内でそれぞれ tensor.expand(s, n) または tensor.expand(n, s) を呼び出します。方向と列方向、展開する方向。
Expand() の fill-in パラメータは size です

画像.png

11 tensor.narrow(dim, start, len)

Narrow() 関数は、特定の次元でデータをフィルタリングする役割を果たします。

torch.narrow(input, dim, start, length)->Tensor

input はスライスする必要があるテンソル、dim はスライスの次元、start は開始インデックス、length はスライスの長さです。実際のアプリケーションは次のとおりです。

画像.png

12 tensor.resize_()

サイズが変更され、リサイズ後の次元にテンソルが切り詰められます。
画像.png

13 tensor.repeat()

tensor.repeat(a,b) は、テンソル全体を a 行方向にコピーし、 b を列方向にコピーします。

画像.png

14 アンバインド()

torch.unbind() は指定された次元を削除した後、指定された次元に沿った各スライスを含むタプルを返します。

torch.unbind(input, dim=0)->seq

画像.png

参考:

pytorch のテンソル次元の変更に関連する関数 (継続的に更新) - weili21 の記事 - Zhihu
https://zhuanlan.zhihu.com/p/438099006

【pytorch tensor テンソル次元変換(テンソル次元変換)】
https://blog.csdn.net/x_yan033/article/details/104965077

おすすめ

転載: blog.csdn.net/Alexa_/article/details/134171971