在onnx推理remove_initializer_from_input

remove_initializer_from_input.py

import argparse

import onnx


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", required=True, help="input model")
    parser.add_argument("--output", required=True, help="output model")
    args = parser.parse_args()
    return args


def remove_initializer_from_input():
    args = get_args()

    model = onnx.load(args.input)
    if model.ir_version < 4:
        print("Model with ir_version below 4 requires to include initilizer in graph input")
        return

    inputs = model.graph.input
    name_to_input = {
    
    }
    for input in inputs:
        name_to_input[input.name] = input

    for initializer in model.graph.initializer:
        if initializer.name in name_to_input:
            inputs.remove(name_to_input[initializer.name])

    onnx.save(model, args.output)


if __name__ == "__main__":
    remove_initializer_from_input()
python remove_initializer_from_input.py --input model.onnx --output model_output.onnx

猜你喜欢

转载自blog.csdn.net/qq_16792139/article/details/132412768