torch.jit.trace TracerWarning mute技巧
在使用torch.jit.trace时,经常会碰到如下warning:
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can’t record the data flow of Python values,so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
当然这些warning,可能并不会在c++调用时产生错误,权作洁癖吧。
此博客汇总了个人尝试过的一些warning的破解方式:
1.慎用tensor.shape/torch.size()
1.1 生成新的tensor
如下:
y=x.new(x.size())
产生如下错误:
TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect.
可改成:
y=torch.zeros_like(x)
1.2 if/while语句中
在if/while语句中,有时需要用到tensor.shape信息,若如下操作:
if x.shape[0]:
x*=2
报错:
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect.
可修改为:
if x.numel():
x*=2
这里发现一点,在使用torch.jit.trace时,tensor.shape/tensor.size()某个维度的信息常被当做是tensor,如:
print(x.size(0))
#tensor(8)
而正常情况下只是一个int变量,这可能就是torch.jit.trace经常报类似错误的关键所在。
尚不清楚是bug,还是自己没用对。
2.比较两个单元素的tensor
import torch
import torch.nn
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.a=torch.tensor(1)
self.b=torch.tensor(2)
def forward(self, x):
if self.a!=self.b:
x*=2
return x
报错:
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect.
这里要使用torch.equal()或者tensor_A.equal(tensor_B).
import torch
import torch.nn
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.a=torch.tensor(1)
self.b=torch.tensor(2)
def forward(self, x):
if self.a.equal(self.b):
x*=2
return x
3.使用参数strict=False
torch.jit.trace(model,input_imgs,strict=False)
如果你的模型输出不是以tensor的形式,而是如list等的形式,可以设置strict=False来消除warning。
strict (bool, optional) – run the tracer in a strict mode or not (default: True). Only turn this off when you want the tracer to record our mutable container types (currently list/dict) and you are sure that the container you are using in your problem is a constant structure and does not get used as control flow (if, for) conditions.
4.一般性的方法
- 不使用numpy;
- 变量尽可能用tensor形式。
5.其他
未充分验证:
使用 tensor.index_select。