转换Onnx过程中
1 pytorch转onnx模型多输入问题(如:Bert)
Bert模型有三个输入,因此就要创建三个dummy_input,然后利用一个tuple,传入函数中。
dummy_input0 = torch.LongTensor(Batch_size, seg_length).to(torch.device("cuda"))
dummy_input1 = torch.LongTensor(Batch_size, seg_length).to(torch.device("cuda"))
dummy_input2 = torch.LongTensor(Batch_size, seg_length).to(torch.device("cuda"))
torch.onnx.export(model. (dummy_input0, dummy_input1, dummy_input2), filepath)
2
PyTorch v1.0.1 Reshape不支持报错 [Solution]
PyTorch v1.2.0 需要升级cuda10.0以上
3 像data[index] = new_data这样的张量就地索引分配目前在导出中不受支持。解决这类问题的一种方法是使用算子散点,显式地更新原始张量。
就是像tensorflow的静态图,不能随便改变tensor的值,可以用torch的scatter_方法解决
错误的方式
# def forward(self, data, index, new_data):
# data[index] = new_data # 重新赋值
# return data
正确的方式
class InPlaceIndexedAssignmentONNX(torch.nn.Module):
def forward(self, data, index, new_data):
new_data = new_data.unsqueeze(0)
index = index.expand(1, new_data.size(1))
data.scatter_(0, index, new_data)
return data
4. ONNX export failed on ATen operator group_norm because torch.onnx.symbolic.group_norm does not exist
解决:~/anaconda3/envs/py36/lib/python3.6/site-packages/torch/onnx/symbolic.py
@parse_args('v', 'i', 'v', 'v', 'f', 'i')
def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
return g.op("ATen", input, weight, bias, num_groups_i=num_groups,
eps_f=eps, cudnn_enabled_i=cudnn_enabled, operator_s="group_norm")
5 ‘RuntimeError: ONNX export failed: Couldn’t export operator aten::upsample_bilinear2d’
解决方案
略略略
6 RuntimeError: ONNX export failed: Couldn’t export operator aten::avg_pool2d
#Error
self.global_average = nn.AdaptiveAvgPool2d((1,1))//就是这一行的问题是用的AdaptiveAvgPool2d
#OK
self.global_average = nn.AvgPool2d(kernel_size = (7,7),stride=(7,7),ceil_mode=False)
以后遇到别人代码使用Adaptive Pooling,可以通过这两个公式转换为标准的Max/AvgPooling:
#只需要知道输入的input_size ,就可以推导出stride 与kernel_size ,从而替换为标准的Max/AvgPooling
stride = floor ( (input_size / (output_size) )
kernel_size = input_size − (output_size−1) * stride
padding = 0
PyTorch–>ONNX
这一部分比较简单,大致照着PyTorch官网的例程走即可。
pytorch导出onnx并检查
onnx_model = onnx.load(output_onnx)
onnx.checker.check_model(onnx_model) # assuming throw on error
print("==> Passed")
[Solution]
解决:~/anaconda3/envs/py36/lib/python3.6/site-packages/torch/onnx/symbolic.py
在该文件中添加代码
def reshape(g, self, shape):
return view(g, self, shape)
def reshape_as(g, self, other):
shape = g.op('Shape', other)
return reshape(g, self, shape)
Reference
4 RuntimeError: ONNX export failed: Couldn't export operator aten::avg_pool2d