TorchScript Profile

This is a tutorial introduction to the TorchScript, TorchScript is PyTorch model ( nn.Moduleintermediate subclass) representation, can operate in high performance environments (e.g., C).

In this tutorial, we'll cover:

  1. Writing Base model PyTorch, including:
    • Module
    • Pre-defined function to
    • The hierarchy of modules module
  2. Convert PyTorch module for a specific method TorchScript (our high-performance runtime deployment) of
    • Tracking existing modules
    • Use scripts directly compiled module
    • How to combine these two methods
    • Save and load modules TorchScript

We hope that after completing this tutorial, you will continue to read the follow-up tutorial that will guide you really call TorchScript example models from C.

import torch  # 这是同时使用PyTorch和TorchScript所需的全部导入!
print(torch.__version__)

  • Output
1.3.0

1.PyTorch model basis

Let's start with a simple definition of the module. PyTorch module is the basic unit in the composition. It contains:

  • Constructor to call preparation module
  • A set of parameters and submodules. These initialization by the constructor, and may be used by the module during a call.
  • Forward function. This is the code that runs when calling module.
    Let's look at a small example:
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()

    def forward(self, x, h):
        new_h = torch.tanh(x   h)
        return new_h, new_h

my_cell = MyCell()
x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell(x, h))

  • Output
(tensor([[0.5139, 0.6451, 0.3697, 0.7738],
        [0.7936, 0.5864, 0.8063, 0.9324],
        [0.6479, 0.8408, 0.8062, 0.7263]]), tensor([[0.5139, 0.6451, 0.3697, 0.7738],
        [0.7936, 0.5864, 0.8063, 0.9324],
        [0.6479, 0.8408, 0.8062, 0.7263]]))

Thus, we have:

  1. Creating a subclass of torch.nn.Moduleclass.
  2. Define a constructor. Constructors do not have to do too many things, but the constructor is called super.
  3. It defines a positive feature, which requires two input and two output return. Before the actual content is not important to the function, but it is a counterfeit RNN unit
    - i.e., the function to the loop.

We instantiate the module, and produced xand ythey are just 3x4 matrix of random values. Then, we used my_cell(x,h)to call the cell. This in turn calls our forwarding.

Let's do some more interesting things:

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x)   h)
        return new_h, new_h

my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))

  • Output
MyCell(
  (linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[ 0.3941,  0.4160, -0.1086,  0.8432],
        [ 0.5604,  0.4003,  0.5009,  0.6842],
        [ 0.7084,  0.7147,  0.1818,  0.8296]], grad_fn=<TanhBackward>), tensor([[ 0.3941,  0.4160, -0.1086,  0.8432],
        [ 0.5604,  0.4003,  0.5009,  0.6842],
        [ 0.7084,  0.7147,  0.1818,  0.8296]], grad_fn=<TanhBackward>))

We have redefined the module MyCell, but this time we added self.linearproperties, and call in advance (forward) function self.linear.

Here in the end what happened? torch.nn.LinearIs PyTorch standard library modules. Like MyCell, like, you can use call syntax to invoke it. We are building a hierarchy of modules.

Printing on the module visually represents the hierarchy of subclasses module. In our example, we can see that we subclass linear and its parameters.

By such a combination module, we can write simple and easy to read model has reusable components.

You may have noticed in the output grad_fn. This is PyTorch automatic discrimination method (referred autograd detail information) is. In short, the system allows us to calculate the potential complexity of the program by the derivative. This design provides great flexibility for model creation.

Now, let's examine its flexibility:

class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.dg = MyDecisionGate()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x))   h)
        return new_h, new_h

my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))

  • Output
MyCell(
  (dg): MyDecisionGate()
  (linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[0.0850, 0.2812, 0.5188, 0.8523],
        [0.1233, 0.3948, 0.6615, 0.7466],
        [0.7072, 0.6103, 0.6953, 0.7047]], grad_fn=<TanhBackward>), tensor([[0.0850, 0.2812, 0.5188, 0.8523],
        [0.1233, 0.3948, 0.6615, 0.7466],
        [0.7072, 0.6103, 0.6953, 0.7047]], grad_fn=<TanhBackward>))

We once again redefined the MyCellclass, but here we have defined MyDecisionGate. The flow of the control module. Control flow includes loops and if statements something like.

Given a complete representation of the program, a number of frameworks symbol derived calculation method. However, in PyTorch, we use the gradient zone. We operation occurs during the recording operation, and plays back rearwardly when calculating derivatives. Thus, the frame structure does not have to explicitly define all derived classes of language.

2.TorchScript basis

Now, let us examples are running, for example, to see how to apply TorchScript.

In short, even PyTorch flexible and dynamic characteristics, TorchScript also gives you capture model definition tools. Let's start study the so-called tracking .

2.1 track (Tracing) module

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x)   h)
        return new_h, new_h

my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)

  • Output
TracedModule[MyCell](
  original_name=MyCell
  (linear): TracedModule[Linear](original_name=Linear)
)

Set us back a bit, and chose the MyCellsecond version of the class. As before, we instantiate it, but this time, we call torch.jit.trace, in Modulepassing this example (module) and pass the input network may see in the example.

This in the end is what to do? It has been called modules, recording what will happen when the module is running, and created torch.jit.ScriptModuleinstances ( TracedModuleare examples)

TorchScript which is defined in the middle represents the recording (or IR), the depth of the learning pattern is commonly referred to. We can check with the .graphchart attributes:

print(traced_cell.graph)

  • Output
graph(%self : ClassType<MyCell>,
      %input : Float(3, 4),
      %h : Float(3, 4)):
  %1 : ClassType<Linear> = prim::GetAttr[name="linear"](%self)
  %weight : Tensor = prim::GetAttr[name="weight"](%1)
  %bias : Tensor = prim::GetAttr[name="bias"](%1)
  %6 : Float(4, 4) = aten::t(%weight), scope: MyCell/Linear[linear] # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1370:0
  %7 : int = prim::Constant[value=1](), scope: MyCell/Linear[linear] # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1370:0
  %8 : int = prim::Constant[value=1](), scope: MyCell/Linear[linear] # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1370:0
  %9 : Float(3, 4) = aten::addmm(%bias, %input, %6, %7, %8), scope: MyCell/Linear[linear] # /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1370:0
   : int = prim::Constant[value=1](), scope: MyCell # /var/lib/jenkins/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:188:0
   : Float(3, 4) = aten::add(%9, %h, ), scope: MyCell # /var/lib/jenkins/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:188:0
   : Float(3, 4) = aten::tanh(), scope: MyCell # /var/lib/jenkins/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:188:0
   : (Float(3, 4), Float(3, 4)) = prim::TupleConstruct(, )
  return ()

But this is a very low-level representation, most of the information included in the figure for the end user to no avail. Instead, we can use the .codeproperty to give the code of Python syntax explanation:

print(traced_cell.code)

  • Output
import __torch__
import __torch__.torch.nn.modules.linear
def forward(self,
    input: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = self.linear
  weight = _0.weight
  bias = _0.bias
  _1 = torch.addmm(bias, input, torch.t(weight), beta=1, alpha=1)
  _2 = torch.tanh(torch.add(_1, h, alpha=1))
  return (_2, _2)

So why do we do all this? There are several reasons:

  1. TorchScript code can call in its own interpreter, the interpreter is substantially restricted Python interpreter. The interpreter is not global interpreter lock, the same can be
    processed simultaneously on a number of requests an example.
  2. This format allows us to save the entire model to disk, and load it into another environment, such as in written in languages ​​other than Python's server
  3. TorchScript provides a representation for us, in which we can compile the code is optimized to provide more effective implementation
  4. TorchScript allows us to interface with many back-end / equipment operation, we need more than a single operator to run these programs more extensive view.

We can see that call traced_cellthe same produce results and Python modules:

print(my_cell(x, h))
print(traced_cell(x, h))

  • Output
(tensor([[-0.3983,  0.5954,  0.2587, -0.3748],
        [-0.5033,  0.4471,  0.8264,  0.2135],
        [ 0.3430,  0.5561,  0.6794, -0.2273]], grad_fn=<TanhBackward>), tensor([[-0.3983,  0.5954,  0.2587, -0.3748],
        [-0.5033,  0.4471,  0.8264,  0.2135],
        [ 0.3430,  0.5561,  0.6794, -0.2273]], grad_fn=<TanhBackward>))
(tensor([[-0.3983,  0.5954,  0.2587, -0.3748],
        [-0.5033,  0.4471,  0.8264,  0.2135],
        [ 0.3430,  0.5561,  0.6794, -0.2273]],
       grad_fn=<DifferentiableGraphBackward>), tensor([[-0.3983,  0.5954,  0.2587, -0.3748],
        [-0.5033,  0.4471,  0.8264,  0.2135],
        [ 0.3430,  0.5561,  0.6794, -0.2273]],
       grad_fn=<DifferentiableGraphBackward>))

3. Use script converter module

There is a reason we used the second edition of the module, instead of using sub-modules with a lot of control flow. Now let's examine:

class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x))   h)
        return new_h, new_h

my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell.code)

  • Output
import __torch__.___torch_mangle_0
import __torch__
import __torch__.torch.nn.modules.linear.___torch_mangle_1
def forward(self,
    input: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = self.linear
  weight = _0.weight
  bias = _0.bias
  x = torch.addmm(bias, input, torch.t(weight), beta=1, alpha=1)
  _1 = torch.tanh(torch.add(x, h, alpha=1))
  return (_1, _1)

View .codeoutput, we can not find where to find if-elsethe branch! why? Tracking exactly as we said do: run the code, the recording operation occurs, and construct a can do that ScriptModule. Unfortunately, things such as control flow like being erased.

How do we faithfully represents this module in TorchScript in? We provide a script compiler that can directly analyze your Python source code to convert it to TorchScript. Let's use the script compiler converts MyDecisionGate:

scripted_gate = torch.jit.script(MyDecisionGate())

my_cell = MyCell(scripted_gate)
traced_cell = torch.jit.script(my_cell)
print(traced_cell.code)

  • Output
import __torch__.___torch_mangle_3
import __torch__.___torch_mangle_2
import __torch__.torch.nn.modules.linear.___torch_mangle_4
def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = self.linear
  _1 = _0.weight
  _2 = _0.bias
  if torch.eq(torch.dim(x), 2):
    _3 = torch.__isnot__(_2, None)
  else:
    _3 = False
  if _3:
    bias = ops.prim.unchecked_unwrap_optional(_2)
    ret = torch.addmm(bias, x, torch.t(_1), beta=1, alpha=1)
  else:
    output = torch.matmul(x, torch.t(_1))
    if torch.__isnot__(_2, None):
      bias0 = ops.prim.unchecked_unwrap_optional(_2)
      output0 = torch.add_(output, bias0, alpha=1)
    else:
      output0 = output
    ret = output0
  _4 = torch.gt(torch.sum(ret, dtype=None), 0)
  if bool(_4):
    _5 = ret
  else:
    _5 = torch.neg(ret)
  new_h = torch.tanh(torch.add(_5, h, alpha=1))
  return (new_h, new_h)

Now, we have faithfully capture our behavior in TorchScript in the program. Now, let's try to run the program:

# New inputs
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell(x, h)

3.1 mixed scripting (Scripting) and track (Tracing)

In some cases, instead of using a tracking script (for example, the module has many architectural decisions that are based on our constant hope not appear Python values in TorchScript made of). In this case, the script may be written by tracking: torch.jit.scriptthe inline code modules to be tracked, the tracking code and inline script module.

  • Example of the first case:
class MyRNNLoop(torch.nn.Module):
    def __init__(self):
        super(MyRNNLoop, self).__init__()
        self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))

    def forward(self, xs):
        h, y = torch.zeros(3, 4), torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
        return y, h

rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)

  • Output
import __torch__
import __torch__.___torch_mangle_5
import __torch__.___torch_mangle_2
import __torch__.torch.nn.modules.linear.___torch_mangle_6
def forward(self,
    xs: Tensor) -> Tuple[Tensor, Tensor]:
  h = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
  y = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
  y0 = y
  h0 = h
  for i in range(torch.size(xs, 0)):
    _0 = self.cell
    _1 = torch.select(xs, 0, i)
    _2 = _0.linear
    weight = _2.weight
    bias = _2.bias
    _3 = torch.addmm(bias, _1, torch.t(weight), beta=1, alpha=1)
    _4 = torch.gt(torch.sum(_3, dtype=None), 0)
    if bool(_4):
      _5 = _3
    else:
      _5 = torch.neg(_3)
    _6 = torch.tanh(torch.add(_5, h0, alpha=1))
    y0, h0 = _6, _6
  return (y0, h0)

  • Examples of the second case:
class WrapRNN(torch.nn.Module):
    def __init__(self):
        super(WrapRNN, self).__init__()
        self.loop = torch.jit.script(MyRNNLoop())

    def forward(self, xs):
        y, h = self.loop(xs)
        return torch.relu(y)

traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)

  • Output
import __torch__
import __torch__.___torch_mangle_9
import __torch__.___torch_mangle_7
import __torch__.___torch_mangle_2
import __torch__.torch.nn.modules.linear.___torch_mangle_8
def forward(self,
    argument_1: Tensor) -> Tensor:
  _0 = self.loop
  h = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
  h0 = h
  for i in range(torch.size(argument_1, 0)):
    _1 = _0.cell
    _2 = torch.select(argument_1, 0, i)
    _3 = _1.linear
    weight = _3.weight
    bias = _3.bias
    _4 = torch.addmm(bias, _2, torch.t(weight), beta=1, alpha=1)
    _5 = torch.gt(torch.sum(_4, dtype=None), 0)
    if bool(_5):
      _6 = _4
    else:
      _6 = torch.neg(_4)
    h0 = torch.tanh(torch.add(_6, h0, alpha=1))
  return torch.relu(h0)

Thus, when the situation requires them, you can use scripts and use them to track together.

4. Save and load model

We provide API, in order to save the archive format TorchScript module to disk or from disk to load TorchScript module. This format includes code, parameters, properties and debug information, which means that the archive is independent representation of the model can be loaded in a completely separate process. Let us save and load packaged RNNmodules:

traced.save('wrapped_rnn.zip')

loaded = torch.jit.load('wrapped_rnn.zip')

print(loaded)
print(loaded.code)

  • Output
ScriptModule(
  original_name=WrapRNN
  (loop): ScriptModule(
    original_name=MyRNNLoop
    (cell): ScriptModule(
      original_name=MyCell
      (dg): ScriptModule(original_name=MyDecisionGate)
      (linear): ScriptModule(original_name=Linear)
    )
  )
)
import __torch__
import __torch__.___torch_mangle_9
import __torch__.___torch_mangle_7
import __torch__.___torch_mangle_2
import __torch__.torch.nn.modules.linear.___torch_mangle_8
def forward(self,
    argument_1: Tensor) -> Tensor:
  _0 = self.loop
  h = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
  h0 = h
  for i in range(torch.size(argument_1, 0)):
    _1 = _0.cell
    _2 = torch.select(argument_1, 0, i)
    _3 = _1.linear
    weight = _3.weight
    bias = _3.bias
    _4 = torch.addmm(bias, _2, torch.t(weight), beta=1, alpha=1)
    _5 = torch.gt(torch.sum(_4, dtype=None), 0)
    if bool(_5):
      _6 = _4
    else:
      _6 = torch.neg(_4)
    h0 = torch.tanh(torch.add(_6, h0, alpha=1))
  return torch.relu(h0)

As you can see, the serialization module preserves the hierarchy and the code we have been in the study. For example, the model may be implemented to be loaded into C does not depend on the implementation of Python.

Further reading

We have completed the tutorial! For more presentation involved, check out the demo NeurIPS to use TorchScript conversion machine translation model: https: //colab.research.google.com/drive/1HiICg6jRkBnr5hvK2-VnMi88Vi9pUzEJ

The total running time of the script: (0 minutes 0.247 seconds)

Welcome attention Pan Chong blog resources Summary station: http://docs.panchuang.net/

Welcome concern PyTorch official Chinese Tutorial station: http://pytorch.panchuang.net/

OpenCV Chinese official document: http://woshicver.com/

Published 328 original articles · won praise 903 · views 520 000 +

Guess you like

Origin blog.csdn.net/fendouaini/article/details/104090061