In the previous article , the sequence model is simply used. Since the sequence model can only be expressed as a simple stack of neural network layers, it cannot be expressed as an arbitrary model. Therefore, functional formulas will be used here API
to build more complex models.
In the article on iris classification , the sequence model is simply used to solve this problem. Here we will use the functional formula API
to solve this problem and experience API
the convenience of using the functional formula .
First look at the official case:
import tensorflow as tf
inputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
The modified code is as follows:
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf
from sklearn.datasets import load_iris
# 训练集和测试集的划分
from sklearn.model_selection import train_test_split
x_data = load_iris().data # 特征,【花萼长度,花萼宽度,花瓣长度,花瓣宽度】
y_data = load_iris().target # 分类
x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.30, random_state=42)
input_data = tf.keras.Input(shape=(len(x_data[0],)))
h_1 = tf.keras.layers.Dense(4, activation="relu")(input_data)
h_2 = tf.keras.layers.Dense(3, activation="softmax")(h_1)
model = tf.keras.Model(inputs=input_data, outputs=h_2)
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
history = model.fit(x_train, y_train, epochs=300)
for key in history.history.keys():
plt.plot(history.epoch, history.history[key])
In fact, we can see that in the above small case, the difference from the iris classification article is that the neural network layer can be arbitrarily specified to which layer is connected, and then the tf.keras.Input
layer is added .
Finally, let's take a look at the network, first we need to install pydot
, graphviz
and then python-graphviz
.
In Anaconda
the Navigator
direct search can be installed.
Then we use:
tf.keras.utils.plot_model(model, 'model_info.png', show_shapes=True)
tf.keras.utils.plot_model(model, 'mnist_model.png')