Understanding target data for softmax output layer

Mick :

I found some example code for a MNIST hand written character classification problem. The start of the code is as follows:

import tensorflow as tf

# Load in the data
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
print("x_train.shape:", x_train.shape)

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])
# Compile the model
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
# Train the model
r = model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=10)

Looking at the code it appears that the output layer of the network consists of ten nodes. If the network was working perfectly after training then (the appropriate) one of the ten outputs would have an activation very close to one and the rest should have activations very close to zero.

I knew that the training set contained 60000 example patterns. I assumed that the target output data (y_train) would therefore be a 2D numpy array with a shape of 60000x10. I decided to double check and executed print(y_train.shape) and was very surprised to see it say (60000,)... Normally you would expect to see the size of the target patterns would be the same as the number of nodes in the output layer. I thought to myself, "OK, well obviously softmax is an unusual special case were we only need one target"... My next thought was - how could I have known this from any documentation?... so far I have failed to find anything.

ddoGas :

I think you were searching in the wrong direction. It's not because of the softmax. Softmax function (not layer) receives n values and produces n values. It's because of the sparse_categorical_crossentropy loss.

In the official document you can check that you are supposed to give target values as label integers. You can also see that there is a exact same loss that uses shape of (60000,10) as target values which is CategoricalCrossentropy loss.

You choose which loss to use depending on your provided data format. Since MNIST data is labeled as integers instead of one-hot encoding, the tutorial uses SparseCategoricalCrossentropy loss.

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=13881&siteId=1