Estructura básica de la red Tensorflow

El contenedor de red
encapsula múltiples capas de red en un gran modelo de red a través del contenedor de red Sequential proporcionado por Keras, y solo necesita llamar a la instancia del modelo de red una vez para completar la operación de propagación secuencial de datos desde la primera capa hasta la última capa.
(1) El contenedor secuencial se encapsula como una red:

import tensorflow as tf
from tensorflow.keras import  layers,Sequential
model = Sequential([#封装一个网络
    layers.Dense(3,activation=None),#全连接层,不使用激活函数
    layers.ReLU(),#激活函数层
    layers.Dense(2,activation=None),
    layers.ReLU()
])
x = tf.random.normal([4,3])
out = model(x)
out
'''
<tf.Tensor: shape=(4, 2), dtype=float32, numpy=
array([[0.02129143, 0.        ],
       [0.00915629, 0.        ],
       [0.        , 0.9278006 ],
       [1.190076  , 0.        ]], dtype=float3
'''

(2) El contenedor secuencial pasa el método add():

layers_num = 3#堆叠三层
model = Sequential([])#创建空间的网络容器
for _ in range(layers_num):
    model.add(layers.Dense(3))#添加全连接层
    model.add(layers.ReLU())#添加激活层
model.build(input_shape=(4,4))#创建网络参数
model.summary()#打印出网络结构和参数量
'''
Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_9 (Dense)              (4, 3)                    15        
_________________________________________________________________
re_lu_9 (ReLU)               (4, 3)                    0         
_________________________________________________________________
dense_10 (Dense)             (4, 3)                    12        
_________________________________________________________________
re_lu_10 (ReLU)              (4, 3)                    0         
_________________________________________________________________
dense_11 (Dense)             (4, 3)                    12        
_________________________________________________________________
re_lu_11 (ReLU)              (4, 3)                    0         
=================================================================
Total params: 39
Trainable params: 39
Non-trainable params: 0
_________________________________________________________________
'''

(3) La lista de tensores a optimizar y la lista de todos los tensores de todas las capas:

for p in model.trainable_variables:
    print(p.name,p.shape)
'''
dense_9/kernel:0 (4, 3)
dense_9/bias:0 (3,)
dense_10/kernel:0 (3, 3)
dense_10/bias:0 (3,)
dense_11/kernel:0 (3, 3)
dense_11/bias:0 (3,)
'''

La
clase de capa del modelo de configuración es la clase principal de la capa de red, que define algunas funciones comunes de la capa de red, como agregar pesos y administrar listas de pesos.
La clase Model es la clase principal de la red. Además de las funciones de la clase Layer, también agrega funciones convenientes como guardar modelos, cargar modelos, entrenar y probar modelos. Sequential también es una subclase de Model, por lo que tiene todas las funciones de la clase Model.

model = Sequential([#封装一个网络
    layers.Dense(256,activation='relu'),
    layers.Dense(128,activation='relu'),
    layers.Dense(56,activation='relu'),
    layers.Dense(28,activation='relu'),
    layers.Dense(10),
])
model.build(input_shape=(4,28,28))
model.summary()
#compile()函数中指定的优化器、损失函数等参数也是自行训练时需要设置的参数
from tensorflow.keras import  optimizers,losses
from tensorflow import  metrics
model.compile(optimizer=optimizers.Adam(lr=0.01),
             loss = losses.CategoricalCrossentropy(from_logits=True),
             metrics=['accuracy'])#设置测量指标为准确率

Entrenamiento del modelo
Envíe el conjunto de datos a entrenar y el conjunto de datos para su verificación a través de la función fit(), este paso se llama entrenamiento del modelo

#epochs参数指定训练迭代的Epoch数量;validation_data参数指定用于验证(测试)的数据集和验证的频率validation_freq

history = model.fit(train_x,epochs=10,validation_data=train_val,validation_freq=5)

La función fit() devolverá el historial de registro de datos del proceso de entrenamiento, donde history.history es un objeto de diccionario, que contiene elementos de registro como indicadores de pérdida y medición durante el proceso de entrenamiento.

history.history()

Prueba del modelo
La predicción del modelo se puede completar a través del método model.predict(x).

#模型预测
model.predict(x_test)
#测试模型性能
model.evaluate(x_test)

Guarde el modelo
(1) El estado de la
red de tensores se refleja principalmente en la estructura de la red y los datos de tensores dentro de la capa de red. Por lo tanto, bajo la condición de tener el archivo fuente de la estructura de la red, es más fácil directamente guardar los parámetros del tensor de red en el sistema de archivos.Una forma de magnitud.

model.save_weights('weights.ckpt')#保存模型的所有张量数据

Cuando sea necesario, primero cree el objeto de red y luego llame al método load_weights(path) del objeto de red para escribir los valores de tensor guardados en el archivo de modelo especificado en los parámetros de red actuales.

model.load_weights('weights.ckpt')

Este método de guardar y cargar la red es el más ligero.Solo los valores de los parámetros del tensor se guardan en el archivo, y no hay otros parámetros estructurales adicionales. Pero necesita usar la misma estructura de red para restaurar el estado de la red correctamente, por lo que generalmente se usa cuando hay archivos fuente de red.

(2) Método de red Este
método no requiere archivos fuente de red, pero solo necesita archivos de parámetros del modelo para restaurar el modelo de red. La estructura del modelo y los parámetros del modelo se pueden guardar en el archivo de ruta a través de la función Model.save(ruta), y la estructura de la red y los parámetros de la red se pueden restaurar a través de keras.models.load_model(ruta) sin necesidad de archivos fuente de red.

#保存模型结构与模型参数到文件
model.save('model.h5')
#从文件恢复网络结构与网络参数
model = tf.keras.model.load_model('model.h5')

Además de guardar los parámetros del modelo, el archivo model.h5 también debe guardar la información de la estructura de la red, y el objeto del modelo de red se puede restaurar directamente desde el archivo sin crear un modelo por adelantado.

(3) Método SavedModel
A través de tf.saved_model.save(network, path), el modelo se puede guardar en el directorio de la ruta en el método SavedModel.

#保存模型结构与模型参数到文件
tf.saved_model.save(model,'model_savedmodel')

El usuario no necesita preocuparse por el formato de guardado del archivo y solo necesita restaurar el objeto modelo a través de la función tf.saved_model.load.

#从文件恢复网络结构与网络参数
model = tf.saved_model.load('model_savedmodel')

Supongo que te gusta

Origin blog.csdn.net/weixin_56260304/article/details/128279218
Recomendado
Clasificación