The Functional API

一、The Functional API

tf.keras.Sequential 模型是层的简单堆叠,无法表示任意模型。使用 Keras 函数式 API 可以构建复杂的模型拓扑,例如:

  • 多输入模型
  • 多输出模型
  • 具有共享层的模型(同一层被调用多次),
  • 具有非序列数据流的模型(例如,残差连接)

使用函数式 API 构建的模型具有以下特征:

  • 层实例可调用并返回张量。
  • 输入张量和输出张量用于定义 tf.keras.Model实例。
  • 此模型的训练方式和 Sequential Model一样。

下面的示例使用functional API构建一个简单的全连接网络

def buildComplexModel():
    print("The Functional API")

    # layer实列作用于一个tensor, 并返回一个tensor
    input = tf.keras.Input(shape=(32,))
    print(type(input))  # <class 'tensorflow.python.framework.ops.Tensor'>
    x = layers.Dense(64, activation='relu')(input)
    print(type(x))      # <class 'tensorflow.python.framework.ops.Tensor'>
    x = layers.Dense(64, activation='relu')(x)
    print(type(x))      # <class 'tensorflow.python.framework.ops.Tensor'>
    predictions = layers.Dense(10, activation='softmax')(x)
    print(type(predictions))    # <class 'tensorflow.python.framework.ops.Tensor'>
    print("predictions: ", predictions) # predictions:  Tensor("dense_2/Identity:0", shape=(None, 10), dtype=float32)

    # 构建模型
    model = tf.keras.Model(inputs=input, outputs=predictions)

    # 编译模型
    model.compile(
        optimizer=tf.keras.optimizers.RMSprop(0.001),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )

    # 训练模型
    # With Numpy arrays
    data = np.random.random((1000, 32))
    labels = np.random.random((1000, 10))
    model.fit(data, labels, batch_size=32, epochs=5)

1.1、简单实现

def TestFunctionalAPI():
    inputs = tf.keras.Input(shape=(784,)) # 784维的向量
    print(inputs.shape, inputs.dtype)

    img_inputs = tf.keras.Input(shape=(32, 32, 3))

    # layer on input
    from tensorflow.keras import layers
    devse = layers.Dense(64, activation='relu')

    # 添加更多的层
    x = devse(inputs) # layer call inputs
    x = layers.Dense(64, activation='relu')(x)
    outputs = layers.Dense(10, activation='softmax')(x)

    # 创建模型
    model = tf.keras.Model(inputs=inputs, outputs=outputs)

    print(model.summary())

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 784)]             0         
_________________________________________________________________
dense (Dense)                (None, 64)                50240     
_________________________________________________________________
dense_1 (Dense)              (None, 64)                4160      
_________________________________________________________________
dense_2 (Dense)              (None, 10)                650       
=================================================================
Total params: 55,050
Trainable params: 55,050
Non-trainable params: 0
_________________________________________________________________
None

未完待续。。。

发布了784 篇原创文章 · 获赞 90 · 访问量 44万+

猜你喜欢

转载自blog.csdn.net/wuxintdrh/article/details/103577167