A preliminary study on TorchDynamo ②: Torch.FX research and practice

fbe7daf829237d58038f9b9f6a2864a4.jpeg

Author|strint

1

overview

torch.fx is a Python to Python code conversion tool officially released by PyTorch. If you want to do Torch code transformation, torch.fx is the tool of choice.

torch.fx will trace the Torch code into a graph composed of 6 basic nodes. Based on this graph, various transformations can be easily made. The transformed graph can be regenerated into torch code (a nn.Module), and then like an ordinary nn .Module is executed in the same way.

The newly released torch.compile (also known as TorchDynamo) in torch 2.0 converts the code into the GraphModule of torch.fx by default, further strengthening the importance of torch.fx. (Related articles: A preliminary exploration of TorchDynamo ①: Dynamic modification of Python ByteCode )

Keywords : PyTorch, graph transformation, compilation

2

Minimum example

torch.fx has three basic functions. The first basic function is to convert torch nn.Module into fx.GraphModule, which is called symbolic trace; the second basic function is intermediate expression and graph rewriting; the third basic function is Python code generation.

First define a representative nn.Module, which includes 6 basic operations that fx handles:

import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)


    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)


module = MyModule()

Then use the first basic function of fx, symbolic trace, which can convert the torch python code into a symbolic expression. The type of the expression is fx.GraphModule:

from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)

The characteristic of fx.GraphModule is that its behavior when performing calculations is the same as nn.Module, but it also has an embedded calculation graph, and this calculation graph can be operated using graph traversal. Intermediate expressions and graph rewriting are both It is done based on this calculation graph. Print fx.GraphModule and you can see the graph IR expression of the module above:

# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
    %x : [#users=1] = placeholder[target=x]
    %param : [#users=1] = get_attr[target=param]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp
"""

The calculation graph contained in fx.GraphModule can be converted into torch python code (the calculation graph can also be converted into custom code), that is, the code generation function. For example, the following is the python code corresponding to the module above:

# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
"""

These three main functions will be discussed separately later.

3

Graph generation (Symbolic Trace)

The fx composition method is symbolic trace. It can be understood as passing fake input into nn.Module or a function. When the fake input is executed, it is not actually executed, but the execution operation path (Symbolic Trace) is recorded, and finally a complete The execution record is a graph.

The inputs of the symbolic_trace function are root and concrete_args. Root is the code to be traced. concrete_args is optional and can be passed in some fake input to specialize the trace.

The trace function is implemented by default using Tracer's trace method. It implements the trace function, returns an fx.Graph, and then uses fx.Graph and the original root to construct an fx.GraphModule and returns it.

def symbolic_trace(
    root: Union[torch.nn.Module, Callable[..., Any]],
    concrete_args: Optional[Dict[str, Any]] = None,
) -> GraphModule:
    tracer = Tracer()
    graph = tracer.trace(root, concrete_args)
    name = (
        root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
    )
    return GraphModule(tracer.root, graph, name)

So in fact, you can not use symbolic_trace, but directly call Tracer yourself. If you need to customize trace logic, you can actually rewrite Tracer's behavior by inheriting and rewriting Tracer.

Tracer functions

The main method of Tracer is trace, which is used to convert the input nn.Module or function into a symbolic calculation graph (IR). The essence of trace is to record the corresponding operations as the value is passed.

The trace mechanism relies on converting input into an abstract value Proxy, which acts as a proxy for tensor execution. The process of trace is to convert tensors into Proxy and pass them in the code, and Proxy can input regular torch operations.

The reason why Proxy inputs regular torch operations can work is that it relies on the __torch_function__ protocol of torch issued operations](https://github.com/pytorch/rfcs/blob/master/RFC-0001-torch-function-for -methods.md). It can be considered that if a type supports __torch_function__, it can be passed to the regular function of torch for execution, and the logic called during execution is defined in __torch_function__. In this way, defining the logic of recording operations to the graph for Proxy's __torch_function__ can complete the trace function (https://github.com/pytorch/pytorch/blob/de586001269fa04fa76ccc64964f676a25e120b2/torch/fx/proxy.py#L449).

This mechanism can be used to implement a minimalist ProxyTensor that tracks and prints torch operations. For addition, the addition will be symbolized:

import torch


class ProxyTensor(torch.Tensor):
    @classmethod
    def __torch_function__(cls, func, types, args=None, kwargs=None):
        if func.__name__ == 'add':
            print("\n=> torch function call:")
            print(f"==> function name: {func.__name__}")
            print(f"==> function args: ({', '.join((str(type(arg)) for arg in args))})")
            # 自定义张量相加的行为
            result = args[0].symbolic() + " + " + args[1].symbolic()
            return result
        else:
            # 对于其他运算,使用默认行为
            return super().__torch_function__(func, types, args=args, kwargs=kwargs)


    def symbolic(self):
        return "tensor(" + str(self.shape) + ", " + str(self.dtype) +")"


# 创建自定义张量
x = ProxyTensor([4, 5, 6])
y = ProxyTensor([1, 2, 3])


result = x - y
print(f"minus result: {result}")


result = x + y
print(f"add result: {result}")

Performing a subtraction is a regular torch tensor operation:

minus result: ProxyTensor([3., 3., 3.])

When performing addition, custom operations will be performed on operations, inputs, and outputs instead of tensor operations:

=> torch function call:
==> function name: add
==> function args: (<class '__main__.ProxyTensor'>, <class '__main__.ProxyTensor'>)
add result: tensor(torch.Size([3]), torch.float32) + tensor(torch.Size([3]), torch.float32)

Tracer's trace actually does similar things. First, convert the input of nn.Module or function into Node in graph, and then wrap Node into Proxy as new input. Later, when the torch operation executes the Proxy, the custom __torch_function__ function will be triggered. The custom behavior of Proxy is to record the executed operation as the Node in the diagram, and then the operation wraps the Node into Proxy as the operation result to continue passing on. In this way, a computational graph is constructed.

It is also worth considering the nested calls of non-built-in nn.Module and functions. Nesting is ignored in fx, so what is traced are the built-in operations of torch. If you want a custom operation to be traced as a built-in operation, you can register it using torch.fx.wrap.

In addition, for control flow and non-torch built-in operations, you can find the limitations of the trace mechanism. They will be executed by python, but trace does not know that they exist. So the if loop may only record the execution of one branch, the for loop is expanded, and a python calculation result is passed into the torch built-in operation as a constant. This is a limitation of the trace mechanism.

fx.GraphModule

The return result of trace is fx.Graph, then use it to construct an fx.GraphModule and return it. fx.GraphModule inherits from nn.Module, so its main behavior is consistent with nn.Module. The special thing is that its forward is generated from fx.Graph. In addition, it has a graph attribute, which is used to obtain the calculation graph contained inside it. There is also a code attribute. Code is of type str, which is a Python text code generated from graph, and the forward method is obtained by compiling the text code.

The fx.GraphModule generated by symbolic_trace is usually used as an ordinary nn.Module. This is enough to understand when using it. This design reflects fx's good ease of use.

Custom Tracer

Torch fx also provides custom space for the trace process by inheriting and overriding Tracer. The following is an introduction. Usually, it is not involved in usage, so you can ignore this part.

There are several ways to customize:

  • create_node: Tracer will call this when inserting a node into the graph. It will return a node. There are the following 6 types of nodes. This is also the basic unit of trace process recording;

    • placeholder, usually the input of the entire traced Module or function;

    • call_function, function call;

    • call_method, method call on the object;

    • call_module, call of nn.Module;

    • get_attr, acquisition of attributes on nn.Module;

    • output, usually the output of the entire traced Module or function;

  • create_proxy: As shown above, the input and output of all operation calls are Proxy, so the input and output will be converted to Proxy, and the process of converting to Proxy will be implemented by calling create_proxy. Proxy corresponds to the abstract return result of the above Node multi-corresponding operation, so the corresponding Node will be input when Proxy is constructed.

  • create_args_for_root: Create the traced Module or function input;

  • create_arg: Create the input of the internal function;

  • call_module: Call call_module when encountering a nn.Module to trigger corresponding node creation and other behaviors;

  • getattr: When obtaining attributes from nn.Module, getattr will be called to trigger corresponding node creation and other behaviors;

The customization of the above methods is closely coupled with the behavior of Tracer, so it needs to be handled carefully and customized in conjunction with Tracer's code implementation.

4

Expression in the middle of the picture and rewriting of the picture

The execution process of the trace torch code above records the sequence of operations according to the python execution order. For each operation, a node is generated, and the type of the node is fx.Node. These nodes collectively form a graph, and the type of the graph is fx. Graph.

fx.Node and fx.Graph

fx.Node and fx.Graph are the core data structures of fx intermediate expression. In the above example, the graph is printed, and you can see the text expression of a complete graph, that is, the intermediate expression. Each line corresponds to a node (return corresponds to a node of output type):

```python
# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
    %x : [#users=1] = placeholder[target=x]
    %param : [#users=1] = get_attr[target=param]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp
"""
```

fx.Graph mainly supports some operations of adding, deleting, checking and modifying graphs. Supports using the nodes attribute to obtain the list of all Nodes in the graph, supporting using create_node to add a new Node (it also supports using syntactic sugars such as call_module and call_method to directly add specific types of Node), using erase_node to delete a Node, and using inserting_after or inserting_before settings. As the insertion point of new Node, use eliminate_dead_code to delete unused Node, use lint to check the graph structure, and use on_generate_code to insert some custom operations during code generation.

fx.Node represents a node in the graph. The op attribute of fx.Node can obtain the type of Node. As mentioned in the section about creating nodes above, there are the following 6 types of Node:

  • placeholder, usually the input of the entire traced Module or function;

  • call_function, function call;

  • call_method, method call on the object;

  • call_module, call of nn.Module;

  • get_attr, acquisition of attributes on nn.Module;

  • output, usually the output of the entire traced Module or function;

fx.Node supports using the append method to insert a new Node behind the node, supports using prepend to insert a new Node in front of the node, supports using replace_all_uses_with to replace all dependencies on this Node in the graph with a new Node, and also supports some other The replacement operation supports using format_node to format and print a Node.

Another valuable thing is that the fx.Node target attribute records the operations corresponding to the node. For placeholder, output, and call_method, target is an ordinary string name; for call_function, target is the function itself; and for call_module and get_attr, target is also a string, but the string is the key to find the corresponding module or attribute object, here The design is not very good and needs to be adapted. Assuming that gm is an instance of GraphModule, the following method can find the instances corresponding to call_module and get_attr Node through key:

# node 为 call_module 时,其 Module 实例查找方法
modules = gm.named_modules()
module = modules[node.target]


# node 为 get_attr 时,其 attr 实例查找方法
getattr(gm, node.target)

The meta attribute of fx.Node contains node-related object information and code call stack information. Object information can help get the value of the object instance, and the code call stack can help confirm the code location corresponding to the current node. These two pieces of information are very helpful for debugging.

graph traversal pattern

Graph traversal mode is the most typical graph rewriting mode. You can use fx.Graph.nodes to get the nodes in the graph and rewrite them. The following is an example of graph rewriting that replaces the add operation with the bitwise_and operation.

import torch
from torch.fx import symbolic_trace
import operator


# 定义一个普通的 module
class M(torch.nn.Module):
    def forward(self, x, y):
        return x + y, torch.add(x, y), x.add(y)


# trace 一下
traced = symbolic_trace(M())


# 要匹配的 target 列表
patterns = set([operator.add, torch.add, "add"])


# 遍历 fx.Graph 的 Node 列表并修改
for n in traced.graph.nodes:
    # 如果当前 Node 的 target 符合 add
    if any(n.target == pattern for pattern in patterns):
        # 在当前 Node 的后面插入 bitwise_and Node
        with traced.graph.inserting_after(n):
            new_node = traced.graph.call_function(torch.bitwise_and, n.args, n.kwargs)
            n.replace_all_uses_with(new_node)
        # 清理掉过时的 Node
        traced.graph.erase_node(n)
# 重新编译下 GraphMoudle
# 根据新的图做代码生成,这样就得到了新的 GraphModule 了
traced.recompile()

The above annotates a typical graph graph traversal modification graph pattern. More examples can be found at this link.

In addition, if you usually need to do some general processing of complex inputs, the map_aggregate function provides a general transformation tool function for parameters. For an input of type tuple/list/dict composed of nodes, you can provide a node processing function fn to map_aggregate, and then map_aggregate returns an input with the same structure as the original input tuple/list/dict. In this new input Each node is transformed by fn. This function is similar to ArgsTree in oneflow.

Interpreter mode

The Interpreter mode provides a mode for modifying the graph while executing. The essence is that we can traverse the nodes in the graph and execute the nodes in the graph one by one at the same time. As mentioned above, you can obtain the node instance through the Node.target property, such as obtaining nn.Module, and then execute the instance. Here is an example of recording the shape and dtype of Node's actual output tensor by executing Node's ShapeProp. You can see that its core is traversing Node and executing Node:

for node in self.graph.nodes:
    if node.op == 'placeholder':
        result = next(args_iter)
    elif node.op == 'get_attr':
        result = fetch_attr(node.target)
    elif node.op == 'call_function':
        # load_arg 可以获取实际的 tensor,然后输入 target 做 operator 的执行
        result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
    elif node.op == 'call_method':
        self_obj, *args = load_arg(node.args)
        kwargs = load_arg(node.kwargs)
        result = getattr(self_obj, node.target)(*args, **kwargs)
    elif node.op == 'call_module':
        result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
    
    if isinstance(result, torch.Tensor):
        # 记录执行结果的 shape 和 dtype
        node.shape = result.shape
        node.dtype = result.dtype

The Interpreter mode provides a syntax sugar fx.Interpreter, which implements the above graph traversal process, and then supports overloading the behavior of different Node types, so that the logic of a Node can be customized.

fx.Interpreter accepts an fx.GraphModuel as input, and then uses the run method to execute the GraphModule:

def fn(x):
    return torch.sigmoid(x).neg()


gm = torch.fx.symbolic_trace(fn)
input = torch.randn(3, 4)


class MyInterpreter(fx.Interpreter):
    pass


result = MyInterpreter(gm).run(input)

run actually traverses the Nodes in the graph, and then calls the run_node method on the Node, and the run_node method calls the execution methods of various types of Node:

run()
    +-- run_node()
        +-- placeholder()
        +-- get_attr()
        +-- call_function()
        +-- call_method()
        +-- call_module()
        +-- output()

The execution methods of run_node and various types of nodes can be overloaded. Using Interpreter to implement ShapeProp, you can see that you don’t need to write graph traversal yourself:

class ShapePropInterpreter(fx.Interpreter):
    def run_node(self, n : Node) -> Any:
        result = super().run_node(n)
        if isinstance(result, torch.Tensor):
            # 记录执行结果的 shape 和 dtype
            n.shape = result.shape
            n.dtype = result.dtype
        return result


result = ShapePropInterpreter(gm).run(input)

In addition, you can also use Interpreter to achieve the effect of graph rewriting. The following is the operation to change the original sigmoid to neg:

class NegSigmSwapInterpreter(fx.Interpreter):
    def call_function(self, target : Target,
                      args : Tuple, kwargs : Dict) -> Any:
        if target == torch.sigmoid:
            # 这里传入的参数是实际值
            return torch.neg(*args, **kwargs)
        return super().call_function(n)


# 执行 Interpreter
result = NegSigmSwapInterpreter(gm).run(input)

Interpreter can execute and manipulate graphs at the same time. But its disadvantage is that it can modify the actual execution of the graph, but it cannot change the graph structure.

Transformer mode

Another disadvantage of the Interpreter pattern is that it is executed on the fly and does not change the structure of the graph. If you want to modify the structure of the graph, you can use the Transformer pattern.

fx.Transformer inherits from Interpreter, so the supported overloaded interfaces are similar. The difference is that it actually performs symbolic execution and creates a new graph.

Pass in a raw GraphModule to the Transformer, and then call the transform method. A new graph will be created, and then executed in the node order of the original graph. The returned result will be used to create a new node in the new graph. Finally, you get a new GraphModule.

class NegSigmSwapXformer(fx.Transformer):
    def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
        if target == torch.sigmoid:
            # 这里传入的参数是 Proxy
            return torch.neg(*args, **kwargs)
        return super().call_function(n)


# 得到了一个 sigmoid 被替换为 neg 操作的 GraphModule
transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform()

fx.Transformer provides a convenient way to rewrite Node to Node graphs.

5

Python code generation

Python code generation is triggered when the GraphModule's recompile method is called. It is the internal behavior of fx, and you usually don’t need to pay attention to it when using it. Here are the main implementation techniques.

What Python code generation is doing is converting graph into code:

# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
    %x : [#users=1] = placeholder[target=x]
    %param : [#users=1] = get_attr[target=param]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp
"""
print(symbolic_traced.code)
"""
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
"""

This conversion process is one-to-one, and a Node will be converted into the corresponding Python code. Its core function is the emit_node function in fx.Graph, taking call_method Node as an example:

elif node.op == 'call_method':
    assert isinstance(node.target, str)
    body.append(
        f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
        f'({_format_args(node.args[1:], node.kwargs)})')

The above method adds python text code to the body based on the information in the node, and adds the Node information:

%clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})

Converted to Python code:

clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None

This part of emit_node logic is called in the python_code function of fx.Graph. The python_code function returns a python_code object, which contains python source code and global object data. Then perform the following operations to assign the code generated by the graph to GraphModule.

# 生成 python 代码对象,graph 对应一个 fx.Graph
python_code = graph.python_code(root_module='self')
# python 代码文本
code = python_code.src
# python 代码全局对象
globals = python_code.globals
# 使用 python 字节码编译器编译和加载 python 代码
exec(compile(code, key, 'exec'), globals)
# 从中获取编译好的总函数 forward
forward_fn = globals_copy['forward']

Finally, replace the forward method of GraphModule with forward_fn, and you will get a GraphModule with the same execution logic as the graph.

6

torch.fx 和 torch.compile

Under the torch.compile (TorchDynamo) function of torch 2.0, when a function or nn.Module is compiled by inputting torch.compile, a compiler backend can be customized.

The following custom_backend is the customized compilation logic. torch.compile will trace the corresponding torch code into an fx.GraphModule object, and then pass in the custom_backend function, so that you can customize the compilation logic based on fx.GraphModule, generate a custom function, and return it to torch.compile. When opt_model in the example below is executed for the first time, custom_backend execution will be triggered, a custom function (compiled and optimized function) will be obtained and cached. When executed later, the compiled and optimized function can be directly used for execution to achieve optimization. The effect of execution.

from typing import List
def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("custom backend called with FX graph:")
    print(gm.graph)
    return gm.forward


opt_model = torch.compile(init_model(), backend=custom_backend)

7

Summarize

torch.fx is a Python to Python code conversion tool officially released by PyTorch. It provides tools for tracing code to generate graphs, rewriting graphs, and regenerating new Python code. Flexibility and ease of use are high. This article introduces its core features and some practical tips.

OneFlow uses torch.fx and torch.compile to convert Torch code to OneFlow code to compile and accelerate Torch code more easily.

reference

[1]. torch fx official documentation. https://pytorch.org/docs/stable/fx.html

[2]. torch.fx: Practical Program Capture and Transformation for Deep Learning in Python. https://arxiv.org/pdf/2112.08429.pdf

[3]. torch fx is used to convert torch to oneflow. https://github.com/Oneflow-Inc/diffusers/pull/237

[4]. Adapted to PyTorch FX, OneFlow makes quantization awareness training easier

Everyone else is watching

Try OneFlow: github.com/Oneflow-Inc/oneflow/

51bc4e7350d46876738707d38e94f263.png

Guess you like

Origin blog.csdn.net/OneFlow_Official/article/details/132750538