ONNX如何更改固定的输入输出

image.png

背景

发现了艺术风格转移的几个ONNX模型models/vision/style_transfer/fast_neural_style at main · onnx/models · GitHub,本来以为用onnxruntime跑一下就完事了,没想到输入尺寸是固定的。

image.png

于是就搜索一番,发现onnx有提供一个更新输入输出dim的方法

完整代码

代码源自这里onnx/update_model_dims.py at main · onnx/onnx · GitHub

# SPDX-License-Identifier: Apache-2.0

from typing import Any, List, Dict, Set
from onnx import ModelProto, ValueInfoProto

import onnx.checker


def update_inputs_outputs_dims(model: ModelProto, input_dims: Dict[str, List[Any]], output_dims: Dict[str, List[Any]]) -> ModelProto:
    """
        This function updates the dimension sizes of the model's inputs and outputs to the values
        provided in input_dims and output_dims. if the dim value provided is negative, a unique dim_param
        will be set for that dimension.
        Example. if we have the following shape for inputs and outputs:
                shape(input_1) = ('b', 3, 'w', 'h')
                shape(input_2) = ('b', 4)
                and shape(output)  = ('b', 'd', 5)
            The parameters can be provided as:
                input_dims = {
                    "input_1": ['b', 3, 'w', 'h'],
                    "input_2": ['b', 4],
                }
                output_dims = {
                    "output": ['b', -1, 5]
                }
            Putting it together:
                model = onnx.load('model.onnx')
                updated_model = update_inputs_outputs_dims(model, input_dims, output_dims)
                onnx.save(updated_model, 'model.onnx')
    """
    dim_param_set: Set[str] = set()

    def init_dim_param_set(dim_param_set: Set[str], value_infos: List[ValueInfoProto]) -> None:
        for info in value_infos:
            shape = info.type.tensor_type.shape
            for dim in shape.dim:
                if dim.HasField('dim_param'):
                    dim_param_set.add(dim.dim_param)  # type: ignore

    init_dim_param_set(dim_param_set, model.graph.input)  # type: ignore
    init_dim_param_set(dim_param_set, model.graph.output)  # type: ignore
    init_dim_param_set(dim_param_set, model.graph.value_info)  # type: ignore

    def update_dim(tensor: ValueInfoProto, dim: Any, j: int, name: str) -> None:
        dim_proto = tensor.type.tensor_type.shape.dim[j]
        if isinstance(dim, int):
            if dim >= 0:
                if dim_proto.HasField('dim_value') and dim_proto.dim_value != dim:
                    raise ValueError('Unable to set dimension value to {} for axis {} of {}. Contradicts existing dimension value {}.'
                        .format(dim, j, name, dim_proto.dim_value))
                dim_proto.dim_value = dim
            else:
                generated_dim_param = name + '_' + str(j)
                if generated_dim_param in dim_param_set:
                    raise ValueError('Unable to generate unique dim_param for axis {} of {}. Please manually provide a dim_param value.'
                        .format(j, name))
                dim_proto.dim_param = generated_dim_param
        elif isinstance(dim, str):
            dim_proto.dim_param = dim
        else:
            raise ValueError(f'Only int or str is accepted as dimension value, incorrect type: {type(dim)}')

    for input in model.graph.input:
        input_name = input.name
        if "input" not in input_name:
            continue
        input_dim_arr = input_dims[input_name]
        for j, dim in enumerate(input_dim_arr):
            update_dim(input, dim, j, input_name)

    for output in model.graph.output:
        output_name = output.name
        output_dim_arr = output_dims[output_name]
        for j, dim in enumerate(output_dim_arr):
            update_dim(output, dim, j, output_name)

    onnx.checker.check_model(model)
    return model
复制代码

做了2个小改动

  1. 判断如果input_name不是input开头的就不更新dim,使用过程中发现,后面的一些节点也被当做input了,然后就报错了
  2. 方法返回model

使用案例

if __name__ == "__main__":
    input_dims = {
        "input1": ['b', 3, 'w', 'h'],
    }
    output_dims = {
        "output1": ['b', 3, 'dw', 'dh']
    }
    model = onnx.load(R'h:\AIProjects\AIImageTools\algs\neural_style\onnx_models\mosaic-9.onnx')
    updated_model = update_inputs_outputs_dims(model, input_dims, output_dims)
    onnx.save(updated_model, R'h:\AIProjects\AIImageTools\algs\neural_style\onnx_models\mosaic-9-dynamic.onnx')
复制代码

改完之后再用Netron查看

image.png

效果

最后用转换后的onnx模型处理张图片玩玩

21313213.jpg

猜你喜欢

转载自juejin.im/post/7102446629399199780
今日推荐