【Swin-T onnx】swin transformer 转 onnx Error解决

1 文件和环境准备

使用的是 霹雳吧啦 大神的GitHub代码:swin_transformer

推荐使用torch1.10以上的版本

2 Error_1

Exporting the operator roll to ONNX opset version 11 is not supported.

错误原因:roll算子不支持
解决方案:将model.py中的 roll 修改为 cat,代码如下:

        # reverse cyclic shift
        if self.shift_size > 0:
            # x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
            x = torch.cat((shifted_x[:,-self.shift_size:,:,:], shifted_x[:,:-self.shift_size,:,:]), dim=1)
            x = torch.cat((shifted_x[:,:,-self.shift_size:,:], shifted_x[:,:,:-self.shift_size,:]), dim=2)

        # cyclic shift
        if self.shift_size > 0:
            # shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
            shifted_x = torch.cat((x[:,self.shift_size:,:,:], x[:,:self.shift_size,:,:]), dim=1)
            shifted_x = torch.cat((x[:,:,self.shift_size:,:], x[:,:,:self.shift_size,:]), dim=2)

3 Error_2

  File "/swin_transformer/model.py", line 365, in forward
    x = torch.cat((shifted_x[:,-self.shift_size:,:,:], shifted_x[:,:-self.shift_size,:,:]), dims=1)
TypeError: cat() received an invalid combination of arguments - got (tuple, dims=int), but expected one of:
 * (tuple of Tensors tensors, int dim, *, Tensor out)

错误原因:torch.cat函数使用错误
解决方案:torch.cat中第二个参数是dim,不是dims。

4 Error_3

  File "/home/users/env/env1/lib64/python3.6/site-packages/torch/onnx/utils.py", line 890, in _graph_op
    torch._C._jit_pass_onnx_node_shape_type_inference(n, _params_dict, opset_version)
RuntimeError: input_shape_value == reshape_value || input_shape_value == 1 || reshape_value == 1INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/onnx/shape_type_inference.cpp":520, please report a bug to PyTorch. ONNX Expand input shape constraint not satisfied.

错误原因:pytorch版本问题,报错的这个版本是1.9.1
解决方案:pytorch版本升级为1.10.1即可。

5 导出onnx模型、推理一张图片

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import swin_tiny_patch4_window7_224 as create_model


def model_convert_onnx(model, input_shape, output_path):
    dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1])
    input_names = ["input1"]
    output_names = ["output1"]      

    torch.onnx.export(
        model,
        dummy_input,
        output_path,
        verbose=True,
        keep_initializers_as_inputs=True,
        opset_version=11,       # 版本通常为10 or 11
        input_names=input_names,
        output_names=output_names,
    )




def main():
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    device = "cpu"

    img_size = 224
    data_transform = transforms.Compose(
        [transforms.Resize(int(img_size * 1.14)),
         transforms.CenterCrop(img_size),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    # load image
    img_path = "../flower_data/tulip.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = create_model(num_classes=5).to(device)
    # load model weights
    model_weight_path = "./weights/model-9.pth"
    model.load_state_dict(torch.load(model_weight_path, map_location=device))
    model.eval()


    # # 导出onnx模型的输入尺寸,要和pytorch模型的输入尺寸一致
    # input_shape = (224, 224)
    # # onnx模型输出到哪里去
    # output_path = './weights/swin_transformer.onnx'
    # model_convert_onnx(model, input_shape, output_path)
    # print("model convert onnx finsh, onnx model location:", output_path)



    onnx_path = './weights/swin_transformer.onnx'
    #---------------------------------------------------------#
    #   使用onnxruntime
    #---------------------------------------------------------#
    image_data = img.numpy()
    import onnxruntime
    ort_session = onnxruntime.InferenceSession(onnx_path)
    # 注意这儿的 input1 需要和model_convert_onnx()中定义的模型输入名称相同!
    ort_inputs = {
    
    "input1": image_data}
    onnx_outputs = ort_session.run(None, ort_inputs)

    output = torch.from_numpy(onnx_outputs[0])
    output = torch.squeeze(output).cpu()
    predict = torch.softmax(output, dim=0)
    print("onnx_predict:", predict)
    predict_cla = torch.argmax(predict).numpy()




    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        print("pt_predict:", predict)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()

参考链接

https://github.com/pytorch/pytorch/issues/78348
https://blog.csdn.net/blueblood7/article/details/121034635

猜你喜欢

转载自blog.csdn.net/weixin_45377629/article/details/127413972