This is a tutorial introduction to the TorchScript, TorchScript is PyTorch model ( nn.Module
intermediate subclass) representation, can operate in high performance environments (e.g., C).
In this tutorial, we'll cover:
- Writing Base model PyTorch, including:
- Module
- Pre-defined function to
- The hierarchy of modules module
- Convert PyTorch module for a specific method TorchScript (our high-performance runtime deployment) of
- Tracking existing modules
- Use scripts directly compiled module
- How to combine these two methods
- Save and load modules TorchScript
We hope that after completing this tutorial, you will continue to read the follow-up tutorial that will guide you really call TorchScript example models from C.
import torch # 这是同时使用PyTorch和TorchScript所需的全部导入!
print(torch.__version__)
- Output
1.3.0
1.PyTorch model basis
Let's start with a simple definition of the module. PyTorch module is the basic unit in the composition. It contains:
- Constructor to call preparation module
- A set of parameters and submodules. These initialization by the constructor, and may be used by the module during a call.
- Forward function. This is the code that runs when calling module.
Let's look at a small example:
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
def forward(self, x, h):
new_h = torch.tanh(x h)
return new_h, new_h
my_cell = MyCell()
x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell(x, h))
- Output
(tensor([[0.5139, 0.6451, 0.3697, 0.7738],
[0.7936, 0.5864, 0.8063, 0.9324],
[0.6479, 0.8408, 0.8062, 0.7263]]), tensor([[0.5139, 0.6451, 0.3697, 0.7738],
[0.7936, 0.5864, 0.8063, 0.9324],
[0.6479, 0.8408, 0.8062, 0.7263]]))
Thus, we have:
- Creating a subclass of
torch.nn.Module
class. - Define a constructor. Constructors do not have to do too many things, but the constructor is called super.
- It defines a positive feature, which requires two input and two output return. Before the actual content is not important to the function, but it is a counterfeit RNN unit
- i.e., the function to the loop.
We instantiate the module, and produced x
and y
they are just 3x4 matrix of random values. Then, we used my_cell(x,h)
to call the cell. This in turn calls our forwarding.
Let's do some more interesting things:
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.linear(x) h)
return new_h, new_h
my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))
- Output
MyCell(
(linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[ 0.3941, 0.4160, -0.1086, 0.8432],
[ 0.5604, 0.4003, 0.5009, 0.6842],
[ 0.7084, 0.7147, 0.1818, 0.8296]], grad_fn=<TanhBackward>), tensor([[ 0.3941, 0.4160, -0.1086, 0.8432],
[ 0.5604, 0.4003, 0.5009, 0.6842],
[ 0.7084, 0.7147, 0.1818, 0.8296]], grad_fn=<TanhBackward>))
We have redefined the module MyCell
, but this time we added self.linear
properties, and call in advance (forward) function self.linear
.
Here in the end what happened? torch.nn.Linear
Is PyTorch standard library modules. Like MyCell
, like, you can use call syntax to invoke it. We are building a hierarchy of modules.
Printing on the module visually represents the hierarchy of subclasses module. In our example, we can see that we subclass linear and its parameters.
By such a combination module, we can write simple and easy to read model has reusable components.
You may have noticed in the output grad_fn
. This is PyTorch automatic discrimination method (referred autograd detail information) is. In short, the system allows us to calculate the potential complexity of the program by the derivative. This design provides great flexibility for model creation.
Now, let's examine its flexibility:
class MyDecisionGate(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x
else:
return -x
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.dg = MyDecisionGate()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.dg(self.linear(x)) h)
return new_h, new_h
my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))
- Output
MyCell(
(dg): MyDecisionGate()
(linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[0.0850, 0.2812, 0.5188, 0.8523],
[0.1233, 0.3948, 0.6615, 0.7466],
[0.7072, 0.6103, 0.6953, 0.7047]], grad_fn=<TanhBackward>), tensor([[0.0850, 0.2812, 0.5188, 0.8523],
[0.1233, 0.3948, 0.6615, 0.7466],
[0.7072, 0.6103, 0.6953, 0.7047]], grad_fn=<TanhBackward>))
We once again redefined the MyCell
class, but here we have defined MyDecisionGate
. The flow of the control module. Control flow includes loops and if statements something like.
Given a complete representation of the program, a number of frameworks symbol derived calculation method. However, in PyTorch, we use the gradient zone. We operation occurs during the recording operation, and plays back rearwardly when calculating derivatives. Thus, the frame structure does not have to explicitly define all derived classes of language.
2.TorchScript basis
Now, let us examples are running, for example, to see how to apply TorchScript.
In short, even PyTorch flexible and dynamic characteristics, TorchScript also gives you capture model definition tools. Let's start study the so-called tracking .
2.1 track (Tracing) module
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.linear(x) h)
return new_h, new_h
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)
- Output
TracedModule[MyCell](
original_name=MyCell
(linear): TracedModule[Linear](original_name=Linear)
)
Set us back a bit, and chose the MyCell
second version of the class. As before, we instantiate it, but this time, we call torch.jit.trace
, in Module
passing this example (module) and pass the input network may see in the example.
This in the end is what to do? It has been called modules, recording what will happen when the module is running, and created torch.jit.ScriptModule
instances ( TracedModule
are examples)
TorchScript which is defined in the middle represents the recording (or IR), the depth of the learning pattern is commonly referred to. We can check with the .graph
chart attributes:
print(traced_cell.graph)
- Output
graph(%self : ClassType<MyCell>,
%input : Float(3, 4),
%h : Float(3, 4)):
%1 : ClassType<Linear> = prim::GetAttr[name="linear"](%self)
%weight : Tensor = prim::GetAttr[name="weight"](%1)
%bias : Tensor = prim::GetAttr[name="bias"](%1)
%6 : Float(4, 4) = aten::t(%weight), scope: MyCell/Linear[linear] # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1370:0
%7 : int = prim::Constant[value=1](), scope: MyCell/Linear[linear] # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1370:0
%8 : int = prim::Constant[value=1](), scope: MyCell/Linear[linear] # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1370:0
%9 : Float(3, 4) = aten::addmm(%bias, %input, %6, %7, %8), scope: MyCell/Linear[linear] # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1370:0
: int = prim::Constant[value=1](), scope: MyCell # /var/lib/jenkins/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:188:0
: Float(3, 4) = aten::add(%9, %h, ), scope: MyCell # /var/lib/jenkins/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:188:0
: Float(3, 4) = aten::tanh(), scope: MyCell # /var/lib/jenkins/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:188:0
: (Float(3, 4), Float(3, 4)) = prim::TupleConstruct(, )
return ()
But this is a very low-level representation, most of the information included in the figure for the end user to no avail. Instead, we can use the .code
property to give the code of Python syntax explanation:
print(traced_cell.code)
- Output
import __torch__
import __torch__.torch.nn.modules.linear
def forward(self,
input: Tensor,
h: Tensor) -> Tuple[Tensor, Tensor]:
_0 = self.linear
weight = _0.weight
bias = _0.bias
_1 = torch.addmm(bias, input, torch.t(weight), beta=1, alpha=1)
_2 = torch.tanh(torch.add(_1, h, alpha=1))
return (_2, _2)
So why do we do all this? There are several reasons:
- TorchScript code can call in its own interpreter, the interpreter is substantially restricted Python interpreter. The interpreter is not global interpreter lock, the same can be
processed simultaneously on a number of requests an example. - This format allows us to save the entire model to disk, and load it into another environment, such as in written in languages other than Python's server
- TorchScript provides a representation for us, in which we can compile the code is optimized to provide more effective implementation
- TorchScript allows us to interface with many back-end / equipment operation, we need more than a single operator to run these programs more extensive view.
We can see that call traced_cell
the same produce results and Python modules:
print(my_cell(x, h))
print(traced_cell(x, h))
- Output
(tensor([[-0.3983, 0.5954, 0.2587, -0.3748],
[-0.5033, 0.4471, 0.8264, 0.2135],
[ 0.3430, 0.5561, 0.6794, -0.2273]], grad_fn=<TanhBackward>), tensor([[-0.3983, 0.5954, 0.2587, -0.3748],
[-0.5033, 0.4471, 0.8264, 0.2135],
[ 0.3430, 0.5561, 0.6794, -0.2273]], grad_fn=<TanhBackward>))
(tensor([[-0.3983, 0.5954, 0.2587, -0.3748],
[-0.5033, 0.4471, 0.8264, 0.2135],
[ 0.3430, 0.5561, 0.6794, -0.2273]],
grad_fn=<DifferentiableGraphBackward>), tensor([[-0.3983, 0.5954, 0.2587, -0.3748],
[-0.5033, 0.4471, 0.8264, 0.2135],
[ 0.3430, 0.5561, 0.6794, -0.2273]],
grad_fn=<DifferentiableGraphBackward>))
3. Use script converter module
There is a reason we used the second edition of the module, instead of using sub-modules with a lot of control flow. Now let's examine:
class MyDecisionGate(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x
else:
return -x
class MyCell(torch.nn.Module):
def __init__(self, dg):
super(MyCell, self).__init__()
self.dg = dg
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.dg(self.linear(x)) h)
return new_h, new_h
my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell.code)
- Output
import __torch__.___torch_mangle_0
import __torch__
import __torch__.torch.nn.modules.linear.___torch_mangle_1
def forward(self,
input: Tensor,
h: Tensor) -> Tuple[Tensor, Tensor]:
_0 = self.linear
weight = _0.weight
bias = _0.bias
x = torch.addmm(bias, input, torch.t(weight), beta=1, alpha=1)
_1 = torch.tanh(torch.add(x, h, alpha=1))
return (_1, _1)
View .code
output, we can not find where to find if-else
the branch! why? Tracking exactly as we said do: run the code, the recording operation occurs, and construct a can do that ScriptModule
. Unfortunately, things such as control flow like being erased.
How do we faithfully represents this module in TorchScript in? We provide a script compiler that can directly analyze your Python source code to convert it to TorchScript. Let's use the script compiler converts MyDecisionGate
:
scripted_gate = torch.jit.script(MyDecisionGate())
my_cell = MyCell(scripted_gate)
traced_cell = torch.jit.script(my_cell)
print(traced_cell.code)
- Output
import __torch__.___torch_mangle_3
import __torch__.___torch_mangle_2
import __torch__.torch.nn.modules.linear.___torch_mangle_4
def forward(self,
x: Tensor,
h: Tensor) -> Tuple[Tensor, Tensor]:
_0 = self.linear
_1 = _0.weight
_2 = _0.bias
if torch.eq(torch.dim(x), 2):
_3 = torch.__isnot__(_2, None)
else:
_3 = False
if _3:
bias = ops.prim.unchecked_unwrap_optional(_2)
ret = torch.addmm(bias, x, torch.t(_1), beta=1, alpha=1)
else:
output = torch.matmul(x, torch.t(_1))
if torch.__isnot__(_2, None):
bias0 = ops.prim.unchecked_unwrap_optional(_2)
output0 = torch.add_(output, bias0, alpha=1)
else:
output0 = output
ret = output0
_4 = torch.gt(torch.sum(ret, dtype=None), 0)
if bool(_4):
_5 = ret
else:
_5 = torch.neg(ret)
new_h = torch.tanh(torch.add(_5, h, alpha=1))
return (new_h, new_h)
Now, we have faithfully capture our behavior in TorchScript in the program. Now, let's try to run the program:
# New inputs
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell(x, h)
3.1 mixed scripting (Scripting) and track (Tracing)
In some cases, instead of using a tracking script (for example, the module has many architectural decisions that are based on our constant hope not appear Python values in TorchScript made of). In this case, the script may be written by tracking: torch.jit.script
the inline code modules to be tracked, the tracking code and inline script module.
- Example of the first case:
class MyRNNLoop(torch.nn.Module):
def __init__(self):
super(MyRNNLoop, self).__init__()
self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))
def forward(self, xs):
h, y = torch.zeros(3, 4), torch.zeros(3, 4)
for i in range(xs.size(0)):
y, h = self.cell(xs[i], h)
return y, h
rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)
- Output
import __torch__
import __torch__.___torch_mangle_5
import __torch__.___torch_mangle_2
import __torch__.torch.nn.modules.linear.___torch_mangle_6
def forward(self,
xs: Tensor) -> Tuple[Tensor, Tensor]:
h = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
y = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
y0 = y
h0 = h
for i in range(torch.size(xs, 0)):
_0 = self.cell
_1 = torch.select(xs, 0, i)
_2 = _0.linear
weight = _2.weight
bias = _2.bias
_3 = torch.addmm(bias, _1, torch.t(weight), beta=1, alpha=1)
_4 = torch.gt(torch.sum(_3, dtype=None), 0)
if bool(_4):
_5 = _3
else:
_5 = torch.neg(_3)
_6 = torch.tanh(torch.add(_5, h0, alpha=1))
y0, h0 = _6, _6
return (y0, h0)
- Examples of the second case:
class WrapRNN(torch.nn.Module):
def __init__(self):
super(WrapRNN, self).__init__()
self.loop = torch.jit.script(MyRNNLoop())
def forward(self, xs):
y, h = self.loop(xs)
return torch.relu(y)
traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)
- Output
import __torch__
import __torch__.___torch_mangle_9
import __torch__.___torch_mangle_7
import __torch__.___torch_mangle_2
import __torch__.torch.nn.modules.linear.___torch_mangle_8
def forward(self,
argument_1: Tensor) -> Tensor:
_0 = self.loop
h = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
h0 = h
for i in range(torch.size(argument_1, 0)):
_1 = _0.cell
_2 = torch.select(argument_1, 0, i)
_3 = _1.linear
weight = _3.weight
bias = _3.bias
_4 = torch.addmm(bias, _2, torch.t(weight), beta=1, alpha=1)
_5 = torch.gt(torch.sum(_4, dtype=None), 0)
if bool(_5):
_6 = _4
else:
_6 = torch.neg(_4)
h0 = torch.tanh(torch.add(_6, h0, alpha=1))
return torch.relu(h0)
Thus, when the situation requires them, you can use scripts and use them to track together.
4. Save and load model
We provide API, in order to save the archive format TorchScript module to disk or from disk to load TorchScript module. This format includes code, parameters, properties and debug information, which means that the archive is independent representation of the model can be loaded in a completely separate process. Let us save and load packaged RNN
modules:
traced.save('wrapped_rnn.zip')
loaded = torch.jit.load('wrapped_rnn.zip')
print(loaded)
print(loaded.code)
- Output
ScriptModule(
original_name=WrapRNN
(loop): ScriptModule(
original_name=MyRNNLoop
(cell): ScriptModule(
original_name=MyCell
(dg): ScriptModule(original_name=MyDecisionGate)
(linear): ScriptModule(original_name=Linear)
)
)
)
import __torch__
import __torch__.___torch_mangle_9
import __torch__.___torch_mangle_7
import __torch__.___torch_mangle_2
import __torch__.torch.nn.modules.linear.___torch_mangle_8
def forward(self,
argument_1: Tensor) -> Tensor:
_0 = self.loop
h = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
h0 = h
for i in range(torch.size(argument_1, 0)):
_1 = _0.cell
_2 = torch.select(argument_1, 0, i)
_3 = _1.linear
weight = _3.weight
bias = _3.bias
_4 = torch.addmm(bias, _2, torch.t(weight), beta=1, alpha=1)
_5 = torch.gt(torch.sum(_4, dtype=None), 0)
if bool(_5):
_6 = _4
else:
_6 = torch.neg(_4)
h0 = torch.tanh(torch.add(_6, h0, alpha=1))
return torch.relu(h0)
As you can see, the serialization module preserves the hierarchy and the code we have been in the study. For example, the model may be implemented to be loaded into C does not depend on the implementation of Python.
Further reading
We have completed the tutorial! For more presentation involved, check out the demo NeurIPS to use TorchScript conversion machine translation model: https: //colab.research.google.com/drive/1HiICg6jRkBnr5hvK2-VnMi88Vi9pUzEJ
The total running time of the script: (0 minutes 0.247 seconds)
Welcome attention Pan Chong blog resources Summary station: http://docs.panchuang.net/
Welcome concern PyTorch official Chinese Tutorial station: http://pytorch.panchuang.net/
OpenCV Chinese official document: http://woshicver.com/