keras poda-cuantización-razonamiento

tensorflow proporciona una herramienta de optimización tensorflow_model_optimization, específicamente para la optimización del modelo de Keras

Principalmente puede realizar poda, cuantificación y agrupamiento de peso.

Los dos primeros se utilizan principalmente aquí

El conjunto de datos utiliza el artículo anterior: modelo mnn de entrenamiento-transformación-predicción

El código de entrenamiento específico es el siguiente

Nota: Debe instalar manualmente tensorflow_model_optimization antes de su uso, solo use pip install tensorflow_model_optimization

import tempfile
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import tensorflow_model_optimization as tfmot



batch_size = 2
img_height = 180
img_width = 180
num_classes = 5
epochs = 50
validation_split=0.2
data_dir='flower_photos'

#数据集准备
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=validation_split,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)
 
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=validation_split,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

AUTOTUNE = tf.data.experimental.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)


model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(img_height, img_height,3)),
  keras.layers.Reshape(target_shape=(img_height, img_height, 3)),
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Dropout(0.2),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(num_classes)
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
 
print(model.summary())
model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)

tf.keras.models.save_model(model, 'baseline_model.h5', include_optimizer=False)
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open("baseline_model.tflite", "wb").write(tflite_model)


#开始剪枝
print("start pruning")
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
num_images =3670
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=end_step)
}

model_for_pruning = prune_low_magnitude(model, **pruning_params)
model_for_pruning.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model_for_pruning.summary()
logdir = tempfile.mkdtemp()
callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

model_for_pruning.fit(train_ds, 
                  batch_size=batch_size, epochs=5, validation_data=val_ds,
                  callbacks=callbacks)
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
#开始量化
print("start quantize")
quantize_model = tfmot.quantization.keras.quantize_model
q_aware_model = quantize_model(model_for_export)
q_aware_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

q_aware_model.summary()
q_aware_model.fit(train_ds,
                  batch_size=batch_size, epochs=5, validation_data=val_ds)
converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_and_pruned_tflite_model = converter.convert()
quantized_and_pruned_tflite_file ='pruned_and_quantized.tflite'
with open(quantized_and_pruned_tflite_file, 'wb') as f:
  f.write(quantized_and_pruned_tflite_model)






Una vez finalizada la operación, echemos un vistazo al archivo del modelo:

Puede ver que el archivo se comprime mucho, aproximadamente 4 veces

Pruébelo a continuación, velocidad de inferencia

El razonamiento del archivo de modelo optimizado es el siguiente

import tensorflow as tf
import cv2
import numpy as np
import time

start=time.time()
image = cv2.imread('397.jpg')
image=cv2.resize(image,(180,180))
image=image[np.newaxis,:,:,:].astype(np.float32)
print(image.shape)
interpreter = tf.lite.Interpreter(model_path='pruned_and_quantized.tflite')
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
for _ in range(10):
    interpreter.set_tensor(input_details[0]['index'],image)
    interpreter.invoke()
    output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
print('avg infer time is %.6f s'%((time.time()-start)/10.0))

 resultado de la operación:

Razonamiento del modelo original:

import tensorflow as tf
import cv2
import numpy as np
import time

start=time.time()
image = cv2.imread('397.jpg')
image=cv2.resize(image,(180,180))
image=image[np.newaxis,:,:,:].astype(np.float32)
print(image.shape)
interpreter = tf.lite.Interpreter(model_path='baseline_model.tflite')
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
for _ in range(10):
    interpreter.set_tensor(input_details[0]['index'],image)
    interpreter.invoke()
    output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
print('avg infer time is %.6f s'%((time.time()-start)/10.0))

resultado de la operación:

 

Inesperadamente, aunque el archivo del modelo se ha vuelto más pequeño, la velocidad es realmente lenta y todavía es docenas de veces más lenta. 

Aquí está el ubuntu usado y también probado con win 10. La brecha de razonamiento es mayor y más lenta.

De todos modos, es solo una oración, en realidad es lento después de la optimización. . . .

Solo porque no puedo encontrar la herramienta de referencia de tflite, solo puedo usar este método para probar, tal vez el tiempo no sea confiable. 

Supongo que te gusta

Origin blog.csdn.net/zhou_438/article/details/113057903
Recomendado
Clasificación