01-Tensorflow2.0トレーニングとFashion-Mnistデータセットの編集

Tensorflow2.0トレーニングとFashion-Mnistデータセットのコンパイルの主な手順

  1. tf.keras.datasets importfashion_mnistデータセット
  2. tf.keras.Sequential()ビルドモデル
  3. model.compile()モデルのコンパイル
  4. model.fitモデルトレーニング、履歴を取得
  5. 履歴曲線グラフを描く
  6. model.evaluate()テストセットでモデルをテストします
#-*- coding:utf-8 _*-
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow import keras
# 输出库的名字和版本
print(sys.version_info)
for module in tf, mpl, np, pd, sklearn, tf, keras:
    print(module.__name__, module.__version__)

# 指定GPU
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=config)

# 导入数据集 fashion_mnist
fashion_mnist = keras.datasets.fashion_mnist
(x_train_all,y_train_all),(x_test_all,y_test_all) = fashion_mnist.load_data()
x_valid , x_train = x_train_all[:5000],x_train_all[5000:]
y_valid , y_train = y_train_all[:5000],y_train_all[5000:]
x_test , y_test = x_test_all,y_test_all

print(x_train.shape,y_train.shape)
print(x_valid.shape,y_valid.shape)
print(x_test.shape,y_test.shape)

# 显示一张图片
def show_single_image(img_arr):
    plt.imshow(img_arr,cmap='binary')
    plt.show()
show_single_image(x_train[0])

# 利用tf.keras.Sequential()创建
# model = tf.keras.Sequential()
# model.add(keras.layers.Flatten(input_shape=[28,28]))
# model.add(keras.layers.Dense(300,activation='relu'))
# model.add(keras.layers.Dense(100,activation='relu'))
# model.add(keras.layers.Dense(10,activation='softmax'))
model = tf.keras.Sequential([
    keras.layers.Flatten(input_shape=[28,28]),
    keras.layers.Dense(300,activation='relu'),
    keras.layers.Dense(100,activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])

# 模型编译
model.compile(loss='sparse_categorical_crossentropy',
               optimizer='sgd',
               metrics=['accuracy'])

# 输出模型的一些参数
model.layers
model.summary() #模型参数打印

# 模型训练 history.history是一个重要的参数
history = model.fit(x_train,y_train,
                    epochs=10,
                    validation_data=(x_valid,y_valid))

# 绘制history图像
def plot_learning_curves(history):
    pd.DataFrame(history.history).plot(figsize=(8,5))
    plt.grid(True)
    plt.gca().set_ylim(0,1)
    plt.show()
plot_learning_curves(history)
# 测试集上进行测试
model.evaluate(x_test_scaled,y_test)
  • 最終的な分類精度の画像は次のとおりです。
  • 分類精度が非常に低く、効果が非常に悪いことがわかります。それを改善する方法については、次の記事で説明します。
    最終分類精度画像

おすすめ

転載: blog.csdn.net/qq_44783177/article/details/108081952