Today, through forums happen to know, after mnist, also appeared Fashion MNIST a classic mnist intended to replace the data set, like mnist, it is also being used as a "hello world" deep learning program, but also by the 70k sheets 28 * consisting of 28 pictures, which are also divided into 10 categories, have been used as training 60k, 10k was used as a test. The only difference is, fashion mnist ten categories by the handwritten numbers into the clothing. The ten categories are as follows:
'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
The design process is as follows:
• First, get a data set, tensorflow get fashion mnist methods and mnist similar, using keras.datasets.fashion_mnist.load_data () can
· The data set into training and test sets
· Since image pixel value range is 0-255, preprocessing the data set, scaled to the pixel value range of 0 to 1 (i.e., divided by 255)
· Build the network model (784 → 128 (relu) → 10 (softmax)), fully connected
· Compilation model, designed loss function (logarithmic loss), Optimizer (adam) and training indicators (accuracy)
· Trainer
· Evaluate the accuracy (test data using matplotlib visualization)
Sources and features about Adam optimizer please refer to: https://www.jianshu.com/p/aebcaf8af76e
About matplotlib data visualization please refer to: https://blog.csdn.net/xHibiki/article/details/84866887
Portion of the training set data visualization as follows:
Made a total of 50 training, losses and precision when training began as follows:
Losses and precision when training is completed as follows:
Model performance on the test set as follows:
Select a set of test images to predict visual results are as follows:
Code is as follows:
1 import tensorflow as tf 2 from tensorflow import keras 3 import numpy as np 4 import matplotlib.pyplot as plt 5 6 # 导入fashion mnist数据集 7 fashion_mnist = keras.datasets.fashion_mnist 8 (train_images,train_labels),(test_images,test_labels) = fashion_mnist.load_data() 9 10 # 衣服类别 11 class_names = ['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal', 12 'Shirt','Sneaker','Bag','Ankle boot'] 13 print(train_images.shape,len(train_labels)) 14 print(test_images.shape,len(test_labels)) 15 16 # 查看图片 17 plt.figure() 18 plt.imshow(train_images[0]) 19 plt.colorbar() 20 plt.grid(False) 21 plt.show() 22 23 # 预处理数据,将像素值除以255,使其缩放到0到1的范围 24 train_images = train_images / 255.0 25 test_images = test_images / 255.0 26 27 # 验证数据格式的正确性,显示训练集前25张图像并注明类别 28 plt.figure(figsize=(10,10)) 29 for i in range(25): 30 plt.subplot(5,5,i+1) 31 plt.xticks([]) 32 plt.yticks([]) 33 plt.grid(False) 34 plt.imshow(train_images[i],cmap=plt.cm.binary) 35 plt.xlabel(class_names[train_labels[i]]) 36 plt.show() 37 38 # 搭建网络结构 39 model = keras.Sequential([ 40 keras.layers.Flatten(input_shape=(28,28)), 41 keras.layers.Dense(128,activation='relu'), 42 keras.layers.Dense(10,activation='softmax') 43 ]) 44 45 # 设置损失函数、优化器及训练指标 46 model.compile( 47 optimizer='adam', 48 loss='sparse_categorical_crossentropy', 49 metrics=['accuracy'] 50 ) 51 52 # 训练模型 53 model.fit(train_images,train_labels,epochs=50) 54 55 # 模型评估 56 test_loss,test_acc=model.evaluate(test_images,test_labels,verbose=2) 57 print('/nTest accuracy:',test_acc) 58 59 # 选择测试集中的图像进行预测 60 predictions=model.predict(test_images) 61 62 # 查看第一个预测 63 print("预测结果:",np.argmax(predictions[0])) 64 # 将正确标签打印出来和预测结果对比 65 print("真实结果:",test_labels[0]) 66 67 # 以图形方式查看完整的十个类的预测 68 def plot_image(i,predictions_array,true_label,img): 69 predictions_array,true_label,img=predictions_array,true_label[i],img[i] 70 plt.grid(False) 71 plt.xticks([]) 72 plt.yticks([]) 73 74 plt.imshow(img,cmap=plt.cm.binary) 75 76 predicted_label=np.argmax(predictions_array) 77 if predicted_label==true_label: 78 color='blue' 79 else: 80 color='red' 81 82 plt.xlabel("{}{:2.0f}%({})".format(class_names[predicted_label], 83 100*np.max(predictions_array), 84 class_names[true_label]), 85 color=color) 86 87 def plot_value_array(i,predictions_array,true_label): 88 predictions_array,true_label=predictions_array,true_label[i] 89 plt.grid(False) 90 plt.xticks(range(10)) 91 plt.yticks([]) 92 thisplot=plt.bar(range(10),predictions_array,color="#777777") 93 plt.ylim([0,1]) 94 predicted_label=np.argmax(predictions_array) 95 96 thisplot[predicted_label].set_color('red') 97 thisplot[true_label].set_color('blue') 98 99 i=10 100 plt.figure(figsize=(6,3)) 101 plt.subplot(1,2,1) 102 plot_image(i,predictions[i],test_labels,test_images) 103 plt.subplot(1,2,2) 104 plot_value_array(i,predictions[i],test_labels) 105 plt.show()