Calculate model FLOPs under Tensorflow1.x

write in front

When using Tensorflow2.x, we can use thekeras-flops package to implement the FLOPs calculation of the model, but the keras-flops package only supports Tensorflow2 .2 or above version. After consulting the code on the Internet, I successfully implemented FLOPs of the calculation model under Tensorflow 1.15.0.

Code & running results

1. Code

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers
import tensorflow.compat.v1.keras.backend as K


def stats_graph():
    # 必须要下面这行代码
    tf.compat.v1.disable_eager_execution()
    print(tf.__version__)
    sess = tf.compat.v1.Session()
    graph = sess.graph
    flops = tf.compat.v1.profiler.profile(graph, options=tf.compat.v1.profiler.ProfileOptionBuilder.float_operation())
    params = tf.compat.v1.profiler.profile(graph, options=tf.compat.v1.profiler.ProfileOptionBuilder.trainable_variables_parameter())
    print('GFLOPs: {};    Trainable params: {}'.format(flops.total_float_ops/1000000000.0, params.total_parameters))


def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.compat.v1.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.compat.v1.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = tf.graph_util.convert_variables_to_constants(
            session, input_graph_def, output_names, freeze_var_names)
        return frozen_graph


def load_pb(pb):
    with tf.io.gfile.GFile(pb, "rb") as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph


def create_model(num_class):
    #测试FLOPs时使用下面的模型定义方式
    pretrained_model = keras.applications.MobileNetV2(input_tensor = tf.compat.v1.placeholder('float32', shape=(1, 224, 224, 3)),
                                                      include_top=False,  # 去掉网络结构的最后一层
                                                      weights='imagenet',          # 参数为None从头开始训练,为'imagenet'从已训练好的模型开始训练
                                                      pooling='avg')
    pretrained_model.trainable = True  # 是否冻结网络(冻结网络参数不更新), 为trainable=True时,参数更新      
    model = keras.Model(inputs=pretrained_model.inputs, outputs=pretrained_model.outputs)
    model.summary()
    return model


def main():
    run_meta = tf.RunMetadata()
    with tf.Session(graph=tf.Graph()) as sess:
        K.set_session(sess)
        model =create_model(5)
        frozen_graph = freeze_session(K.get_session(), output_names=[out.op.name for out in model.outputs])

        with tf.io.gfile.GFile('graph.pb', "wb") as f:
            f.write(frozen_graph.SerializeToString())

        g2 = load_pb('./graph.pb')
        with g2.as_default():
            flops = tf.compat.v1.profiler.profile(g2, options=tf.compat.v1.profiler.ProfileOptionBuilder.float_operation())
            print('FLOP after freezing {} GFLOPS'.format(float(flops.total_float_ops) * 1e-9))

def main2():
    model = create_model(5)
    stats_graph()

if __name__ == "__main__":
    main()

The above code contains two methods, both of which can realize the calculation of model FLOPs. It is worth noting that when defining the model, it is best to use input_tensor instead of input_shape.

2. Running results

The running results corresponding to main1 are as follows:
main1 running results

The corresponding running results of main2 are as follows:main2 running results
The above is the entire content of this article. If there are any errors, please criticize and correct them in the comment area.

Reference link

https://stackoverflow.com/questions/62283556/calculating-flops-of-a-keras-model-returns-ops-with-no-flops-due-to-incomplete-s
https://blog.csdn.net/deephacking/article/details/107873881

Guess you like

Origin blog.csdn.net/orangeboss/article/details/131932597