pytorch モデルを onnx モデルに変換する際の問題の概要

1. よくあるエラー

1.1 変数がtensor型からint型に変わる

説明: TracerWarning: テンソルを Python 整数に変換すると、トレースが不正確になる可能性があります。Python 値のデータ フローを記録することはできません。

错误代码:
```python
x0 = torch.randn(1, 3, 640, 480)
h, w = x0.size()[-2:]   #{tensor:()} tenrsor(640)
paddingBottom = int(np.ceil(h/64)*64-h)  #{int} 0
paddingRight = int(np.ceil(w/64)*64-w)  # {int} 32
x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)  ##{tensor:(tenrsor(1),tenrsor(3),tenrsor(640),tenrsor(512))} 
正确代码
x0 = torch.randn(1, 3, 640, 480)
h, w = x0.size()[-2:]
paddingBottom = np.ceil(h/64)*64-h   # 修改第一处:去掉int,保持tensor类型
paddingRight = np.ceil(w/64)*64-w   # 修改第二处:去掉int,保持tensor类型
x0 = nn.ReplicationPad2d((0, paddingRight.numel(), 0, paddingBottom.numel()))(x0)  # 修改第三处:增加.numel()

参考解決策: torch.jit.traceコメントエリアにあるようにTracerWarning [1]を削除し、変数をtensor型のままにして、.numel()でint値を取得します。
ここに画像の説明を挿入します

1.2 警告:Floordivは非推奨であり、その動作は pytorch の将来のバージョンで変更される予定です。現在、0 に向かって丸められます (「floor」ではなく「trunc」関数と同様)。

モジュールから: einops/einops.py

inferred_length: int = length // known_product  #报错代码
# 修改为:
inferred_length: int = torch.div(length, known_product, rounding_mode='floor')  # 修改:引入 torch.div实现

参考ソリューション:ドキュメント[2] およびドキュメント[3]
ここに画像の説明を挿入します
ここに画像の説明を挿入します

1.3 einops.Rearrange は torch.transpose モジュールに置き換えられました

einops.Rearrange を採用:
TracerWarning: テンソルを Python ブール値に変換すると、トレースが不正確になる可能性があります。Python 値のデータ フローを記録することはできないため、この値は将来定数として扱われる予定です。これは、トレースが他の入力に一般化されない可能性があることを意味します。
既知: Set[str] = {composite_axis の軸の軸 if axis_name2known_length[axis] != _unknown_axis_length}

from einops.layers.torch import Rearrange

trans_x = Rearrange('b h w c -> b c h w')(trans_x)  #原始代码
----------------------修改为--------------------
trans_x = trans_x.transpose(3, 2).transpose(2, 1)  #修改代码,运行速度上有轻微的可忽略不计的提升
from einops import rearrange 

attn_mask0 = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)') #原始代码
----------------------修改为--------------------
w1, w2, p1, p2, p3, p4 = attn_mask.size()  #修改代码
attn_mask = attn_mask.reshape(1, 1, w1 * w2, p1 * p2, p3 * p4) #修改代码

参考文献

[1] torch.jit.trace で TracerWarning が解消されるhttps://blog.csdn.net/WANGWUSHAN/article/details/118052523
[2] Python 使用時に発生する問題(解決策を求める)https://zhuanlan.zhihu.com /p/562922076
[3] UserWarning: Floordiv は非推奨であり、その動作は pytorch の将来のバージョンで変更される予定です。https://blog.csdn.net/weixin_43564920/article/details/127004030

おすすめ

転載: blog.csdn.net/sanxiaw/article/details/132839440