pytorch导出rot90算子至onnx

1 背景描述

在部署模型时,如果某些模型中或者前后处理中含有rot90算子,但又希望一起和模型导出onnx时,可能会遇到如下错误(当前使用环境pytorch2.0.1opset_version为17):

import torch
import torch.nn as nn


class RotModel(nn.Module):
    def forward(self, x: torch.Tensor):
        x = torch.rot90(x, k=1, dims=(2, 3))
        return x


def main():
    print("pytorch version:", torch.__version__)

    model = RotModel()
    with torch.inference_mode():
        x = torch.randn(size=(1, 3, 224, 224))

        torch.onnx.export(model,
                          args=(x,),
                          f="rot90_counterclockwise.onnx",
                          opset_version=17)


if __name__ == '__main__':
    main()

torch.onnx.errors.UnsupportedOperatorError: Exporting the operator ‘aten::rot90’ to ONNX opset version 17 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.

简单的说就是不支持导出该算子,包括在onnx支持的算子文档中也找不到rot90算子,onnx官方github链接:
https://github.com/onnx/onnx


2 等价替换

导不出咋办,那就想想旋转矩阵的原理,以及如何使用现有支持的算子替换。

2.1 rot90替换(NCHW)

废话不多说,rot90度(以逆时针为例)可以使用翻转和转置实现。具体代码如下,使用torch自带的rot90与自己实现的对比,通过torch.equal()来对比两个Tensor是否一致,结果一致,不信自己试试。

import torch


def self_rot90_counterclockwise(x: torch.Tensor):
    x = x.flip(dims=[3]).permute([0, 1, 3, 2])
    return x


def main():
    print("pytorch version:", torch.__version__)

    with torch.inference_mode():
        x = torch.randn(size=(1, 3, 224, 224))

        y0 = torch.rot90(x, k=1, dims=[2, 3])
        y1 = self_rot90_counterclockwise(x)
        print(torch.equal(y0, y1))


if __name__ == '__main__':
    main()

2.2 rot180替换(NCHW)

rot180度(以逆时针为例)可以使用翻转实现。具体代码如下:

import torch


def self_rot180_counterclockwise(x: torch.Tensor):
    x = x.flip(dims=[2, 3])
    return x


def main():
    print("pytorch version:", torch.__version__)

    with torch.inference_mode():
        x = torch.randn(size=(1, 3, 224, 224))

        y0 = torch.rot90(x, k=2, dims=[2, 3])
        y1 = self_rot180_counterclockwise(x)
        print(torch.equal(y0, y1))


if __name__ == '__main__':
    main()

2.3 rot270替换(NCHW)

rot270度(以逆时针为例)可以使用翻转和转置实现。具体代码如下:

import torch


def self_rot270_counterclockwise(x: torch.Tensor):
    x = x.flip(dims=[2]).permute([0, 1, 3, 2])
    return x


def main():
    print("pytorch version:", torch.__version__)

    with torch.inference_mode():
        x = torch.randn(size=(1, 3, 224, 224))

        y0 = torch.rot90(x, k=3, dims=[2, 3])
        y1 = self_rot270_counterclockwise(x)
        print(torch.equal(y0, y1))


if __name__ == '__main__':
    main()


3 rot导出ONNX

这里以rot90度(以逆时针为例)结合刚刚的等价实现来导出ONNX:

import torch
import torch.nn as nn


class RotModel(nn.Module):
    def forward(self, x: torch.Tensor):
        # x = torch.rot90(x, k=1, dims=(2, 3))
        x = x.flip(dims=[3]).permute([0, 1, 3, 2])
        return x


def main():
    print("pytorch version:", torch.__version__)

    model = RotModel()
    with torch.inference_mode():
        x = torch.randn(size=(1, 3, 224, 224))

        torch.onnx.export(model,
                          args=(x,),
                          f="rot90_counterclockwise.onnx",
                          opset_version=17)


if __name__ == '__main__':
    main()

使用netron打开生成的rot90_counterclockwise.onnx文件,如下所示:

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_37541097/article/details/134624876