tensorflow2.x的模型输入从NHWC格式转换为NCHW格式

tensorflow2.x的模型输入默认格式是NHWC格式,但是某些应用场景下需要在模型中将NHWC输入格式转换为输入NCHW格式,操作代码如下

import tensorflow as tf


model_path = "./xxx.h5"
output_path = "./yyy.h5"
model.load_model(model_path)  # 当前输入尺寸是128*128
new_model = tf.keras.models.Sequential([Input((3, 128, 128)), tf.keras.layers.Lambda(lambda x: tf.transpose(x,[0,2,3,1])), model])
new_model.summary()
new_model.save(output_path)

猜你喜欢

转载自blog.csdn.net/BIT_Legend/article/details/122270071