大于2GB模型onnxsim优化很耗时,容易挂掉,而且需要特别大的系统内存。
这里提出一种比较简单的优化大于2GB ONNX模型的方法:
1. 把卷积和矩阵乘的权重(参数量大于某个阈值)替换为ConstantOfShape,从而显著缩小模型大小。
2. 利用onnxsim特性避免折叠(参数量大于某个阈值)ConstantOfShape算子。
需要onnxsim>=0.4.24
避免常量折叠产生大tensor的tile和ConstantOfShape算子的onnxsim命令:
onnxsim --no-large-tensor size_th in_model.onnx out_model.onnx
size_th类似为1KB, 1MB等
3. 对压缩后的模型进行优化和常量折叠后的模型删除ConstantOfShape算子,并替换为原来的权重。
该方法也可以用于其他优化,例如onnx infer shape和opset转换。因为大于2GB模型需要写回到文件才能调用官方infershape,而大于2GB模型opset转换无官方方法。
注意每个ConstantOfShape的value最好不一样,否则onnxsim会合并value和shape相同的ConstantOfShape。
ConstantOfShape算子
创建一个ConstantOfShape算子
const_shape = (512, 512)
const_shape_init = onnx.helper.make_tensor(name="const_shape", data_type=onnx.TensorProto.INT64, dims=[
2], vals=const_shape, raw=False)
tensor_value_attr = onnx.helper.make_tensor("value", onnx.TensorProto.FLOAT, dims=[1], vals=[1])
constantofshape_node = onnx.helper.make_node(op_type="ConstantOfShape",
inputs=["const_shape"],
outputs=["input2"],
value=tensor_value_attr)
完整的模型压缩解压缩代码和使用方法如下:
python compress_model.py onnx_model.onnx
onnxsim --no-large-tensor 1MB compressed.onnx compressed.opt.onnx
python uncompress_model.py onnx_model.onnx compressed.opt.onnx onnx_model.compress.cfg
该方法可以成功用于stable diffusion unet的优化。
需要注意几点,stable diffusion unet可以采用上面的方法压缩,然后设置压缩模型的输入shape并进行onnxsim优化。但是可能一次优化并不能消除所有的shape算子等动态shape 并且可能出现time_step从[1]变成[-1]的bug。需要重新设置一次输入shape再onnxsim优化一次即可消除所有动态shape算子。最后再进行解压缩。
如果明明设置了静态shape,onnxsim还是优化不完全,可以先删除旧的shape value info,再调用onnx infer shape工具infer shape,再使用onnxsim, 参考:
onnx模型图优化/模型修改_nodes in a graph must be topologically sorted_Luchang-Li的博客-CSDN博客
模型压缩和解压缩代码
compress_model.py
import sys
import json
import onnx
SIZE_1MB = 1024 * 1024
def get_onnx_tensor_proto_shape(onnx_tensor_proto):
shape = [elem for elem in onnx_tensor_proto.dims]
return shape
def get_onnx_tensor_proto_dtype(onnx_tensor_proto):
return onnx_tensor_proto.data_type
def shape_elem_num(shape):
elem_num = 1
for elem in shape:
elem_num *= elem
return elem_num
NODE_INDICES = {}
def create_node_name(node_type):
global NODE_INDICES
if node_type not in NODE_INDICES:
NODE_INDICES[node_type] = 0
node_id = NODE_INDICES[node_type]
NODE_INDICES[node_type] += 1
name = f"{node_type}_{node_id}"
return name
def create_const_of_shape(shape, dtype=onnx.TensorProto.FLOAT, value=0.0, output_name=None, node_name=None):
if node_name is None:
node_name = create_node_name("ConstantOfShape")
if not output_name:
output_name = node_name + "_output0"
const_shape_name = node_name + "_shape"
shape_dim = [len(shape)]
shape_initializer = onnx.helper.make_tensor(
name=const_shape_name, data_type=onnx.TensorProto.INT64, dims=shape_dim, vals=shape, raw=False)
tensor_value_attr = onnx.helper.make_tensor("value", dtype, dims=[1], vals=[value])
node = onnx.helper.make_node(op_type="ConstantOfShape",
inputs=[const_shape_name],
outputs=[output_name],
value=tensor_value_attr)
return node, shape_initializer
def del_onnx_initializers(graph, del_init_names):
indices = []
for idx, tensor_proto in enumerate(graph.initializer):
if tensor_proto.name in del_init_names:
indices.append(idx)
indices = sorted(indices, reverse=True)
for idx in indices:
del graph.initializer[idx]
def insert_onnx_nodes(graph, idx, new_nodes):
new_nodes = reversed(new_nodes)
for node in new_nodes:
graph.node.insert(idx, node)
def add_onnx_inits(graph, new_inits):
del_init_names = [init.name for init in new_inits]
del_onnx_initializers(graph, del_init_names)
graph.initializer.extend(new_inits)
COMPRESS_NODE_TYPES = ["Conv", "Gemm", "MatMul"]
CONST_OF_SHAPE_VALUE = 0.01
def compress_onnx_model(onnx_model, size_th=SIZE_1MB):
graph = onnx_model.graph
initializer = graph.initializer
name_2_init_map = {}
for init in initializer:
name_2_init_map[init.name] = init
replaced_tensor_names = []
new_nodes = []
new_inits = []
for node in graph.node:
if node.op_type not in COMPRESS_NODE_TYPES:
continue
init_name = node.input[1]
if init_name not in name_2_init_map:
continue
init = name_2_init_map[init_name]
dtype = get_onnx_tensor_proto_dtype(init)
shape = get_onnx_tensor_proto_shape(init)
if dtype != onnx.TensorProto.FLOAT:
continue
shape_elem = shape_elem_num(shape)
if shape_elem * 4 <= size_th:
continue
global CONST_OF_SHAPE_VALUE
node, shape_init = create_const_of_shape(
shape=shape, dtype=onnx.TensorProto.FLOAT, value=CONST_OF_SHAPE_VALUE, output_name=init.name)
CONST_OF_SHAPE_VALUE += 0.003
replaced_tensor_names.append(init.name)
new_nodes.append(node)
new_inits.append(shape_init)
replaced_tensor_names = list(set(replaced_tensor_names))
del_onnx_initializers(graph, replaced_tensor_names)
insert_onnx_nodes(graph, 0, new_nodes)
add_onnx_inits(graph, new_inits)
return onnx_model, replaced_tensor_names
if __name__ == "__main__":
model_path = sys.argv[1]
compressed_model_path = model_path[:-5] + ".compressed.onnx"
compress_cfg = model_path[:-5] + ".compress.cfg"
onnx_model = onnx.load(model_path)
onnx_model, replaced_tensor_names = compress_onnx_model(onnx_model)
onnx.save(onnx_model, compressed_model_path)
replace_cfg = {"replaced_tensor_names": replaced_tensor_names}
print("replaced_tensor_names:", json.dumps(replace_cfg))
with open(compress_cfg, "w") as f:
f.write(json.dumps(replace_cfg))
uncompress_model.py
import sys
import logging
import json
import onnx
def del_onnx_initializers(graph, del_init_names):
indices = []
for idx, tensor_proto in enumerate(graph.initializer):
if tensor_proto.name in del_init_names:
indices.append(idx)
indices = sorted(indices, reverse=True)
for idx in indices:
del graph.initializer[idx]
def del_onnx_nodes(graph, nodes, del_node_init=False):
unused_init_names = []
if del_node_init:
init_names = [init.name for init in graph.initializer]
for node in nodes:
for in_name in node.input:
if in_name in init_names:
unused_init_names.append(in_name)
indices = []
for idx, node in enumerate(graph.node):
if node in nodes:
indices.append(idx)
indices = sorted(indices, reverse=True)
for idx in indices:
del graph.node[idx]
if del_node_init:
del_onnx_initializers(graph, unused_init_names)
def add_onnx_inits(graph, new_inits):
del_init_names = [init.name for init in new_inits]
del_onnx_initializers(graph, del_init_names)
graph.initializer.extend(new_inits)
def uncompress_onnx_model(onnx_model_orig, onnx_model_compressed, replaced_tensor_names):
initializer_orig = onnx_model_orig.graph.initializer
del_nodes = []
for node in onnx_model_compressed.graph.node:
if node.output[0] in replaced_tensor_names:
del_nodes.append(node)
valid_replace_names = [_node.output[0] for _node in del_nodes]
new_inits = []
for init in initializer_orig:
if init.name in valid_replace_names:
new_inits.append(init)
print(f"del node num: {len(del_nodes)}, replaced tensor number in orig model: {len(replaced_tensor_names)}")
if len(del_nodes) != len(replaced_tensor_names):
logging.warning("const of shape nodes number != replaced tensor names number")
del_onnx_nodes(onnx_model_compressed.graph, del_nodes, del_node_init=False)
add_onnx_inits(onnx_model_compressed.graph, new_inits)
return onnx_model_compressed
if __name__ == "__main__":
ref_model_path = sys.argv[1] # unoptimized model
compressed_model_path = sys.argv[2] # compressed and optimized model
cfg_file_path = sys.argv[3] # compress config file that store the replaced init info
out_model_path = compressed_model_path[:-5] + ".uncompressed.onnx"
with open(cfg_file_path, "r") as f:
txt = f.read()
replace_cfg = json.loads(txt)
print("replaced_tensor_names:", replace_cfg)
replaced_tensor_names = replace_cfg["replaced_tensor_names"]
onnx_model_orig = onnx.load(ref_model_path)
onnx_model_compressed = onnx.load(compressed_model_path)
onnx_model_uncompressed = uncompress_onnx_model(onnx_model_orig, onnx_model_compressed, replaced_tensor_names)
onnx.save(onnx_model_uncompressed, out_model_path, save_as_external_data=True)