Tensorflow1.x でモデルの FLOP を計算する

ここにカスタム ディレクトリのタイトルを書き込みます

前に書く

Tensorflow2.x を使用する場合、keras-flops パッケージを使用してモデルの FLOP 計算を実装できますが、keras- flops パッケージは Tensorflow2 .2 以降のバージョンのみをサポートします。インターネット上のコードを参照した結果、Tensorflow 1.15.0 で計算モデルの FLOP を正常に実装できました。

コードと実行結果

1. コード

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers
import tensorflow.compat.v1.keras.backend as K


def stats_graph():
    # 必须要下面这行代码
    tf.compat.v1.disable_eager_execution()
    print(tf.__version__)
    sess = tf.compat.v1.Session()
    graph = sess.graph
    flops = tf.compat.v1.profiler.profile(graph, options=tf.compat.v1.profiler.ProfileOptionBuilder.float_operation())
    params = tf.compat.v1.profiler.profile(graph, options=tf.compat.v1.profiler.ProfileOptionBuilder.trainable_variables_parameter())
    print('GFLOPs: {};    Trainable params: {}'.format(flops.total_float_ops/1000000000.0, params.total_parameters))


def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.compat.v1.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.compat.v1.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = tf.graph_util.convert_variables_to_constants(
            session, input_graph_def, output_names, freeze_var_names)
        return frozen_graph


def load_pb(pb):
    with tf.io.gfile.GFile(pb, "rb") as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph


def create_model(num_class):
    #测试FLOPs时使用下面的模型定义方式
    pretrained_model = keras.applications.MobileNetV2(input_tensor = tf.compat.v1.placeholder('float32', shape=(1, 224, 224, 3)),
                                                      include_top=False,  # 去掉网络结构的最后一层
                                                      weights='imagenet',          # 参数为None从头开始训练,为'imagenet'从已训练好的模型开始训练
                                                      pooling='avg')
    pretrained_model.trainable = True  # 是否冻结网络(冻结网络参数不更新), 为trainable=True时,参数更新      
    model = keras.Model(inputs=pretrained_model.inputs, outputs=pretrained_model.outputs)
    model.summary()
    return model


def main():
    run_meta = tf.RunMetadata()
    with tf.Session(graph=tf.Graph()) as sess:
        K.set_session(sess)
        model =create_model(5)
        frozen_graph = freeze_session(K.get_session(), output_names=[out.op.name for out in model.outputs])

        with tf.io.gfile.GFile('graph.pb', "wb") as f:
            f.write(frozen_graph.SerializeToString())

        g2 = load_pb('./graph.pb')
        with g2.as_default():
            flops = tf.compat.v1.profiler.profile(g2, options=tf.compat.v1.profiler.ProfileOptionBuilder.float_operation())
            print('FLOP after freezing {} GFLOPS'.format(float(flops.total_float_ops) * 1e-9))

def main2():
    model = create_model(5)
    stats_graph()

if __name__ == "__main__":
    main()

上記のコードには 2 つのメソッドが含まれており、どちらもモデルの FLOP の計算を実現できます。 モデルを定義するときは、input_shape ではなく input_tensor を使用するのが最善であることに注意してください。

2. 走行結果

main1 に対応する実行結果は次のとおりです。
main1の実行結果

main2 の対応する実行結果は次のとおりです:main2の実行結果
上記がこの記事の全内容です。間違いがある場合は、コメント エリアで批判して修正してください。

参考リンク

https://stackoverflow.com/questions/62283556/calculated-flops-of-a-keras-model-returns-ops-with-no-flops-due-to-incomplete-s
https://blog.csdn.net/deephacking/article/details/107873881

おすすめ

転載: blog.csdn.net/orangeboss/article/details/131932597