Inicio rápido con la API funcional tf.keras

En el artículo anterior , el modelo de secuencia simplemente se usa, dado que el modelo de secuencia solo se puede expresar como una simple pila de capas de redes neuronales, no se puede expresar como un modelo arbitrario. Por lo tanto, aquí se utilizarán fórmulas funcionales APIpara construir modelos más complejos.
En el artículo sobre clasificación del iris , simplemente se usa el modelo de secuencia para resolver este problema, aquí usaremos la fórmula funcional APIpara resolver este problema y experimentar APIla conveniencia de usar la fórmula funcional .
Primero mira el caso oficial:

import tensorflow as tf

inputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

El código modificado es el siguiente:

import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf
from sklearn.datasets import load_iris
# 训练集和测试集的划分
from sklearn.model_selection import train_test_split

x_data = load_iris().data  # 特征,【花萼长度,花萼宽度,花瓣长度,花瓣宽度】
y_data = load_iris().target # 分类
x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.30, random_state=42)

input_data = tf.keras.Input(shape=(len(x_data[0],)))
h_1 = tf.keras.layers.Dense(4, activation="relu")(input_data)
h_2 = tf.keras.layers.Dense(3, activation="softmax")(h_1)

model = tf.keras.Model(inputs=input_data, outputs=h_2)
      
model.compile(optimizer=tf.keras.optimizers.Adam(), 
             loss=tf.keras.losses.sparse_categorical_crossentropy, 
             metrics=['accuracy'])

history = model.fit(x_train, y_train, epochs=300)


for key in history.history.keys():
    plt.plot(history.epoch, history.history[key])

Inserte la descripción de la imagen aquí

De hecho, podemos ver que en el caso pequeño anterior, la diferencia con el artículo de clasificación del iris es que la capa de red neuronal se puede especificar arbitrariamente a qué capa está conectada, y luego tf.keras.Inputse agrega la capa.

Por último, vamos a echar un vistazo a la red, primero tenemos que instalar pydot, graphvizy luego python-graphviz.
En Anacondala Navigatorbúsqueda directa se puede instalar.
Inserte la descripción de la imagen aquí

Entonces usamos:

tf.keras.utils.plot_model(model, 'model_info.png', show_shapes=True)

Inserte la descripción de la imagen aquí

tf.keras.utils.plot_model(model, 'mnist_model.png')

Inserte la descripción de la imagen aquí

Supongo que te gusta

Origin blog.csdn.net/qq_26460841/article/details/113550713
Recomendado
Clasificación