Derivation of softmax function cross-entropy loss function for multi-classification problems

Binary classification and multi-classification

The derivation of the cross-entropy loss function for multi-classification problems is consistent with the idea of ​​deriving the cross-entropy loss function for binary classification . You can first refer to the derivation of the cross-entropy loss function for the binary classification problem. You can refer to the derivation of  the cross-entropy loss function for the binary classification problem .

This article refers to the deeplearning book    6.2.2.3 Softmax Units for Multinoulli Output Distributions

Alright, let’s get down to business~

The first thing to be clear is that the models designed for multi-classification problems and binary classification problems are quite different in what is directly output by the last layer.

Assuming that the output layer of the model is a fully connected layer, we can draw what is directly output from the last layer of the two-category problem and the multi-classification problem, as shown in the figure below: (the multi-classification problem in the figure is a three-classification problem)

A blue circle in the figure represents a real value. For binary classification, it is only necessary to output a real value, and then transformed by the sigmoid function to represent the probability that the model predicts 1.

For N classification problems, N output values ​​are required, and the Nth output value can represent  the probability that the model predicts the Nth class after being transformed by softmax.

The probability can be represented after being converted by the softmax function

The following deduces why the probability can be expressed after the conversion of the softmax function, and now it is deduced~

Step 1:  Assume that the outputted three (extended to N) blue circles form a vector  z . The expression for z  can be written as  \mathbf{z}=\mathbf{W}^{T}\mathbf{h}+\mathbf{b}. where z is a vector.

The second step: construction \tilde{P}(y), (note that this \tilde{P}(y)is not a real probability distribution yet). make z_{i}=\log \tilde{P}(y=i|\mathbf{x}), that is \tilde{P}(y=i|\mathbf{x})=\exp (z_{i}).

why \tilde{P}(y=i|\mathbf{x})=\exp (z_{i})  ? Can't it be  \tilde{P}(y=i|\mathbf{x})=z_{i}used? The reason why it is used  \exp (z_{i}) is that when calculating the cross-entropy loss function, it should be put  \log inside, so that it is not easy to saturate when using gradient solution (the original text is this word, I don’t know what is better to translate it now.) This is gradient solution The problem of the time, there is no need to understand it now.

Step 3: Obtain the probability distribution. Exponentiate and normalize can get:

P(y=i|\mathbf{x})=\frac{exp(z_{i})}{\sum _{j}exp(z_{j})}

Step 4: So far,  has been converted into  z_{i} the probability that the model predicts each. This process is called the softmax function, which is the formula we are familiar with:

softmax(\mathbf{z})_{i}=\frac{exp(z_{i})}{\sum _{j}exp(z_{j})}

Step 5: Summary. Exponentiate  and normalize to get  z_{i} the probability that the model predicts each.

After softmax conversion, the cross entropy loss function is derived

After softmax transformation, it is easier to find the cross entropy loss function.

The formula for the cross-entropy loss function  -log(P(y))is

(Here y is the real label of this sample, assuming there are 1, 2, 3 categories, corresponding to  z_{1}, z_{2}, ,  if z_{3} the real label is ,   that is ). z_{2}P(y)P(y=z_{2}^{}|\mathbf{x})

Substituting  softmax(\mathbf{z})_{i}=\frac{exp(z_{i})}{\sum _{j}exp(z_{j})}  it into the formula  -log(P(y)) , the cross-entropy loss function can be obtained:

-\log softmax(\mathbf{z})_{i}=-(z_{i}-\log \sum _{j}exp(z_{j}))

If the actual category corresponds to the output value z_{2}, (assuming only 1, 2, 3 categories, corresponding to  z_{1}, z_{2}, z_{3}). Then the loss function is -\log softmax(\mathbf{z})_{2}=-(z_{2}-\log \sum _{j}exp(z_{j})).

At this point, our derivation is over. If there is something wrong, please leave a message~

Guess you like

Origin blog.csdn.net/qq_32103261/article/details/116590825