pytorch JIT浅解析

概要
  Torch Script中的核心数据结构是ScriptModule。 它是Torch的nn.Module的类似物,代表整个模型作为子模块树。 与普通模块一样,ScriptModule中的每个单独模块都可以包含子模块,参数和方法。 在nn.Modules中,方法是作为Python函数实现的,但在ScriptModules方法中通常实现为Torch Script函数,这是一个静态类型的Python子集,包含PyTorch的所有内置Tensor操作。 这种差异允许您运行ScriptModules代码而无需Python解释器。

ScriptModules和Torch Script函数可以通过两种方式创建:
Tracing:
  使用torch.jit.trace,您可以获取现有模块或python函数,提供示例输入,然后运行该函数,记录在所有张量上执行的操作。 我们将生成的记录转换为Torch Script方法,该方法作为ScriptModule的正向方法安装。 该模块还包含原始模块所具有的任何参数。
Example:

import torch
def foo(x, y):
return 2*x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
1
2
3
4
注意:

  由于跟踪仅记录张量上的操作,因此它不会记录任何控制流操作,如if语句或循环。 当这个控制流在你的模块中保持不变时,这很好,它通常只是内联配置决策。 但有时控制流实际上是模型本身的一部分。 例如,序列到序列转换中的波束搜索是输入的(变化的)序列长度上的循环。 在这种情况下,跟踪不合适,并且应使用脚本编写波束搜索。


Scripting:
  您可以使用Python语法直接编写Torch Script代码。 您可以在ScriptModule的子类上使用torch.jit.script批注(对于函数)或torch.jit.script_method批注(对于方法)来执行此操作。 使用此注释,注释函数的主体将直接转换为Torch脚本。 Torch脚本本身是Python语言的一个子集,因此并非python中的所有功能都可以工作,但我们提供了足够的功能来计算张量并执行与控制相关的操作。
实例:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit import ScriptModule, script_method, trace

class MyScriptModule(ScriptModule):
def __init__(self):
super(MyScriptModule, self).__init__()
# trace produces a ScriptModule's conv1 and conv2
self.conv1 = trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
self.conv2 = trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))

@script_method
def forward(self, input):
input = F.relu(self.conv1(input))
input = F.relu(self.conv2(input))
return input
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
  用于将JIT模式PyTorch程序转换为Torch脚本的API可在torch.jit模块中找到。该模块有两种核心模式,用于将JIT模式模型转换为Torch Script图形表示:Tracing:和Scripting:。
  torch.jit.trace函数接受一个模块或函数以及一组示例输入。然后,它在跟踪遇到的计算步骤时通过函数或模块运行示例输入,并输出执行Tracing操作的基于图形的函数。Tracing非常适用于不涉及数据相关控制流的简单模块和功能,例如标准卷积神经网络。但是,如果Tracing具有依赖于数据的if语句和循环的函数,则仅记录由示例输入执行的执行路径调用的操作。换句话说,不捕获控制流本身。 为了转换包含依赖于数据的控制流的模块和函数,提供了一种 Script机制。
  Script显式将模块或功能代码转换为Torch Script,包括所有可能的控制流路径。 要使用脚本模式,请确保从torch.jit.ScriptModule基类(而不是torch.nn.Module)继承,并将torch.jit.script装饰器添加到Python函数或torch.jit.script_method装饰器中。你的模块的方法。使用脚本的一个警告是它只支持Python的受限子集。下面会描述当前pytorch JIT支持的功能的所有详细信息。为了提供最大的灵活性,可以组合Torch脚本的模式来表示整个程序,并且可以逐步应用这些技术。

TORCH SCRIPT LANGUAGE REFERENCE
Torch Script是Python的一个子集,可以直接编写(使用@script注释),也可以通过跟踪从Python代码自动生成。 使用跟踪时,代码会自动转换为Python的这个子集,方法是仅记录张量上的实际运算符,并简单地执行和丢弃其他周围的Python代码。

使用@script注释直接编写Torch脚本时,程序员必须只使用Torch脚本支持的Python子集。 本节介绍了Torch Script支持的内容,就好像它是独立语言的语言参考一样。 本参考中未提及的Python的任何功能都不是Torch脚本的一部分。

作为Python的一个子集,任何有效的Torch Script函数也是一个有效的Python函数。 这样就可以删除@script注释并使用标准Python工具(如pdb)调试函数。 反之亦然:有许多有效的python程序不是有效的Torch Script程序。 相反,Torch Script专注于在Torch中表示神经网络模型所需的Python特性。

PYTORCH_JIT= 1
设置环境变量PYTORCH_JIT = 0将禁用所有脚本和跟踪注释。 如果其中一个ScriptModule中存在难以调试的错误,则可以使用此标志强制所有内容都使用本机Python运行。 这允许使用像pdb这样的工具来调试代码。

1.Types,支持的类型
Torch Script与完整Python语言之间的最大区别在于Torch Script仅支持表达神经网络模型所需的一小部分类型。 特别是Torch Script支持:

Tensor
  任何dtype,dimension或backend的PyTorch Tensor。

Tuple[T0, T1, …]
  包含子类型T0,T1等的元组(例如Tuple[Tensor,Tensor])

int
  int标量

float
  float 标量

List[T]
  所有成员都是T类的列表与Python不同,Torch Script函数中的每个变量都必须具有单个静态类型。 这样可以更轻松地优化Torch Script功能。
Example,下面这种情况应该避免,返回类型不一致:

@torch.jit.script
def an_error(x):
if x:
r = torch.rand(1)
else:
r = 4
return r # Type mismatch: r is set to type Tensor in the true branch
# and type int in the false branch
1
2
3
4
5
6
7
8
默认情况下,假定Torch脚本函数的所有参数都是Tensor,因为这是模块中最常用的类型。 要指定Torch脚本函数的参数是另一种类型,可以使用上面列出的类型使用MyPy样式类型注释:

@torch.jit.script
def foo(x, tup):
# type: (int, Tuple[Tensor, Tensor]) -> Tensor
t0, t1 = tup
return t0 + t1 + x

print(foo(3, (torch.rand(3), torch.rand(3))))
1
2
3
4
5
6
7
Tips:也可以使用Python 3类型注释来注释类型。 在我们的示例中,我们使用基于注释的注释来确保Python 2的兼容性。
1
2.Expressions,表示
支持以下Python表达式

Literals,常量:
True, False, None, 'string literals', "string literals", number literals 3 (interpreted as int) 3.4 (interpreter as a float)
1
Variables,变量:
a

Variable Resolution,变量分辨能力
  Torch Script支持Python的可变分辨率(即范围)规则的子集。 局部变量的行为与Python中的相同,除了变量必须在函数的所有路径中具有相同类型的限制。 如果变量在if语句的不同侧具有不同的类型,则在if语句结束后使用它是错误的。

类似地,如果仅沿着函数的某些路径定义变量,则不允许使用该变量。

@torch.jit.script
def foo(x):
if x < 0:
y = 4
print(y) # Error: undefined value y
1
2
3
4
5
定义函数时,非局部变量在编译时解析为Python值。 然后,使用“使用Python值”中描述的规则将这些值转换为Torch Script值。

Tuple Construction
(3, 4), (3,)

List Construction
[3, 4], [], [torch.rand(3), torch.rand(4)]

假设空列表具有类型List [Tensor]。 其他列表文字的类型是从成员的类型派生的。

Arithmetic Operators
a + b a - b a * b a / b a ^ b a @ b

Comparison Operators
a == b a != b a < b a > b a <= b a >= b

Logical Operators
a and b a or b not b

Subscripts
t[0] t[-1] t[0:2] t[1:] t[:1] t[:] t[0, 1] t[0, 1:2] t[0, :1] t[-1, 1:, 0] t[1:, -1, 0] t[i:j, i]

Torch Script目前不支持变异张量,因此任何张量索引只能出现在表达式的右侧size上。

Function calls
调用内置函数: torch.rand(3, dtype=torch.int)
调用其他script函数:

import torch

@torch.jit.script
def foo(x):
return x + 1

@torch.jit.script
def bar(x):
return foo(x)
1
2
3
4
5
6
7
8
9
Method calls
调用内置类型的方法,如Tensor:x.mm(y)

在ScriptModule中定义Script方法时,使用@script_method批注。 在这些方法中,可以调用此类的其他方法或访问子模块上的方法。

直接调用子模块(例如self.resnet(输入))等同于调用其正向方法(例如self.resnet.forward(input))

import torch

class MyScriptModule(torch.jit.ScriptModule):
def __init__(self):
super(MyScriptModule, self).__init__()
self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
.resize_(1, 3, 1, 1))
self.resnet = torch.jit.trace(torchvision.models.resnet18(),
torch.rand(1, 3, 224, 224))

@torch.jit.script_method
def helper(self, input):
return self.resnet(input - self.means)

@torch.jit.script_method
def forward(self, input):
return self.helper(input)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
If expressions
x if x > y else y

Casts
float(ten), int(3.5), bool(ten)

Accessing Module Parameters
self.my_parameter self.my_submodule.my_parameter

3.Statements
Torch Script支持以下类型的语句:

Simple Assignments 简单的赋值
a = b
a += b # short-hand for a = a + b, does not operate in-place on a
a -= b

Pattern Matching Assignments
a, b = tuple_or_list
a, b, *c = a_tuple

Print Statements
print(“the result of an add:”, a + b)

If Statements
if a < 4:
r = -a
elif a < 3:
r = a + a
else:
r = 3 * a
1
2
3
4
5
6
While Loops
a = 0
while a < 4:
print(a)
a += 1
1
2
3
4
For loops with range
x = 0
for i in range(10):
x *= i
1
2
3
NOTE:Script当前不支持迭代通用可迭代对象,如list或tensor。 脚本当前不支持启动或增加范围的参数。 这些将在未来版本中添加。

For loops over tuples:
tup = (3, torch.rand(4))
for x in tup:
print(x)
1
2
3
Note:对于tuples的循环将展开循环,为tuples的每个成员生成一个主体。 正文必须为每个成员正确地进行类型检查。

For loops over constant torch.nn.ModuleList
class SubModule(torch.jit.ScriptModule):
def __init__(self):
super(Sub, self).__init__()
self.weight = nn.Parameter(torch.randn(2))

@torch.jit.script_method
def forward(self, input):
return self.weight + input

class MyModule(torch.jit.ScriptModule):
__constants__ = ['mods']

def __init__(self):
super(MyModule, self).__init__()
self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])

@torch.jit.script_method
def forward(self, v):
for module in self.mods:
v = m(v)
return v
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
要在@script_method中使用ModuleList,必须通过将属性的名称添加到该类型的__constants__列表来将其标记为常量。 对于ModuleList上的循环,将在编译时使用常量模块列表的每个成员展开循环体。

Return
return a, b

Note:必须有一个return语句作为函数的最后一个成员,并且return语句不能出现在函数的任何其他位置。 此限制将在以后删除。

4.Debugging
Disable JIT for Debugging
如果要禁用所有JIT模式(跟踪和脚本),以便可以在原始Python中调试程序,则可以使用PYTORCH_JIT环境变量。 PYTORCH_JIT可以通过将其值设置为0来全局禁用JIT。给出一个示例脚本:

@torch.jit.script
def scripted_fn(x : torch.Tensor):
for i in range(12):
x = x + x
return x


def fn(x):
x = torch.neg(x)
import pdb; pdb.set_trace()
return scripted_fn(x)

traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))

traced_fn(torch.rand(3, 4))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
除了调用@script函数之外,使用PDB调试此脚本的工作原理除外。 我们可以全局禁用JIT,这样我们就可以将@script函数作为普通的python函数调用而不是编译它。 如果上面的脚本名为disable_jit_example.py,我们可以像这样调用它:

$ PYTORCH_JIT=0 python disable_jit_example.py
1
我们将能够作为普通的Python函数进入@script函数。

Interpreting Graphs,解释图表
TorchScript使用静态单一赋值(SSA)中间表示(IR)来表示计算。 这种格式的指令包括ATen(PyTorch的C ++后端)运算符和其他原始运算符,包括循环和条件的控制流运算符。 举个例子:

@torch.jit.script
def foo(len):
# type: (int) -> torch.Tensor
rv = torch.zeros(3, 4)
for i in range(len):
if i < 10:
rv = rv - 1.0
else:
rv = rv + 1.0
return rv

print(foo.graph)
1
2
3
4
5
6
7
8
9
10
11
12
具有单个forward方法的ScriptModule将具有属性图,您可以使用该图来检查表示计算的IR。 如果ScriptModule有多个方法,则需要访问方法本身的.graph而不是模块。 我们可以通过访问.bar.graph来检查ScriptModule上名为bar的方法的图形。
上面的示例脚本生成图形:

graph(%len : int) {
%13 : float = prim::Constant[value=1]()
%10 : int = prim::Constant[value=10]()
%2 : int = prim::Constant[value=4]()
%1 : int = prim::Constant[value=3]()
%3 : int[] = prim::ListConstruct(%1, %2)
%4 : int = prim::Constant[value=6]()
%5 : int = prim::Constant[value=0]()
%6 : int[] = prim::Constant[value=[0, -1]]()
%rv.1 : Dynamic = aten::zeros(%3, %4, %5, %6)
%8 : int = prim::Constant[value=1]()
%rv : Dynamic = prim::Loop(%len, %8, %rv.1)
block0(%i : int, %12 : Dynamic) {
%11 : int = aten::lt(%i, %10)
%rv.4 : Dynamic = prim::If(%11)
block0() {
%14 : int = prim::Constant[value=1]()
%rv.2 : Dynamic = aten::sub(%12, %13, %14)
-> (%rv.2)
}
block1() {
%16 : int = prim::Constant[value=1]()
%rv.3 : Dynamic = aten::add(%12, %13, %16)
-> (%rv.3)
}
%19 : int = prim::Constant[value=1]()
-> (%19, %rv.4)
}
return (%rv);
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
以指令%rv.1:Dynamic = aten :: zeros(%3,%4,%5,%6)为例。 %rv.1:动态意味着我们将输出分配给名为rv.1的(唯一)值,并且该值是动态类型,即我们不知道其具体形状。 aten :: zeros是运算符(相当于torch.zeros),输入列表(%3,%4,%5,%6)指定范围中的哪些值应作为输入传递。 内置函数(如aten :: zeros)的模式可以在Builtin Functions中找到。

Builtin Functions:
  Torch Script支持PyTorch提供的内置张量和神经网络函数的子集。 Tensor上的大多数方法以及torch命名空间中的函数都可用。 torch.nn.functional中的许多功能也是可用的。
  我们目前不提供任何内置的ScriptModule,例如Linear或Conv模块。 此功能将在未来开发。 目前我们建议使用torch.jit.trace将标准的torch.nn模块转换为构造中的ScriptModules。

请注意,运算符也可以有关联的块,即prim :: Loop和prim :: If运算符。 在图形打印输出中,这些运算符被格式化以反映其等效的源代码形式,以便于调试。
可以如图所示检查图形以确认ScriptModule描述的计算以自动和手动方式是正确的,如下所述。

Tracing Edge Cases
存在一些边缘情况,其中给定Python函数/模块的跟踪将不代表底层代码。 这些案件可包括:

Tracing依赖于输入的控制流(例如tensor的shapes)
Tracing Tensor视图的就地操作(例如,在左侧索引的赋值)
请注意,这些情况实际上可能在将来可被Trace。
Automatic Trace Checking
自动捕获跟踪中的许多错误的一种方法是使用torch.jit.trace()API上的check_inputs。 check_inputs获取一系列输入元组列表,这些元组将用于重新跟踪计算并验证结果。 例如:

def loop_in_traced_fn(x):
result = x[0]
for i in range(x.size(0)):
result = result * x[i]
return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)
1
2
3
4
5
6
7
8
9
10
提供以下诊断信息:

ERROR: Graphs differed across invocations!
Graph diff:
graph(%0 : Dynamic) {
%1 : int = prim::Constant[value=0]()
%2 : int = prim::Constant[value=0]()
%3 : Dynamic = aten::select(%0, %1, %2)
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=0]()
%6 : Dynamic = aten::select(%0, %4, %5)
%7 : Dynamic = aten::mul(%3, %6)
%8 : int = prim::Constant[value=0]()
%9 : int = prim::Constant[value=1]()
%10 : Dynamic = aten::select(%0, %8, %9)
%11 : Dynamic = aten::mul(%7, %10)
%12 : int = prim::Constant[value=0]()
%13 : int = prim::Constant[value=2]()
%14 : Dynamic = aten::select(%0, %12, %13)
%15 : Dynamic = aten::mul(%11, %14)
+ %16 : int = prim::Constant[value=0]()
+ %17 : int = prim::Constant[value=3]()
+ %18 : Dynamic = aten::select(%0, %16, %17)
+ %19 : Dynamic = aten::mul(%15, %18)
- return (%15);
? ^
+ return (%19);
? ^
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
此消息向我们表明,在我们第一次跟踪它和使用check_inputs跟踪它时,计算之间存在差异。 实际上,loop_in_traced_fn体内的循环取决于输入x的形状,因此当我们尝试另一个具有不同形状的x时,迹线会有所不同。

在这种情况下,可以使用脚本来捕获这样的数据相关控制流:

def fn(x):
result = x[0]
for i in range(x.size(0)):
result = result * x[i]
return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

scripted_fn = torch.jit.script(fn)
print(scripted_fn.graph)

for input_tuple in [inputs] + check_inputs:
torch.testing.assert_allclose(fn(*input_tuple), scripted_fn(*input_tuple))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
那就会产生:

graph(%x : Dynamic) {
%1 : int = prim::Constant[value=0]()
%2 : int = prim::Constant[value=0]()
%result.1 : Dynamic = aten::select(%x, %2, %1)
%4 : int = aten::size(%x, %1)
%5 : int = prim::Constant[value=1]()
%result : Dynamic = prim::Loop(%4, %5, %result.1)
block0(%i : int, %7 : Dynamic) {
%9 : int = prim::Constant[value=0]()
%10 : Dynamic = aten::select(%x, %9, %i)
%result.2 : Dynamic = aten::mul(%7, %10)
%12 : int = prim::Constant[value=1]()
-> (%12, %result.2)
}
return (%result);
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
Tracer Warnings
跟踪器在跟踪计算中为几个有问题的模式生成警告。 例如,在Tensor的切片(视图)上跟踪包含就地赋值的函数:

def fill_row_zero(x):
x[0] = torch.rand(*x.shape[1:2])
return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)
1
2
3
4
5
6
生成几个警告和一个只返回输入的图表:

fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
x[0] = torch.rand(*x.shape[1:2])
fill_row_zero.py:6: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%)
traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
graph(%0 : Float(3, 4)) {
return (%0);
}
1
2
3
4
5
6
7
8
我们可以通过修改代码以不使用就地更新来修复此问题,而是使用torch.cat构建结果张量:

def fill_row_zero(x):
x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0)
return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)
---------------------
作者:丶Shining
来源:CSDN
原文:https://blog.csdn.net/xxradon/article/details/86504906
版权声明:本文为博主原创文章,转载请附上博文链接!

猜你喜欢

转载自www.cnblogs.com/jfdwd/p/11232802.html