Tensorflow1.x下计算模型FLOPs

写在前面

使用Tensorflow2.x时,我们可以使用keras-flops包实现模型的FLOPs计算,但是keras-flops包仅支持Tensorflow2.2以上的版本。经过查阅网上的代码,成功实现了Tensorflow1.15.0下计算模型的FLOPs。

代码&运行结果

1. 代码

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers
import tensorflow.compat.v1.keras.backend as K


def stats_graph():
    # 必须要下面这行代码
    tf.compat.v1.disable_eager_execution()
    print(tf.__version__)
    sess = tf.compat.v1.Session()
    graph = sess.graph
    flops = tf.compat.v1.profiler.profile(graph, options=tf.compat.v1.profiler.ProfileOptionBuilder.float_operation())
    params = tf.compat.v1.profiler.profile(graph, options=tf.compat.v1.profiler.ProfileOptionBuilder.trainable_variables_parameter())
    print('GFLOPs: {};    Trainable params: {}'.format(flops.total_float_ops/1000000000.0, params.total_parameters))


def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.compat.v1.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.compat.v1.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = tf.graph_util.convert_variables_to_constants(
            session, input_graph_def, output_names, freeze_var_names)
        return frozen_graph


def load_pb(pb):
    with tf.io.gfile.GFile(pb, "rb") as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph


def create_model(num_class):
    #测试FLOPs时使用下面的模型定义方式
    pretrained_model = keras.applications.MobileNetV2(input_tensor = tf.compat.v1.placeholder('float32', shape=(1, 224, 224, 3)),
                                                      include_top=False,  # 去掉网络结构的最后一层
                                                      weights='imagenet',          # 参数为None从头开始训练,为'imagenet'从已训练好的模型开始训练
                                                      pooling='avg')
    pretrained_model.trainable = True  # 是否冻结网络(冻结网络参数不更新), 为trainable=True时,参数更新      
    model = keras.Model(inputs=pretrained_model.inputs, outputs=pretrained_model.outputs)
    model.summary()
    return model


def main():
    run_meta = tf.RunMetadata()
    with tf.Session(graph=tf.Graph()) as sess:
        K.set_session(sess)
        model =create_model(5)
        frozen_graph = freeze_session(K.get_session(), output_names=[out.op.name for out in model.outputs])

        with tf.io.gfile.GFile('graph.pb', "wb") as f:
            f.write(frozen_graph.SerializeToString())

        g2 = load_pb('./graph.pb')
        with g2.as_default():
            flops = tf.compat.v1.profiler.profile(g2, options=tf.compat.v1.profiler.ProfileOptionBuilder.float_operation())
            print('FLOP after freezing {} GFLOPS'.format(float(flops.total_float_ops) * 1e-9))

def main2():
    model = create_model(5)
    stats_graph()

if __name__ == "__main__":
    main()

以上的代码包含两种方法,均可以实现模型FLOPs的计算。值得注意的是,在定义模型的时候,最好使用input_tensor而不是input_shape

2. 运行结果

main1对应的运行结果如下:
main1运行结果

main2对应的运行结果如下:main2运行结果
以上就是本文全部内容,如有不正确的地方,还请在评论区批评指正。

参考链接

https://stackoverflow.com/questions/62283556/calculating-flops-of-a-keras-model-returns-ops-with-no-flops-due-to-incomplete-s
https://blog.csdn.net/deephacking/article/details/107873881

猜你喜欢

转载自blog.csdn.net/orangeboss/article/details/131932597