菜鸟落泪:Failed to export an ONNX attribute ‘onnx::Gather‘ 报错

一、前言

最近在转 mobilenet v3 (pytorch -> onnx)的时候,遇见报错:

RuntimeError: Failed to export an ONNX attribute 'onnx::Gather', since it's not constant, please try to make things (e.g., kernel size) static if possible

网上搜了一下,发现要么很麻烦,要么不适用,看报错内容,大致就是说,有个op的属性值不是常量。

二、办法

经过思考,解决措施如下,是avg_pool2d的问题。

def forward(self, x):
    batch, channels, height, width = x.size()
    out = F.avg_pool2d(x, kernel_size=[height, width]).view(batch, -1)
    return out 

因为用avg_pool2d实现全局平均池化的效果(建议用这种实现方式,因为在onnx -> caffe中,nn.AdaptiveAvgPool2d()转换起来很麻烦),在运行下面这行代码的时候

# pytorch -> onnx 代码
torch.onnx.export(model, img, f, verbose=False, opset_version=11, input_names=['images'],
                      output_names=['output'])  # output_names=['classes', 'boxes']

heightwidth的类型会变成 torch.tensor, 使得转换报错,所以需要在转换前加上两行代码:

def forward(self, x):
        batch, channels, height, width = x.size()
        if torch.is_tensor(height):
        	height = height.item()  # 这里是修正代码
        	width = width.item()  # 这里是修正代码
        out = F.avg_pool2d(x, kernel_size=[height, width]).view(batch, -1)
        return out 

之后就转换成功了。

猜你喜欢

转载自blog.csdn.net/tangshopping/article/details/113336472