En el artículo anterior , el modelo de secuencia simplemente se usa, dado que el modelo de secuencia solo se puede expresar como una simple pila de capas de redes neuronales, no se puede expresar como un modelo arbitrario. Por lo tanto, aquí se utilizarán fórmulas funcionales API
para construir modelos más complejos.
En el artículo sobre clasificación del iris , simplemente se usa el modelo de secuencia para resolver este problema, aquí usaremos la fórmula funcional API
para resolver este problema y experimentar API
la conveniencia de usar la fórmula funcional .
Primero mira el caso oficial:
import tensorflow as tf
inputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
El código modificado es el siguiente:
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf
from sklearn.datasets import load_iris
# 训练集和测试集的划分
from sklearn.model_selection import train_test_split
x_data = load_iris().data # 特征,【花萼长度,花萼宽度,花瓣长度,花瓣宽度】
y_data = load_iris().target # 分类
x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.30, random_state=42)
input_data = tf.keras.Input(shape=(len(x_data[0],)))
h_1 = tf.keras.layers.Dense(4, activation="relu")(input_data)
h_2 = tf.keras.layers.Dense(3, activation="softmax")(h_1)
model = tf.keras.Model(inputs=input_data, outputs=h_2)
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
history = model.fit(x_train, y_train, epochs=300)
for key in history.history.keys():
plt.plot(history.epoch, history.history[key])
De hecho, podemos ver que en el caso pequeño anterior, la diferencia con el artículo de clasificación del iris es que la capa de red neuronal se puede especificar arbitrariamente a qué capa está conectada, y luego tf.keras.Input
se agrega la capa.
Por último, vamos a echar un vistazo a la red, primero tenemos que instalar pydot
, graphviz
y luego python-graphviz
.
En Anaconda
la Navigator
búsqueda directa se puede instalar.
Entonces usamos:
tf.keras.utils.plot_model(model, 'model_info.png', show_shapes=True)
tf.keras.utils.plot_model(model, 'mnist_model.png')