ニューラルネットワークトレーニングでargmaxの非微分を克服する方法は?

上記の人工知能アルゴリズムとPythonビッグデータをクリックして、より多くのドライグッズを入手してください

右上  星に設定★、初めてリソースをゲット

学術的な共有の場合のみ、侵害がある場合は、削除するために連絡してください

転載元:著者| Zhenyue Qin、hoooz、OwlLite

ソース|質問と回答を知る

住所|https://www.zhihu.com/question/460500204

元の質問

最近、トーチを使用してnlpスタイルの変換を行っていました。ganを使用して学習したところ、seq2seqの出力は(バッチサイズ、最大長、語彙長)の形のテンソルであり、最後の次元はソフトマックス確率の後の辞書の各単語の出現。

GANの原理に従って、ジェネレーターの出力を入力として受け取り、ディサイダーの出力の形状(バッチサイズ、ラベル番号)のテンソルを取得したいと思います。次に、標準ラベルとクロスエントロピーを実行して損失を取得します。ただし、ディサイダーの入力テンソルは(バッチサイズ、最大長)形状である必要があります。

ここで、seq2seqからの出力がout.argmax(-1)によって処理される場合、loss.backward()はネットワークに勾配を生成できません。ここに良い解決策があるかどうか皆さんにお聞きしたいと思います。

01

回答1:著者-Zhenyue Qin

ガンベル(推定量)を介してひずみと呼ばれるものがあります、あなたは見てみることができます〜

一般的な考え方は次のとおりです。入力ベクトルがvであるとすると、softmaxを使用してsoftmax(v)を取得します。このようにして、最大値は1に非常に近くなり、他の場所は0に非常に近くなります。 argmax(v)を計算すると、定数c = argmax(v)-softmax(v)が得られます。このとき、argmax(v)の結果としてsoftmax(v)+cを使用できます。これの利点つまり、softmax(v)+ cは逆伝播が可能です。つまり、softmax(v)の勾配を逆伝播として使用します。

不明な点がございましたら、コメントをお待ちしております。ありがとうございます。

PS元の回答を修正してくれたChunchuanLuとTowserに感謝します。

02

回答2:著者-hoooz

オプション1:ストップグラジエント操作を追加します。VQVAEおよび対応するpytorchの実装を参照してください[1] [2]

972264003f5256ff81b298080c93c299.png

一文の説明:順方向の伝播は通常と同じです。逆方向に伝播するときは、操縦不可能なポイントから操縦不可能なポイントの前にある最も近い派生可能なポイントに勾配をコピーします。

(赤い線の右端のグラデーションを参照し、中央の辞書モジュールをスキップして、赤い線の左端にまっすぐ進みます)

ここに問題があります

1 /グラデーションチェーンを切断して、辞書モジュールを通過しないようにするにはどうすればよいですか?Pytorchにはdetach()があり、グラデーションをカットオフできるため、グラデーションが非導電性領域に入り、コンパイラがエラーを報告することはありません。

2 /グラデーションを複製する方法は?最も簡単な例を見てください

quantize = input + (quantize - input).detach()
# 正向传播和往常一样,
# 反向传播时,detach()这部分梯度为0,quantize和input的梯度相同,
# 即实现将quantize复制给input
# quantize即红线右端点,input即红线左端点

参照する:

  • [1]。神経離散表現学習

  • [2]。https://github.com/rosinality/v

03

回答3:作成者-OwlLite

argmax / argminの微分不可能な演算は直接無視できます。つまり、次のようにロックします。

class ArgMax(torch.autograd.Function):
	@staticmethod
	def forward(ctx, input):
        idx = torch.argmax(input, 1)
        output = torch.zeros_like(input)
        output.scatter_(1, idx, 1)
        return output
	
	@staticmethod
	def backward(ctx, grad_output):
        return grad_output

---------♥---------

声明:このコンテンツはインターネットからのものであり、著作権は元の作者に帰属します

写真はインターネットから提供されたものであり、この公式アカウントの位置を表すものではありません。侵害がある場合は、削除するために連絡してください

AI博士のプライベートWeChat、まだいくつかの欠員があります

4da468c799ea2353c748269db7979b6f.png

f4df524d7667bd89beb9059ed83e2dde.gif

美しい深層学習モデル図を描く方法は?

美しいニューラルネットワーク図を描く方法は?

ディープラーニングのさまざまな畳み込みを理解するための1つの記事

クリックしてサポートを表示95a69ceea53001d204663efd0685a161.png6af4172ff2e9b92e14ee32d003151b59.png

おすすめ

転載: blog.csdn.net/qq_15698613/article/details/121586581