pytorch模型转onnx模型问题汇总

1、常见报错

1.1 变量由tensor类型变为int类型

错误描述: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can’t record the data flow of Python values。

错误代码:
```python
x0 = torch.randn(1, 3, 640, 480)
h, w = x0.size()[-2:]   #{tensor:()} tenrsor(640)
paddingBottom = int(np.ceil(h/64)*64-h)  #{int} 0
paddingRight = int(np.ceil(w/64)*64-w)  # {int} 32
x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)  ##{tensor:(tenrsor(1),tenrsor(3),tenrsor(640),tenrsor(512))} 
正确代码
x0 = torch.randn(1, 3, 640, 480)
h, w = x0.size()[-2:]
paddingBottom = np.ceil(h/64)*64-h   # 修改第一处:去掉int,保持tensor类型
paddingRight = np.ceil(w/64)*64-w   # 修改第二处:去掉int,保持tensor类型
x0 = nn.ReplicationPad2d((0, paddingRight.numel(), 0, paddingBottom.numel()))(x0)  # 修改第三处:增加.numel()

参考解决办法:torch.jit.trace 消除TracerWarning [1] 评论区中的做法,保持变量为tensor类型,通过.numel()来获取int数值。
在这里插入图片描述

1.2 警告:floordiv is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the ‘trunc’ function NOT ‘floor’).

来源于模块:einops/einops.py

inferred_length: int = length // known_product  #报错代码
# 修改为:
inferred_length: int = torch.div(length, known_product, rounding_mode='floor')  # 修改:引入 torch.div实现

参考解决办法:文献[2]和文献[3]
在这里插入图片描述
在这里插入图片描述

1.3 einops.Rearrange替换为torch.transpose模块

采用einops.Rearrange报错:
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
known: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] != _unknown_axis_length}

from einops.layers.torch import Rearrange

trans_x = Rearrange('b h w c -> b c h w')(trans_x)  #原始代码
----------------------修改为--------------------
trans_x = trans_x.transpose(3, 2).transpose(2, 1)  #修改代码,运行速度上有轻微的可忽略不计的提升
from einops import rearrange 

attn_mask0 = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)') #原始代码
----------------------修改为--------------------
w1, w2, p1, p2, p3, p4 = attn_mask.size()  #修改代码
attn_mask = attn_mask.reshape(1, 1, w1 * w2, p1 * p2, p3 * p4) #修改代码

参考文献

[1] torch.jit.trace 消除TracerWarning, https://blog.csdn.net/WANGWUSHAN/article/details/118052523
[2] python使用时出现的问题(求解决)https://zhuanlan.zhihu.com/p/562922076
[3] UserWarning: floordiv is deprecated, and its behavior will change in a future version of pytorch. https://blog.csdn.net/weixin_43564920/article/details/127004030

猜你喜欢

转载自blog.csdn.net/sanxiaw/article/details/132839440