Quick start with tf.keras-functional API

In the previous article , the sequence model is simply used. Since the sequence model can only be expressed as a simple stack of neural network layers, it cannot be expressed as an arbitrary model. Therefore, functional formulas will be used here APIto build more complex models.
In the article on iris classification , the sequence model is simply used to solve this problem. Here we will use the functional formula APIto solve this problem and experience APIthe convenience of using the functional formula .
First look at the official case:

import tensorflow as tf

inputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

The modified code is as follows:

import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf
from sklearn.datasets import load_iris
# 训练集和测试集的划分
from sklearn.model_selection import train_test_split

x_data = load_iris().data  # 特征,【花萼长度,花萼宽度,花瓣长度,花瓣宽度】
y_data = load_iris().target # 分类
x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.30, random_state=42)

input_data = tf.keras.Input(shape=(len(x_data[0],)))
h_1 = tf.keras.layers.Dense(4, activation="relu")(input_data)
h_2 = tf.keras.layers.Dense(3, activation="softmax")(h_1)

model = tf.keras.Model(inputs=input_data, outputs=h_2)
      
model.compile(optimizer=tf.keras.optimizers.Adam(), 
             loss=tf.keras.losses.sparse_categorical_crossentropy, 
             metrics=['accuracy'])

history = model.fit(x_train, y_train, epochs=300)


for key in history.history.keys():
    plt.plot(history.epoch, history.history[key])

Insert picture description here

In fact, we can see that in the above small case, the difference from the iris classification article is that the neural network layer can be arbitrarily specified to which layer is connected, and then the tf.keras.Inputlayer is added .

Finally, let's take a look at the network, first we need to install pydot, graphvizand then python-graphviz.
In Anacondathe Navigatordirect search can be installed.
Insert picture description here

Then we use:

tf.keras.utils.plot_model(model, 'model_info.png', show_shapes=True)

Insert picture description here

tf.keras.utils.plot_model(model, 'mnist_model.png')

Insert picture description here

Guess you like

Origin blog.csdn.net/qq_26460841/article/details/113550713