How to overcome the non-differentiation of argmax in neural network training?

Click the above artificial intelligence algorithm and Python big data to get more dry goods

In the upper right  ...  Set as a star ★, get resources at the first time

Only for academic sharing, if there is any infringement, please contact to delete

Reprinted from: Author | Zhenyue Qin, hoooz, OwlLite

Source | Know the question and answer

Address | https://www.zhihu.com/question/460500204

original question

Recently, I was using torch to do nlp style conversion. When I used gan to learn, I found that the output of seq2seq was a tensor in the shape of (batch size, max length, vocab length), and the last dimension represented the appearance of each word in the dictionary after softmax. probability.

According to the principle of GAN, I want to take the output of the generator as input to get a tensor of shape (batch size, labels nums) of the output of the decider. Then do cross entropy with standard labels to get loss. But the input tensor of the decider should be of (batch size, max length) shape.

Here, if the output out of seq2seq is processed by out.argmax(-1), loss.backward() will not be able to generate gradients in the network. I would like to ask you guys if there is any good solution here.

01

Answer 1: Author - Zhenyue Qin

There is a thing called strainght through Gumbel (estimator), you can take a look~

The general idea is: Suppose the input vector is v, then we use softmax to get softmax(v). In this way, the maximum value will become very close to 1, and other places will become very close to 0. Then, we calculate argmax (v), then a constant c = argmax(v) - softmax(v) can be obtained. At this time, we can use softmax(v) + c as the result of argmax(v). The advantage of this is that our softmax(v) + c is capable of backpropagation. In other words, we use the gradient of softmax(v) as backpropagation.

If anything is not clear, comments are welcome. Thank you.

PS Thanks to Chunchuan Lu and Towser for their corrections to the original answer.

02

Answer 2: Author-hoooz

Option 1: Add stop gradient operation, please refer to VQVAE and the corresponding pytorch implementation[1][2]

972264003f5256ff81b298080c93c299.png

One sentence explanation: Forward propagation is the same as usual. When backpropagating, copy the gradient from the non-steerable point to the nearest derivable point before the non-steerable point.

(See the gradient at the right end of the red line, skip the middle dictionary module and go straight to the left end of the red line)

~

Here comes the problem

1/How to cut off the gradient chain and prevent him from passing through the dictionary module? Pytorch has a detach(), which can cut off the gradient, so that the gradient will not enter the non-conductive area and cause the compiler to report an error

2/ How to replicate the gradient? Take the simplest example

quantize = input + (quantize - input).detach()
# 正向传播和往常一样,
# 反向传播时,detach()这部分梯度为0,quantize和input的梯度相同,
# 即实现将quantize复制给input
# quantize即红线右端点,input即红线左端点

refer to:

  • [1]. Neural Discrete Representation Learning

  • [2]. https://github.com/rosinality/v

03

Answer 3: Author - OwlLite

The non-derivable operation of argmax/argmin can be ignored directly, that is, locking:

class ArgMax(torch.autograd.Function):
	@staticmethod
	def forward(ctx, input):
        idx = torch.argmax(input, 1)
        output = torch.zeros_like(input)
        output.scatter_(1, idx, 1)
        return output
	
	@staticmethod
	def backward(ctx, grad_output):
        return grad_output

---------♥---------

Statement: This content comes from the Internet, and the copyright belongs to the original author

The pictures are sourced from the Internet and do not represent the position of this official account. If there is any infringement, please contact to delete

Dr. AI's private WeChat, there are still a few vacancies

4da468c799ea2353c748269db7979b6f.png

f4df524d7667bd89beb9059ed83e2dde.gif

How to draw a beautiful deep learning model diagram?

How to draw a beautiful neural network diagram?

One article to understand various convolutions in deep learning

Click to see support95a69ceea53001d204663efd0685a161.png6af4172ff2e9b92e14ee32d003151b59.png

Guess you like

Origin blog.csdn.net/qq_15698613/article/details/121586581