torch.jit.trace 消除TracerWarning

在使用torch.jit.trace时,经常会碰到如下warning:

TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can’t record the data flow of Python values,so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

当然这些warning,可能并不会在c++调用时产生错误,权作洁癖吧。

此博客汇总了个人尝试过的一些warning的破解方式:

1.慎用tensor.shape/torch.size()

1.1 生成新的tensor

如下:

y=x.new(x.size())

产生如下错误:

TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect.

可改成:

y=torch.zeros_like(x)

1.2 if/while语句中

在if/while语句中,有时需要用到tensor.shape信息,若如下操作:

if x.shape[0]:
    x*=2

报错:

 TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect.

可修改为:

if x.numel():
    x*=2

这里发现一点,在使用torch.jit.trace时,tensor.shape/tensor.size()某个维度的信息常被当做是tensor,如:

print(x.size(0))
#tensor(8)

而正常情况下只是一个int变量,这可能就是torch.jit.trace经常报类似错误的关键所在。

尚不清楚是bug,还是自己没用对。

2.比较两个单元素的tensor

import torch
import torch.nn 
class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.a=torch.tensor(1)
        self.b=torch.tensor(2)
    def forward(self, x):
        if self.a!=self.b:
            x*=2
        return x

报错:

TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. 

这里要使用torch.equal()或者tensor_A.equal(tensor_B).

import torch
import torch.nn 
class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.a=torch.tensor(1)
        self.b=torch.tensor(2)
    def forward(self, x):
        if self.a.equal(self.b):
            x*=2
        return x

3.使用参数strict=False

torch.jit.trace(model,input_imgs,strict=False)

如果你的模型输出不是以tensor的形式,而是如list等的形式,可以设置strict=False来消除warning。

strict (bool, optional) – run the tracer in a strict mode or not (default: True). Only turn this off when you want the tracer to record our mutable container types (currently list/dict) and you are sure that the container you are using in your problem is a constant structure and does not get used as control flow (if, for) conditions.

4.一般性的方法

  • 不使用numpy;
  • 变量尽可能用tensor形式。

5.其他

未充分验证:

使用 tensor.index_select。

参考文献

[1] https://pytorch.org/docs/stable/generated/torch.jit.trace.html?highlight=torch%20jit%20trace#torch.jit.trace

おすすめ

転載: blog.csdn.net/WANGWUSHAN/article/details/118052523
おすすめ