01-Tensorflow2.0对于Fashion-Mnist数据集的训练与编译

Tensorflow2.0对于Fashion-Mnist数据集的训练与编译主要步骤

  1. tf.keras.datasets导入fashion_mnist数据集
  2. tf.keras.Sequential()搭建模型
  3. model.compile()模型编译
  4. model.fit模型训练,得到history
  5. 绘制history曲线图
  6. model.evaluate()测试集上测试模型
#-*- coding:utf-8 _*-
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow import keras
# 输出库的名字和版本
print(sys.version_info)
for module in tf, mpl, np, pd, sklearn, tf, keras:
    print(module.__name__, module.__version__)

# 指定GPU
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=config)

# 导入数据集 fashion_mnist
fashion_mnist = keras.datasets.fashion_mnist
(x_train_all,y_train_all),(x_test_all,y_test_all) = fashion_mnist.load_data()
x_valid , x_train = x_train_all[:5000],x_train_all[5000:]
y_valid , y_train = y_train_all[:5000],y_train_all[5000:]
x_test , y_test = x_test_all,y_test_all

print(x_train.shape,y_train.shape)
print(x_valid.shape,y_valid.shape)
print(x_test.shape,y_test.shape)

# 显示一张图片
def show_single_image(img_arr):
    plt.imshow(img_arr,cmap='binary')
    plt.show()
show_single_image(x_train[0])

# 利用tf.keras.Sequential()创建
# model = tf.keras.Sequential()
# model.add(keras.layers.Flatten(input_shape=[28,28]))
# model.add(keras.layers.Dense(300,activation='relu'))
# model.add(keras.layers.Dense(100,activation='relu'))
# model.add(keras.layers.Dense(10,activation='softmax'))
model = tf.keras.Sequential([
    keras.layers.Flatten(input_shape=[28,28]),
    keras.layers.Dense(300,activation='relu'),
    keras.layers.Dense(100,activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])

# 模型编译
model.compile(loss='sparse_categorical_crossentropy',
               optimizer='sgd',
               metrics=['accuracy'])

# 输出模型的一些参数
model.layers
model.summary() #模型参数打印

# 模型训练 history.history是一个重要的参数
history = model.fit(x_train,y_train,
                    epochs=10,
                    validation_data=(x_valid,y_valid))

# 绘制history图像
def plot_learning_curves(history):
    pd.DataFrame(history.history).plot(figsize=(8,5))
    plt.grid(True)
    plt.gca().set_ylim(0,1)
    plt.show()
plot_learning_curves(history)
# 测试集上进行测试
model.evaluate(x_test_scaled,y_test)
  • 最后的分类精确度图像如下:
  • 我们可以看到,分类精度非常低,效果很不好,具体怎么改进,之后的文章中会提到
    最后的分类精确度图像

猜你喜欢

转载自blog.csdn.net/qq_44783177/article/details/108081952