clasificación de la imagen tensorflow2.0 de combate --- clasificación de conjunto de datos de la moda-mnist

De hecho, la idea de escribir este blog se describe principalmente algunos de tf2.0 uso de la API común y cómo construir rápida y fácilmente una red neuronal utilizando los tf.keras

1. En primer lugar hablar sobre tf.keras, con la que podemos fácilmente modelos de redes de construcción que quieren construir, luchar como bloques de construcción, capa por capa de red de la pila hacia arriba. Pero la red gradiente de profundidad desaparecerá y así sucesivamente, por lo que sólo será capaz de construir un modelo de red a los resultados del modelo también necesita algo de conocimiento de otras formas de optimizar.

Para introducir los conjuntos de datos de la moda-mnist puede echar un vistazo a los enlaces a continuación
describe la manera en Github-mnist

2. A continuación, la charla sobre la clasificación general de uso general para la optimización de imágenes

  • 1. Imagen de normalización de datos (estandarización): acelerar la convergencia de redes, principios específicos puede imaginar en un gradiente concéntrico a lo largo el más rápido para llegar al centro sin alcance gráfica formal, el centro será giros y vueltas a lo largo del gradiente
    Aquí Insertar imagen Descripción
  • 2. función de datos mejoras: Enlaces
  • 3. Red parámetros de búsqueda súper: obtener los mejores parámetros del modelo, principalmente red de búsqueda, búsqueda aleatoria, algoritmos genéticos, búsqueda heurística
  • 4.dropout de aplicaciones, earlystopping, regularización y otros métodos: mediante la adición de una capa de olvido, regularización y parada temprana para prevenir modelo exceso de ajuste

3. La sección de código de aplicación y los resultados

#先导入一些常用库,后续用到再增加
import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
import sklearn
import os
import sys

#看一下版本,确认是2.0
print(tf.__version__)

Aquí Insertar imagen Descripción

#使用keras自带的模块导入数据,并且切分训练集、验证集、测试集,对训练数据进行标准化处理
fashion_mnist=keras.datasets.fashion_mnist
(x_train_all,y_train_all),(x_test,y_test)=fashion_mnist.load_data()
print(x_train_all.shape)
print(y_train_all.shape)
print(x_test.shape)
print(y_test.shape)

#切分训练集和验证集
x_train,x_valid=x_train_all[5000:],x_train_all[:5000]
y_train,y_valid=y_train_all[5000:],y_train_all[:5000]

print(x_train.shape)
print(y_train.shape)
print(x_valid.shape)
print(y_valid.shape)


#标准化
from sklearn.preprocessing import StandardScaler

scaler=StandardScaler()
x_train_scaled=scaler.fit_transform(x_train.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
x_valid_scaled=scaler.fit_transform(x_valid.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
x_test_scaled=scaler.fit_transform(x_test.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
#可视化一下图片以及对应的标签
#展示多张图片
def show_imgs(n_rows,n_cols,x_data,y_data,class_names):
    assert len(x_data)==len(y_data)#判断输入数据的信息是否对应一致
    assert n_rows*n_cols<=len(x_data)#保证不会出现数据量不够
    plt.figure(figsize=(n_cols*2,n_rows*1.6))
    for row in range(n_rows):
        for col in range(n_cols):
            index=n_cols*row+col   #得到当前展示图片的下标
            plt.subplot(n_rows,n_cols,index+1)
            plt.imshow(x_data[index],cmap="binary",interpolation="nearest")
            plt.axis("off")
            plt.title(class_names[y_data[index]])
    plt.show()
    
class_names=['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
show_imgs(5,5,x_train,y_train,class_names)

Aquí Insertar imagen Descripción

#搭建网络模型

model=keras.models.Sequential()
model.add(keras.layers.Flatten(input_shape=[28,28]))
model.add(keras.layers.Dense(300,activation="relu"))
model.add(keras.layers.Dense(100,activation="relu"))
model.add(keras.layers.Dense(10,activation="softmax"))
model.compile(loss="sparse_categorical_crossentropy",optimizer="adam",metrics=["acc"])
model.summary()

Aquí Insertar imagen Descripción
Aquí params información de la red digital en la forma de hacer?
W = Y X + b y de acuerdo con las reglas de la multiplicación de matrices (Ninguno, 784) a (Ninguna, 300) es una matriz intermedia (784300) y la magnitud del término sesgo b es 300, por lo que 784 300 + 300 = 235.500, esto es un pequeño detalle justo mencionar.

#训练,并且保存最好的模型、训练的记录以及使用早停防止过拟合
import datetime
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = os.path.join('logs', current_time)
output_model=os.path.join(logdir,"fashionmnist_model.h5")
callbacks=[
    keras.callbacks.TensorBoard(log_dir=logdir),
    keras.callbacks.ModelCheckpoint(output_model,save_best_only=True),
    keras.callbacks.EarlyStopping(patience=5,min_delta=1e-3)
          ]


history=model.fit(x_train_scaled,y_train,epochs=30,validation_data=(x_valid_scaled,y_valid),callbacks=callbacks)

Aquí Insertar imagen Descripción
Antes de que yo era dueño de una carpeta llamada usando TensorBoard y correr ModelCheckpoint a estar mal, me pareció un poco como un insecto en las ventanas, por encima del cual es una solución, y el aspecto tensorboard entonces abierta.
Aquí Insertar imagen Descripción
Aquí Insertar imagen Descripción
El mejor modelo también se guarda como un archivo h5, llamada fácil

def plot_learning_curves(history):
    pd.DataFrame(history.history).plot(figsize=(8,5))
    plt.grid()
    plt.gca().set_ylim(0,1)
    plt.show()

plot_learning_curves(history)

Este es un tiempo para sacar sus propios cambios en la formación, y casi por encima de
Aquí Insertar imagen Descripción

#最后在测试集上的准确率
loss,acc=model.evaluate(x_test_scaled,y_test,verbose=0)
print("在测试集上的损失为:",loss)
print("在测试集上的准确率为:",acc)

Aquí Insertar imagen Descripción

#得到测试集上的预测标签,可视化和真实标签的区别
y_pred=model.predict(x_test_scaled)
predict = np.argmax(y_pred,axis=1) 

show_imgs(3,5,x_test,predict,class_names)
show_imgs(3,5,x_test,y_test,class_names)

Los resultados predicen
Aquí Insertar imagen Descripción
resultados reales
Aquí Insertar imagen Descripción

4. Resumen:

Leer el ejemplo anterior, el uso de tf.keras modelo de construcción está escrito

model=keras.models.Sequential()
model.add(...)
model.add(...)
...
model.compile(...)
model.fit(...)

#当然也可以写成
model=keras.models.Sequential([
	...
	...
	...
])
#这两者差别不大


#还有函数式的写法
inputs=...
hidden1=...(inputs)
....
#子类的写法
class ...:
	...

Sin embargo, para los parámetros del modelo, como la selección ( "sparse_categorical_crossentropy" y "categorical_crossentropy" o "binary_crossentropy") la función de pérdida de qué tipo de pérdida cuando es necesario utilizar la función más adecuada, seleccione la función de activación de cada capa de la red, el optimizador elección ...... necesidad de comprender el significado de que se puede utilizar en las circunstancias apropiadas, que no dio ejemplos aquí utilizar super buscar parámetros óptimos de los parámetros del modelo, el siguiente debe escribir sobre un ejemplo estupendo de los parámetros de la búsqueda.

Publicado 85 artículos originales · ganado elogios 55 · Vistas a 20000 +

Supongo que te gusta

Origin blog.csdn.net/shelgi/article/details/103276140
Recomendado
Clasificación