一种大于2GB ONNX模型onnxsim优化方法

大于2GB模型onnxsim优化很耗时,容易挂掉,而且需要特别大的系统内存。

这里提出一种比较简单的优化大于2GB ONNX模型的方法:

1. 把卷积和矩阵乘的权重(参数量大于某个阈值)替换为ConstantOfShape,从而显著缩小模型大小。

2. 利用onnxsim特性避免折叠(参数量大于某个阈值)ConstantOfShape算子。

需要onnxsim>=0.4.24

避免常量折叠产生大tensor的tile和ConstantOfShape算子的onnxsim命令:

onnxsim --no-large-tensor size_th in_model.onnx out_model.onnx

size_th类似为1KB, 1MB等

3. 对压缩后的模型进行优化和常量折叠后的模型删除ConstantOfShape算子,并替换为原来的权重。

该方法也可以用于其他优化,例如onnx infer shape和opset转换。因为大于2GB模型需要写回到文件才能调用官方infershape,而大于2GB模型opset转换无官方方法。

注意每个ConstantOfShape的value最好不一样,否则onnxsim会合并value和shape相同的ConstantOfShape。

ConstantOfShape算子

34042c77b5d54e759b3de7a5bdf3d148.png

创建一个ConstantOfShape算子

const_shape = (512, 512)
const_shape_init = onnx.helper.make_tensor(name="const_shape", data_type=onnx.TensorProto.INT64, dims=[
                                           2], vals=const_shape, raw=False)

tensor_value_attr = onnx.helper.make_tensor("value", onnx.TensorProto.FLOAT, dims=[1], vals=[1])
constantofshape_node = onnx.helper.make_node(op_type="ConstantOfShape",
                                             inputs=["const_shape"],
                                             outputs=["input2"],
                                             value=tensor_value_attr)

完整的模型压缩解压缩代码和使用方法如下:

python compress_model.py onnx_model.onnx

onnxsim --no-large-tensor 1MB compressed.onnx compressed.opt.onnx

python uncompress_model.py onnx_model.onnx compressed.opt.onnx onnx_model.compress.cfg

该方法可以成功用于stable diffusion unet的优化。

需要注意几点,stable diffusion unet可以采用上面的方法压缩,然后设置压缩模型的输入shape并进行onnxsim优化。但是可能一次优化并不能消除所有的shape算子等动态shape 并且可能出现time_step从[1]变成[-1]的bug。需要重新设置一次输入shape再onnxsim优化一次即可消除所有动态shape算子。最后再进行解压缩。

如果明明设置了静态shape,onnxsim还是优化不完全,可以先删除旧的shape value info,再调用onnx infer shape工具infer shape,再使用onnxsim, 参考:

onnx模型图优化/模型修改_nodes in a graph must be topologically sorted_Luchang-Li的博客-CSDN博客

模型压缩和解压缩代码

compress_model.py

import sys
import json
import onnx


SIZE_1MB = 1024 * 1024


def get_onnx_tensor_proto_shape(onnx_tensor_proto):
    shape = [elem for elem in onnx_tensor_proto.dims]
    return shape


def get_onnx_tensor_proto_dtype(onnx_tensor_proto):
    return onnx_tensor_proto.data_type


def shape_elem_num(shape):
    elem_num = 1
    for elem in shape:
        elem_num *= elem
    return elem_num


NODE_INDICES = {}


def create_node_name(node_type):
    global NODE_INDICES
    if node_type not in NODE_INDICES:
        NODE_INDICES[node_type] = 0
    node_id = NODE_INDICES[node_type]
    NODE_INDICES[node_type] += 1

    name = f"{node_type}_{node_id}"
    return name


def create_const_of_shape(shape, dtype=onnx.TensorProto.FLOAT, value=0.0, output_name=None, node_name=None):
    if node_name is None:
        node_name = create_node_name("ConstantOfShape")
    if not output_name:
        output_name = node_name + "_output0"
    const_shape_name = node_name + "_shape"

    shape_dim = [len(shape)]
    shape_initializer = onnx.helper.make_tensor(
        name=const_shape_name, data_type=onnx.TensorProto.INT64, dims=shape_dim, vals=shape, raw=False)

    tensor_value_attr = onnx.helper.make_tensor("value", dtype, dims=[1], vals=[value])

    node = onnx.helper.make_node(op_type="ConstantOfShape",
                                 inputs=[const_shape_name],
                                 outputs=[output_name],
                                 value=tensor_value_attr)
    return node, shape_initializer


def del_onnx_initializers(graph, del_init_names):
    indices = []
    for idx, tensor_proto in enumerate(graph.initializer):
        if tensor_proto.name in del_init_names:
            indices.append(idx)

    indices = sorted(indices, reverse=True)
    for idx in indices:
        del graph.initializer[idx]


def insert_onnx_nodes(graph, idx, new_nodes):
    new_nodes = reversed(new_nodes)
    for node in new_nodes:
        graph.node.insert(idx, node)


def add_onnx_inits(graph, new_inits):
    del_init_names = [init.name for init in new_inits]
    del_onnx_initializers(graph, del_init_names)
    graph.initializer.extend(new_inits)


COMPRESS_NODE_TYPES = ["Conv", "Gemm", "MatMul"]
CONST_OF_SHAPE_VALUE = 0.01


def compress_onnx_model(onnx_model, size_th=SIZE_1MB):
    graph = onnx_model.graph
    initializer = graph.initializer

    name_2_init_map = {}
    for init in initializer:
        name_2_init_map[init.name] = init

    replaced_tensor_names = []
    new_nodes = []
    new_inits = []

    for node in graph.node:
        if node.op_type not in COMPRESS_NODE_TYPES:
            continue
        init_name = node.input[1]
        if init_name not in name_2_init_map:
            continue

        init = name_2_init_map[init_name]
        dtype = get_onnx_tensor_proto_dtype(init)
        shape = get_onnx_tensor_proto_shape(init)

        if dtype != onnx.TensorProto.FLOAT:
            continue

        shape_elem = shape_elem_num(shape)
        if shape_elem * 4 <= size_th:
            continue

        global CONST_OF_SHAPE_VALUE
        node, shape_init = create_const_of_shape(
            shape=shape, dtype=onnx.TensorProto.FLOAT, value=CONST_OF_SHAPE_VALUE, output_name=init.name)
        CONST_OF_SHAPE_VALUE += 0.003

        replaced_tensor_names.append(init.name)
        new_nodes.append(node)
        new_inits.append(shape_init)

    replaced_tensor_names = list(set(replaced_tensor_names))
    del_onnx_initializers(graph, replaced_tensor_names)
    insert_onnx_nodes(graph, 0, new_nodes)
    add_onnx_inits(graph, new_inits)
    return onnx_model, replaced_tensor_names


if __name__ == "__main__":
    model_path = sys.argv[1]
    compressed_model_path = model_path[:-5] + ".compressed.onnx"
    compress_cfg = model_path[:-5] + ".compress.cfg"

    onnx_model = onnx.load(model_path)

    onnx_model, replaced_tensor_names = compress_onnx_model(onnx_model)
    onnx.save(onnx_model, compressed_model_path)

    replace_cfg = {"replaced_tensor_names": replaced_tensor_names}
    print("replaced_tensor_names:", json.dumps(replace_cfg))

    with open(compress_cfg, "w") as f:
        f.write(json.dumps(replace_cfg))

uncompress_model.py

import sys
import logging
import json
import onnx


def del_onnx_initializers(graph, del_init_names):
    indices = []
    for idx, tensor_proto in enumerate(graph.initializer):
        if tensor_proto.name in del_init_names:
            indices.append(idx)

    indices = sorted(indices, reverse=True)
    for idx in indices:
        del graph.initializer[idx]


def del_onnx_nodes(graph, nodes, del_node_init=False):
    unused_init_names = []
    if del_node_init:
        init_names = [init.name for init in graph.initializer]
        for node in nodes:
            for in_name in node.input:
                if in_name in init_names:
                    unused_init_names.append(in_name)

    indices = []
    for idx, node in enumerate(graph.node):
        if node in nodes:
            indices.append(idx)
    indices = sorted(indices, reverse=True)
    for idx in indices:
        del graph.node[idx]

    if del_node_init:
        del_onnx_initializers(graph, unused_init_names)


def add_onnx_inits(graph, new_inits):
    del_init_names = [init.name for init in new_inits]
    del_onnx_initializers(graph, del_init_names)
    graph.initializer.extend(new_inits)


def uncompress_onnx_model(onnx_model_orig, onnx_model_compressed, replaced_tensor_names):
    initializer_orig = onnx_model_orig.graph.initializer

    del_nodes = []
    for node in onnx_model_compressed.graph.node:
        if node.output[0] in replaced_tensor_names:
            del_nodes.append(node)
    valid_replace_names = [_node.output[0] for _node in del_nodes]

    new_inits = []
    for init in initializer_orig:
        if init.name in valid_replace_names:
            new_inits.append(init)

    print(f"del node num: {len(del_nodes)}, replaced tensor number in orig model: {len(replaced_tensor_names)}")

    if len(del_nodes) != len(replaced_tensor_names):
        logging.warning("const of shape nodes number != replaced tensor names number")

    del_onnx_nodes(onnx_model_compressed.graph, del_nodes, del_node_init=False)
    add_onnx_inits(onnx_model_compressed.graph, new_inits)
    return onnx_model_compressed


if __name__ == "__main__":
    ref_model_path = sys.argv[1]  # unoptimized model
    compressed_model_path = sys.argv[2]  # compressed and optimized model
    cfg_file_path = sys.argv[3]  # compress config file that store the replaced init info

    out_model_path = compressed_model_path[:-5] + ".uncompressed.onnx"

    with open(cfg_file_path, "r") as f:
        txt = f.read()
        replace_cfg = json.loads(txt)
    print("replaced_tensor_names:", replace_cfg)

    replaced_tensor_names = replace_cfg["replaced_tensor_names"]

    onnx_model_orig = onnx.load(ref_model_path)
    onnx_model_compressed = onnx.load(compressed_model_path)

    onnx_model_uncompressed = uncompress_onnx_model(onnx_model_orig, onnx_model_compressed, replaced_tensor_names)
    onnx.save(onnx_model_uncompressed, out_model_path, save_as_external_data=True)

猜你喜欢

转载自blog.csdn.net/u013701860/article/details/130337446