tensorflow proporciona una herramienta de optimización tensorflow_model_optimization, específicamente para la optimización del modelo de Keras
Principalmente puede realizar poda, cuantificación y agrupamiento de peso.
Los dos primeros se utilizan principalmente aquí
El conjunto de datos utiliza el artículo anterior: modelo mnn de entrenamiento-transformación-predicción
El código de entrenamiento específico es el siguiente
Nota: Debe instalar manualmente tensorflow_model_optimization antes de su uso, solo use pip install tensorflow_model_optimization
import tempfile
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import tensorflow_model_optimization as tfmot
batch_size = 2
img_height = 180
img_width = 180
num_classes = 5
epochs = 50
validation_split=0.2
data_dir='flower_photos'
#数据集准备
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=validation_split,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=validation_split,
subset="validation",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
AUTOTUNE = tf.data.experimental.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
model = keras.Sequential([
keras.layers.InputLayer(input_shape=(img_height, img_height,3)),
keras.layers.Reshape(target_shape=(img_height, img_height, 3)),
layers.Conv2D(16, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Dropout(0.2),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(num_classes)
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
print(model.summary())
model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs
)
tf.keras.models.save_model(model, 'baseline_model.h5', include_optimizer=False)
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open("baseline_model.tflite", "wb").write(tflite_model)
#开始剪枝
print("start pruning")
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
num_images =3670
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
final_sparsity=0.80,
begin_step=0,
end_step=end_step)
}
model_for_pruning = prune_low_magnitude(model, **pruning_params)
model_for_pruning.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model_for_pruning.summary()
logdir = tempfile.mkdtemp()
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep(),
tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]
model_for_pruning.fit(train_ds,
batch_size=batch_size, epochs=5, validation_data=val_ds,
callbacks=callbacks)
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
#开始量化
print("start quantize")
quantize_model = tfmot.quantization.keras.quantize_model
q_aware_model = quantize_model(model_for_export)
q_aware_model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
q_aware_model.summary()
q_aware_model.fit(train_ds,
batch_size=batch_size, epochs=5, validation_data=val_ds)
converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_and_pruned_tflite_model = converter.convert()
quantized_and_pruned_tflite_file ='pruned_and_quantized.tflite'
with open(quantized_and_pruned_tflite_file, 'wb') as f:
f.write(quantized_and_pruned_tflite_model)
Una vez finalizada la operación, echemos un vistazo al archivo del modelo:
Puede ver que el archivo se comprime mucho, aproximadamente 4 veces
Pruébelo a continuación, velocidad de inferencia
El razonamiento del archivo de modelo optimizado es el siguiente
import tensorflow as tf
import cv2
import numpy as np
import time
start=time.time()
image = cv2.imread('397.jpg')
image=cv2.resize(image,(180,180))
image=image[np.newaxis,:,:,:].astype(np.float32)
print(image.shape)
interpreter = tf.lite.Interpreter(model_path='pruned_and_quantized.tflite')
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
for _ in range(10):
interpreter.set_tensor(input_details[0]['index'],image)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
print('avg infer time is %.6f s'%((time.time()-start)/10.0))
resultado de la operación:
Razonamiento del modelo original:
import tensorflow as tf
import cv2
import numpy as np
import time
start=time.time()
image = cv2.imread('397.jpg')
image=cv2.resize(image,(180,180))
image=image[np.newaxis,:,:,:].astype(np.float32)
print(image.shape)
interpreter = tf.lite.Interpreter(model_path='baseline_model.tflite')
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
for _ in range(10):
interpreter.set_tensor(input_details[0]['index'],image)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
print('avg infer time is %.6f s'%((time.time()-start)/10.0))
resultado de la operación:
Inesperadamente, aunque el archivo del modelo se ha vuelto más pequeño, la velocidad es realmente lenta y todavía es docenas de veces más lenta.
Aquí está el ubuntu usado y también probado con win 10. La brecha de razonamiento es mayor y más lenta.
De todos modos, es solo una oración, en realidad es lento después de la optimización. . . .
Solo porque no puedo encontrar la herramienta de referencia de tflite, solo puedo usar este método para probar, tal vez el tiempo no sea confiable.