Layerout Transform Optimization of Deep Learning Compiler

More deep learning compiler knowledge can be found at https://github.com/BBuf/tvm_mlir_learn. At the same time, it also maintains a cuda learning warehouse https://github.com/BBuf/how-to-optim-algorithm-in-cuda and a learning warehouse of how to learn deep learning frameworks (PyTorch and OneFlow), https://github .com/BBuf/how-to-learn-deep-learning-framework , friends in need can click a little star . A series of articles related to LLM training and reasoning are collected under the directory https://github.com/BBuf/how-to-optim-algorithm-in-cuda/tree/master/large-language-model-note.

In the description of this article, there are some mixtures of interfaces and Interfaces. These two are the same and both represent the Interface of MLIR.

0x0.background

Continuing with the interpretation of the optimization work of the deep learning compiler, this article will introduce how to implement Layerout Transform based on MLIR in the OneFlow system. In the 2D convolutional neural network, in addition to the NCHW data format, there is generally an NHWC data format. For convolution operations, using the NHWC format for calculation may obtain better performance. However, the training of the deep learning network is generally carried out using NCHW, and we generally only perform the Layerout Transform from NCHW to NHWC during inference. There are two problems here: First, for an operator such as Conv2D, the weight format saved when it is trained in NCHW mode is [out_channels, in_channels, *kernel_size], but we need to convert the weight format when inferring in NHWC format ; Then, for operators without weights, we also need to make the operators support NHWC operations as much as possible to reduce the additional overhead caused by the Transpose operation inserted before and after the convolution operator. For example, suppose there is a small network x->conv->relu->conv->relu->out as follows, if we want to execute it in NHWC format, in addition to changing the weights of the two convolutions, We also need to insert transpose before and after conv to modify the data format input to the conv operator, that is, x->transpose(0, 2, 3, 1)->conv->transpose(0, 3, 1, 2) - > relu -> transpose(0, 2, 3, 1) -> conv -> transpose(0, 3, 1, 2) -> relu -> out . Then careful readers can find that there are actually a lot of redundant Transpose here, because ReLU supports operations in NHWC format, then this network can be simplified to x->transpose(0, 2, 3,. This cuts the Transpose Op overhead in half.

The reason for the simplification of transpose is that the transpose operator itself has the overhead of running and scheduling. If we do not minimize the number of transpose, the calculation acceleration brought about by switching to NHWC may be covered by the overhead of transpose. We implemented the above Layerout Transform optimization based on OneFlow, and the test results are given below.

This optimization was tested on V100, the test code is found in https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/ir/test/OneFlow/auto_nhwc/test_resnet101_benchmark.py, the performance results are as follows:

  • Turn on the AMP option of nn.Graph.
  • The network selects ResNet101 and performs forward reasoning on it.
batch_size nchw auto nhwc
16 14s 13s
32 24 p 22s
64 44s 38s

When BatchSize=64, a 13.6% acceleration is obtained. As the BatchSize decreases, the acceleration ratio will decrease, but some acceleration will always be maintained. It should be noted that the weight parameter part is transposed in advance, so there is no additional overhead for this part. In fact, we used the method of constant folding to complete it, which will be discussed in the next article.

0x1. Realize parsing

In terms of implementation, three problems need to be solved. The first is how to determine which operators support NHWC operations, the second is to insert Transpose operators, and the third is to eliminate redundant Transpose pairs.

0x1.1 Determine which operators support NHWC operations based on Interface

In OneFlow, if we want an Op to support NHWC calculation, we only need to declare a NCHWCompatibleInterface when the Op is defined. Take convolution as an example:

def OneFlow_Conv2DOp : OneFlow_ConvolutionBaseOp<"conv2d", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>, DeclareOpInterfaceMethods<NCHWCompatibleInterface>]> {
    
    }

DeclareOpInterfaceMethods here means that this Operator implements the NCHWCompatibleInterface interface, which defines the methods that the Operator that is compatible with the NCHW format needs to implement.

If we want any other Op to support NHWC operations, we only need to define this interface and rewrite the member functions of this interface. Next, let's look at the definition of NCHWCompatibleInterface.

def NCHWCompatibleInterface : OpInterface<"NCHWCompatible"> {
    
    
  let description = [{
    
    
    Interface of NCHW compatibility
  }];

  let methods = [
    InterfaceMethod<"",
        "bool", "IsNCHW", (ins)
    >,
    InterfaceMethod<"Create NHWC op and return the new op's results to be transposed",
        "llvm::SmallVector<mlir::Value, 4>", "NchwToNhwc", (ins "llvm::SmallVector<mlir::Value, 4>": $transposed_inputs, "PatternRewriter&": $rewriter)
    >,
    InterfaceMethod<"",
        "llvm::DenseSet<mlir::Value>", "OperandsToTranspose", (ins)
    >,
    InterfaceMethod<"",
        "llvm::DenseSet<mlir::Value>", "ResultsToTranspose", (ins)
    >,
  ];
  let cppNamespace = "::mlir::oneflow";
}

This interface inherits from the OpInterface interface, which is the base class describing the Operator Interface in the MLIR framework. NCHWCompatibleInterface represents an Operator Interface compatible with the NCHW format. NCHWCompatibleInterface defines several methods:

  • IsNCHW: Returns a bool value, indicating under what conditions the current Operator is processing input data in NCHW format.
  • NchwToNhwc: Accepts input after Transpose and rewriter (rewriter), used to convert from NCHW format to NHWC format.
  • OperandsToTranspose: Returns the set of input values ​​that require Transpose.
  • ResultsToTranspose: Returns the set of output values ​​that require Transpose.

Next, let's take a look at the implementation of the NCHWCompatibleInterface interface corresponding to Conv2D Op:

bool Conv2DOp::IsNCHW() {
    
     return this->getDataFormat().str() == "channels_first"; }

llvm::DenseSet<Value> Conv2DOp::OperandsToTranspose() {
    
    
  if (this->get_addToOutput()) {
    
    
    return {
    
    this->getIn(), this->getWeight(), this->get_addToOutput()};
  } else {
    
    
    return {
    
    this->getIn(), this->getWeight()};
  }
}

llvm::DenseSet<Value> Conv2DOp::ResultsToTranspose() {
    
     return {
    
    this->getOut()}; }

llvm::SmallVector<Value, 4> Conv2DOp::NchwToNhwc(llvm::SmallVector<Value, 4> value,
                                                 PatternRewriter& rewriter) {
    
    
  auto conv_op = *this;
  SmallVector<Value, 4> operands;
  operands.push_back(value[0]);
  operands.push_back(value[1]);
  if (conv_op.getBias()) operands.push_back(conv_op.getBias());
  if (this->get_addToOutput()) {
    
     operands.push_back(value[2]); }
  NamedAttrList attributes = conv_op->getAttrs();
  attributes.set(conv_op.getDataFormatAttrName(), rewriter.getStringAttr("channels_last"));
  auto res = rewriter
                 .create<oneflow::Conv2DOp>(conv_op.getLoc(), getNHWCResultTypes(conv_op), operands,
                                            attributes)
                 ->getResults();
  llvm::SmallVector<Value, 4> results;
  results.push_back(res[0]);
  return results;
}

Among them, the IsNCHW method returns a bool value, indicating whether the Conv2DOp Operation uses the NCHW format. It does so by examining the data_format attribute of the Operation. The OperandsToTranspose method returns a collection of input values ​​that require Transpose. For Conv2DOp, the main inputs include input, weight, bias (optional) and addto_output (optional), where bias does not require Transpose, and this addto_output is a special output of OneFlow for operator fusion. Readers can ignore it. The ResultsToTranspose method returns a collection of output values ​​that require Transpose. For Conv2DOp, there is only one output, so the value of the output feature map is returned. The NchwToNhwc method accepts an input value and a rewriter in NCHW format and returns a result value in NHWC format. It implements the conversion from NCHW to NHWC by creating a new Conv2DOp Operation and setting the data_format attribute to channels_last.

0x1.2 Insert Transpose operator

The next step is to greedily insert Transpose operators into the operators in the network. The idea here is that we try to insert a Transpose before and after each operator in the network as much as possible. In this way, we can obtain the optimal value when the Transpose pairs are eliminated. solution. The logic of inserting Transpose to operators in the network is described in the following Pattern code:

struct AutoNhwcPattern : public OpInterfaceRewritePattern<NCHWCompatible> {
    
    
  explicit AutoNhwcPattern(mlir::MLIRContext* context)
      : OpInterfaceRewritePattern<NCHWCompatible>(context, /*benefit=*/1) {
    
    }

 public:
  LogicalResult matchAndRewrite(NCHWCompatible op, PatternRewriter& rewriter) const override {
    
    
    if (op->hasTrait<OpTrait::IsOpConfCompatible>()) {
    
    
      for (mlir::Value operand : op.OperandsToTranspose()) {
    
    
        if (operand.getType().cast<mlir::RankedTensorType>().getShape().size() != 4) {
    
    
          return failure();
        }
      }
      const auto device_name = OpTrait::IsOpConfCompatible<void>::getDeviceTag(op)
                                   .cast<mlir::StringAttr>()
                                   .getValue()
                                   .str();
      if (device_name == "cpu") {
    
     return failure(); }
    }
    llvm::SmallVector<int32_t> perm = getChannelLastTransposePerm();
    llvm::SmallVector<int32_t> result_perm = getChannelFirstTransposePerm();

    NamedAttrList transpose_attributes;
    if (InitTransposeAttributes(op, transpose_attributes, rewriter).succeeded()) {
    
    
      transpose_attributes.append(llvm::StringRef("perm"), getSI32ArrayAttr(rewriter, perm));
    } else {
    
    
      return failure();
    }
    // when op op has no sense of data_format and pre op is transpose, we greedily insert transpose
    // into this op, seeking more opportunities to eliminate transpose pattern.
    const bool greedily_transpose_flag = !op.IsNCHW() && IsInsertTransposeOpBefore(op, rewriter);

    if (op.IsNCHW() || greedily_transpose_flag) {
    
    
      // create transpose op for input operand
      SmallVector<Value, 4> tranposed_operands;
      llvm::DenseSet<Value> operand_transpose = op.OperandsToTranspose();
      int num_transposed_operand = 0;
      for (Value operand : op->getOperands()) {
    
    
        if (operand_transpose.find(operand) != operand_transpose.end()) {
    
    
          SmallVector<Value, 4> input_res = getInputOperandTransposeOp(
              op, operand, transpose_attributes, num_transposed_operand, rewriter);
          tranposed_operands.push_back(input_res[0]);
          num_transposed_operand += 1;
        }
      }
      // create NHWC op
      SmallVector<Value, 4> created_results = op.NchwToNhwc(tranposed_operands, rewriter);
      // create transpose op for results
      int num_transposed_result = 0;
      transpose_attributes.set(llvm::StringRef("perm"), getSI32ArrayAttr(rewriter, result_perm));
      llvm::DenseSet<Value> transpose_result = op.ResultsToTranspose();

      for (Value result : op->getOpResults()) {
    
    
        if (transpose_result.find(result) != transpose_result.end()) {
    
    
          if (auto result_transpose_op =
                  getResultTransposeOp(op, created_results[num_transposed_result],
                                       transpose_attributes, num_transposed_result, rewriter)) {
    
    
            result.replaceAllUsesWith(result_transpose_op);
            num_transposed_result += 1;
          } else {
    
    
            return failure();
          }
        }
      }
    }
    return success();
  }
};

First, the AutoNhwcPattern class inherits from OpInterfaceRewritePattern, which is a base class for rewriting Operation. AutoNhwcPattern rewrites the Operation that implements NCHWCompatible Interface to realize the format conversion from NCHW to NHWC. Then, AutoNhwcPattern overrides the matchAndRewrite method. This method will be called when encountering the Operation of NCHWCompatible Interface to realize the conversion from NCHW to NHWC. Next, the matchAndRewrite method will first check whether the Operation meets the conversion conditions, such as whether it is 4-dimensional, whether it is on a CPU device, and so on. Returns failure if not satisfied. If so, the matchAndRewrite method gets the conversion order of NCHW to NHWC and NHWC to NCHW. And initialize the properties of Transpose Operation. Then, if the current Op is in NCHW format or the previous Op of this Op is Transpose Op, the operation of inserting Transpose Op is performed here to obtain more optimization opportunities.

There are also several related tool functions involved here, let's explain them:

llvm::SmallVector<int32_t> getChannelLastTransposePerm() {
    
     return {
    
    0, 2, 3, 1}; }

llvm::SmallVector<int32_t> getChannelFirstTransposePerm() {
    
     return {
    
    0, 3, 1, 2}; }

llvm::SmallVector<mlir::Value, 4> getInputOperandTransposeOp(NCHWCompatible op, Value val,
                                                             NamedAttrList transpose_attributes,
                                                             int num_transposed_operand,
                                                             PatternRewriter& rewriter) {
    
    
  std::string transpose_name = OpTrait::IsOpConfCompatible<void>::getOpName(op).str()
                               + "_transpose_input_" + std::to_string(num_transposed_operand);
  transpose_attributes.set(llvm::StringRef(OpTrait::IsOpConfCompatible<void>::getOpNameAttr()),
                           rewriter.getStringAttr(transpose_name));
  SmallVector<Value, 4> input_operands;
  input_operands.push_back(val);
  auto res = rewriter
                 .create<oneflow::TransposeOp>(op.getLoc(), getNHWCType(val.getType()),
                                               input_operands, transpose_attributes)
                 ->getResults();
  return res;
}

TransposeOp getResultTransposeOp(NCHWCompatible op, Value val, NamedAttrList transpose_attributes,
                                 int num_transposed_result, PatternRewriter& rewriter) {
    
    
  std::string transpose_name = OpTrait::IsOpConfCompatible<void>::getOpName(op).str()
                               + "_transpose_output_" + std::to_string(num_transposed_result);
  transpose_attributes.set(llvm::StringRef(OpTrait::IsOpConfCompatible<void>::getOpNameAttr()),
                           rewriter.getStringAttr(transpose_name));
  SmallVector<Value, 4> operands;
  operands.push_back(val);
  TransposeOp transpose_op = rewriter.create<oneflow::TransposeOp>(
      op.getLoc(), getNCHWType(val.getType()), operands, transpose_attributes);
  return transpose_op;
}

bool IsInsertTransposeOpBefore(NCHWCompatible op, PatternRewriter& rewriter) {
    
    
  bool insert_transpose_op_flag = false;
  for (mlir::Value operand : op->getOperands()) {
    
    
    TransposeOp transposeInputOp = operand.getDefiningOp<TransposeOp>();
    if (!transposeInputOp) continue;
    const auto perm = transposeInputOp.getPermAttr();
    if (perm.size() == 4 && perm[0] == rewriter.getSI32IntegerAttr(0)
        && perm[1] == rewriter.getSI32IntegerAttr(3) && perm[2] == rewriter.getSI32IntegerAttr(1)
        && perm[3] == rewriter.getSI32IntegerAttr(2)) {
    
    
      insert_transpose_op_flag = true;
      break;
    }
  }
  return insert_transpose_op_flag;
}

Among them, the getChannelLastTransposePerm and getChannelFirstTransposePerm methods return the conversion order from NHWC to NCHW and NCHW to NHWC respectively. The getInputOperandTransposeOp method creates a Transpose Operation for the Operation's input. It creates a TransposeOp with the input values, Transpose properties and overrider, and returns its result. Similarly, the getResultTransposeOp method creates a Transpose Operation for the Operation's output. It creates a TransposeOp with the output value, Transpose properties, and overrider, and returns the Operation. The IsInsertTransposeOpBefore method checks whether the input of the Operation already has a Transpose Operation. If yes, and the Transpose Operation converts NHWC to NCHW, return true, otherwise return false.

0x1.3 Eliminate redundant Transpose pairs

Next, we need to eliminate all adjacent Transpose pairs in the graph inserted into Transpose Op as much as possible. The code implementation is as follows:

bool IsRedundantTransposeMatch(ArrayAttr pre, ArrayAttr afe, mlir::PatternRewriter& rewriter) {
    
    
  const auto prePerm = pre.getValue().vec();
  const auto afePerm = afe.getValue().vec();
  if (prePerm.size() == 4 && afePerm.size() == 4) {
    
    
    // handle nchw->nhwc->nchw: (0, 2, 3, 1) -> (0, 3, 1, 2)
    if (prePerm[0] == afePerm[0] && prePerm[1] == afePerm[3] && prePerm[2] == afePerm[1]
        && prePerm[3] == afePerm[2] && prePerm[0] == rewriter.getSI32IntegerAttr(0)
        && prePerm[1] == rewriter.getSI32IntegerAttr(2)
        && prePerm[2] == rewriter.getSI32IntegerAttr(3)
        && prePerm[3] == rewriter.getSI32IntegerAttr(1))
      return true;
    // handle nhwc->nchw->nhwc: (0, 3, 1, 2) -> (0, 2, 3, 1)
    if (prePerm[0] == afePerm[0] && prePerm[1] == afePerm[2] && prePerm[2] == afePerm[3]
        && prePerm[3] == afePerm[1] && prePerm[0] == rewriter.getSI32IntegerAttr(0)
        && prePerm[1] == rewriter.getSI32IntegerAttr(3)
        && prePerm[2] == rewriter.getSI32IntegerAttr(1)
        && prePerm[3] == rewriter.getSI32IntegerAttr(2))
      return true;
  }
  return false;
}

struct AutoNhwcEliminateRedundantTransposePattern : public mlir::OpRewritePattern<TransposeOp> {
    
    
  explicit AutoNhwcEliminateRedundantTransposePattern(mlir::MLIRContext* context)
      : OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {
    
    }
  mlir::LogicalResult matchAndRewrite(TransposeOp op,
                                      mlir::PatternRewriter& rewriter) const override {
    
    
    mlir::Value transposeInput = op.getOperand();
    TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();

    if (!transposeInputOp
        || !IsRedundantTransposeMatch(op.getPermAttr(), transposeInputOp.getPermAttr(), rewriter)) {
    
    
      return failure();
    }
    rewriter.replaceOp(op, {
    
    transposeInputOp.getOperand()});
    return success();
  }
};

The IsRedundantTransposeMatch method checks whether the order of two Transpose Operations would cause redundancy. It judges by comparing the perm properties of the two Transpose. Similar to AutoNhwcPattern, the AutoNhwcEliminateRedundantTransposePattern class inherits from OpRewritePattern. It overrides TransposeOp to implement Transpose elimination. If the order is NHWC->NCHW->NHWC or NCHW->NHWC->NCHW, it is judged as redundant Transpose. If the input also comes from a TransposeOp and the two Transpose sequences cause redundancy, the matchAndRewrite method will replace the TransposeOp with the input of the TransposeOp. Achieve Transpose elimination. The matchAndRewrite method first takes the input of the TransposeOp and checks whether the input also comes from a TransposeOp. If not, or the order of the two Transposes does not result in redundancy, failure is returned. Return success at last to indicate successful elimination of redundant Transpose.

In the end, the two Passes introduced above are encapsulated into AutoNhwcPass and act on the calculation graph of MLIR to complete the global optimization. As you can see from the code below, this optimization only takes effect when the ONEFLOW_MLIR_PREFER_NHWC environment variable is turned on.

void populateAutoNhwcPatterns(::mlir::RewritePatternSet& patterns) {
    
    
  bool enable_nhwc = ::oneflow::ParseBooleanFromEnv("ONEFLOW_MLIR_PREFER_NHWC", false);
  if (enable_nhwc) {
    
    
    patterns.add<AutoNhwcPattern>(patterns.getContext());
    patterns.add<AutoNhwcEliminateRedundantTransposePattern>(patterns.getContext());
  }
}

class AutoNhwcPass : public AutoNhwcPassBase<AutoNhwcPass> {
    
    
  void runOnOperation() override {
    
    
    Operation* op = getOperation();
    RewritePatternSet patterns(op->getContext());
    oneflow::populateAutoNhwcPatterns(patterns);
    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
  }
};

Supplement: 0x1.4 weight of transpose elimination

It is also necessary to briefly explain how to deal with the transpose of weight. In 0x1.2, we also inserted Transpose Op for weight (constant op), and then we know that weight is constant, so the Transpose Op for weight can be folded at compile time. This process is done at https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/ir/oneflow-translate/lib/OneFlow/MLIROneFlowTranslation.cpp#L808-L811, which we will introduce later Let's take a look at the implementation of Constant Folding.

0x2. Conclusion

This article introduces the Layerout Transform in the OneFlow compiler. This technology also played an important role in the later OneFlow version of Stable Diffusion, improving the inference speed. There is a similar optimization in TVM's Ansor. By setting different Layerouts as the Op's strategy to affect the Op's schedule, consider the Layerout Transform when searching to obtain a larger search space and better results. The way to deal with the extra overhead of Transpose is not the only way, here is just a way that I personally think is relatively simple, and readers are free to use it if they have similar needs.

Guess you like

Origin blog.csdn.net/just_sort/article/details/130738580