mnist identify optimization - the new fashion mnist the model train

Today, through forums happen to know, after mnist, also appeared Fashion MNIST a classic mnist intended to replace the data set, like mnist, it is also being used as a "hello world" deep learning program, but also by the 70k sheets 28 * consisting of 28 pictures, which are also divided into 10 categories, have been used as training 60k, 10k was used as a test. The only difference is, fashion mnist ten categories by the handwritten numbers into the clothing. The ten categories are as follows:

'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'

The design process is as follows:

  • First, get a data set, tensorflow get fashion mnist methods and mnist similar, using keras.datasets.fashion_mnist.load_data () can

  · The data set into training and test sets

  · Since image pixel value range is 0-255, preprocessing the data set, scaled to the pixel value range of 0 to 1 (i.e., divided by 255)

  · Build the network model (784 → 128 (relu) → 10 (softmax)), fully connected

  · Compilation model, designed loss function (logarithmic loss), Optimizer (adam) and training indicators (accuracy)

  · Trainer

  · Evaluate the accuracy (test data using matplotlib visualization)

 

Sources and features about Adam optimizer please refer to: https://www.jianshu.com/p/aebcaf8af76e

About matplotlib data visualization please refer to: https://blog.csdn.net/xHibiki/article/details/84866887

 

Portion of the training set data visualization as follows:

 

Made a total of 50 training, losses and precision when training began as follows:

 

 Losses and precision when training is completed as follows:

 

 Model performance on the test set as follows:

 

 Select a set of test images to predict visual results are as follows:

 

 Code is as follows:

  1 import tensorflow as tf
  2 from tensorflow import keras
  3 import numpy as np
  4 import matplotlib.pyplot as plt
  5 
  6 # 导入fashion mnist数据集
  7 fashion_mnist = keras.datasets.fashion_mnist
  8 (train_images,train_labels),(test_images,test_labels) = fashion_mnist.load_data()
  9 
 10 # 衣服类别
 11 class_names = ['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal',
 12                'Shirt','Sneaker','Bag','Ankle boot']
 13 print(train_images.shape,len(train_labels))
 14 print(test_images.shape,len(test_labels))
 15 
 16 # 查看图片
 17 plt.figure()
 18 plt.imshow(train_images[0])
 19 plt.colorbar()
 20 plt.grid(False)
 21 plt.show()
 22 
 23 # 预处理数据,将像素值除以255,使其缩放到0到1的范围
 24 train_images = train_images / 255.0
 25 test_images = test_images / 255.0
 26 
 27 # 验证数据格式的正确性,显示训练集前25张图像并注明类别
 28 plt.figure(figsize=(10,10))
 29 for i in range(25):
 30     plt.subplot(5,5,i+1)
 31     plt.xticks([])
 32     plt.yticks([])
 33     plt.grid(False)
 34     plt.imshow(train_images[i],cmap=plt.cm.binary)
 35     plt.xlabel(class_names[train_labels[i]])
 36 plt.show()
 37 
 38 # 搭建网络结构
 39 model = keras.Sequential([
 40     keras.layers.Flatten(input_shape=(28,28)),
 41     keras.layers.Dense(128,activation='relu'),
 42     keras.layers.Dense(10,activation='softmax')
 43 ])
 44 
 45 # 设置损失函数、优化器及训练指标
 46 model.compile(
 47     optimizer='adam',
 48     loss='sparse_categorical_crossentropy',
 49     metrics=['accuracy']
 50 )
 51 
 52 # 训练模型
 53 model.fit(train_images,train_labels,epochs=50)
 54 
 55 # 模型评估
 56 test_loss,test_acc=model.evaluate(test_images,test_labels,verbose=2)
 57 print('/nTest accuracy:',test_acc)
 58 
 59 # 选择测试集中的图像进行预测
 60 predictions=model.predict(test_images)
 61 
 62 # 查看第一个预测
 63 print("预测结果:",np.argmax(predictions[0]))
 64 # 将正确标签打印出来和预测结果对比
 65 print("真实结果:",test_labels[0])
 66 
 67 # 以图形方式查看完整的十个类的预测
 68 def plot_image(i,predictions_array,true_label,img):
 69     predictions_array,true_label,img=predictions_array,true_label[i],img[i]
 70     plt.grid(False)
 71     plt.xticks([])
 72     plt.yticks([])
 73 
 74     plt.imshow(img,cmap=plt.cm.binary)
 75 
 76     predicted_label=np.argmax(predictions_array)
 77     if predicted_label==true_label:
 78         color='blue'
 79     else:
 80         color='red'
 81 
 82     plt.xlabel("{}{:2.0f}%({})".format(class_names[predicted_label],
 83                                        100*np.max(predictions_array),
 84                                        class_names[true_label]),
 85                                        color=color)
 86 
 87 def plot_value_array(i,predictions_array,true_label):
 88     predictions_array,true_label=predictions_array,true_label[i]
 89     plt.grid(False)
 90     plt.xticks(range(10))
 91     plt.yticks([])
 92     thisplot=plt.bar(range(10),predictions_array,color="#777777")
 93     plt.ylim([0,1])
 94     predicted_label=np.argmax(predictions_array)
 95 
 96     thisplot[predicted_label].set_color('red')
 97     thisplot[true_label].set_color('blue')
 98 
 99 i=10
100 plt.figure(figsize=(6,3))
101 plt.subplot(1,2,1)
102 plot_image(i,predictions[i],test_labels,test_images)
103 plt.subplot(1,2,2)
104 plot_value_array(i,predictions[i],test_labels)
105 plt.show()

Guess you like

Origin www.cnblogs.com/zdm-code/p/12198403.html