Common Subexpression Elimination and Dead Code Elimination Implementation of Deep Learning Compiler

0x0. Preamble

More deep learning compiler knowledge can be found at https://github.com/BBuf/tvm_mlir_learn. At the same time, it also maintains a cuda learning warehouse https://github.com/BBuf/how-to-optim-algorithm-in-cuda and a learning warehouse of how to learn deep learning frameworks (PyTorch and OneFlow), https://github .com/BBuf/how-to-learn-deep-learning-framework , friends in need can click a little star. A series of articles related to LLM training and reasoning are collected under the directory https://github.com/BBuf/how-to-optim-algorithm-in-cuda/tree/master/large-language-model-note.

[Saving flow] Last time I introduced the Layerout Transform optimization of the deep learning compiler . In this article, I mentioned that I will also introduce the implementation of the constant folding optimization Pass, but before introducing the constant folding pass, I would like to introduce a similar optimization method. That is Common Subexpression Elimination Implementation (CSE). Still take the CSE Pass implemented based on MLIR in OneFlow as an example to explain. In the process of analyzing the code implementation, I found that when the common subexpression is eliminated based on MLIR, the function of dead code elimination is also implemented by the way. In addition, when considering the elimination of common subexpressions, it is necessary to ensure that the two repeated operations are in the same basic block and that there are no other operations with side effects between the two repeated operations before they can be eliminated. In the implementation of OneFlow, only the special properties of OneFlow's UserOp, namely OpName and SymbolID, are erased and replaced with a magic property, because these two properties should not affect the elimination of common subexpressions. This optimization is quite useful, and it plays a big role in OneFlow's Stable Diffusion optimization.

0x1. Effect

The function of common subexpression elimination is very simple, that is, to fold common expressions into one expression to avoid repeated calculation overhead. Let's take two tests written by OneFlow for CSE Pass as an example to illustrate. These two examples are here https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/ir/test/OneFlow/cse.mlir, a MLIR Module is provided here, which contains two functions: @Cast_1__FUSE__ScalarMulByTensor_2 and @f2.

Among them, the first function @Cast_1__FUSE__ScalarMulByTensor_2 accepts a tensor of shape 96x96xi64 as input and performs two type conversion operations to convert the input into a 96x96xf32 tensor. It then adds the two resulting tensors using the oneflow.add_n operation and returns the resulting 96x96xf32 tensor. The FileCheck command verifies the existence of the "oneflow.cast" and "oneflow.add_n2" ops with the "ScalarMulByTensor_2" op_name attribute. Here again explain the CHECK specification, such as CHECK: %[[OUT:[a-zA-Z0-9_]+]] = "oneflow.cast" is a FileCheck command, used to verify whether the generated code is as expected. FileCheck is part of the LLVM project to provide pattern matching functionality for compiler testing. %[[OUT:[a-zA-Z0-9_]+]] is a regular expression capture group that captures a string beginning with % followed by a sequence of letters, numbers or underscores. This string corresponds to a value name in MLIR. "oneflow.cast" means we expect to find an operation named "oneflow.cast".

The second function @f2 takes three input tensors: a tensor of shape 2x64x64x320xf16, a tensor of shape 320x320x3x3xf16, and a tensor of shape 320xf16. It transposes the second input tensor twice and performs two conv2d operations using the transposed tensor, the first input tensor, and the third input tensor. The function returns two result tensors of shape 2x64x64x320xf16. The FileCheck command verifies the existence of the "oneflow.conv2d" operation with a scope_symbol_id attribute equal to 163 and checks the output of the two result tensors.

These two functions have one thing in common, that is, they both have the same common Op. After compiling oneflow, we can use the following command to add the CSE Pass to the opt pass pipeline to run the mlir expression for transformation. We can Focus on the transformed expression. The command is as follows:

oneflow/build/oneflow/ir/bin/oneflow-opt oneflow/oneflow/ir/test/OneFlow/cse.mlir -cse-with-attributes-ignored -cse -cse-put-attributes -canonicalize

To explain a few options here:

  • cse-with-attributes-ignored: This parameter tells the optimizer to ignore OneFlow IR-specific attributes that affect CSE (here OpName and SymbolID) when performing Common Subexpression Elimination (CSE).
  • cse: This parameter turns on common subexpression elimination (CSE) optimization. CSE is a compiler optimization technique used to remove redundant subexpressions, thereby reducing the amount of computation and improving program execution speed.
  • cse-put-attributes: This parameter instructs the optimizer to put the original attributes back into the original operation after performing CSE. This helps ensure that property information for operations is preserved during optimization. (Also implying that we must preserve the original properties)
  • canonicalize: This parameter enables canonicalization optimization. Normalization optimization converts operations and expressions in a program into a unified standard form, thereby simplifying the implementation of subsequent optimizations and improving efficiency. (In these two given examples, turning off canonicalize will not affect the expression of the output IR)

Next is the MLIR Module output after running the above command.

module {
    
    
  func.func @Cast_1__FUSE__ScalarMulByTensor_2(%arg0: tensor<96x96xi64>) -> tensor<96x96xf32> {
    
    
    %0 = "oneflow.cast"(%arg0) {
    
    device_name = ["0:0"], device_tag = "cpu", dtype = 2 : i32, hierarchy = [1], op_name = "Cast_1", op_type_name = "cast", pin_memory = false, scope_symbol_id = 4611686018427416574 : i64} : (tensor<96x96xi64>) -> tensor<96x96xf32>
    %1 = "oneflow.add_n2"(%0, %0) {
    
    device_name = ["0:0"], device_tag = "cpu", hierarchy = [1], op_name = "ScalarMulByTensor_2", op_type_name = "add_n", scope_symbol_id = 4611686018427416574 : i64} : (tensor<96x96xf32>, tensor<96x96xf32>) -> tensor<96x96xf32>
    return %1 : tensor<96x96xf32>
  }
  func.func @f2(%arg0: tensor<2x64x64x320xf16>, %arg1: tensor<320x320x3x3xf16>, %arg2: tensor<320xf16>) -> (tensor<2x64x64x320xf16>, tensor<2x64x64x320xf16>) {
    
    
    %0 = "oneflow.transpose"(%arg1) {
    
    device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.down_blocks.0.resnets.0.conv1-conv2d-31_transpose_input_1", perm = [0 : si32, 2 : si32, 3 : si32, 1 : si32], scope_symbol_id = 163 : i64} : (tensor<320x320x3x3xf16>) -> tensor<320x3x3x320xf16>
    %1 = "oneflow.conv2d"(%arg0, %0, %arg2) {
    
    data_format = "channels_last", device_name = ["@0:0"], device_tag = "cuda", dilation_rate = [1 : si32, 1 : si32], filters = 320 : si32, groups = 1 : si32, hierarchy = [1], kernel_size = [3 : si32, 3 : si32], op_name = "unet.down_blocks.0.resnets.0.conv1-conv2d-31", operand_segment_sizes = array<i32: 1, 1, 1, 0>, padding_before = [1 : si32, 1 : si32], scope_symbol_id = 163 : i64, strides = [1 : si32, 1 : si32], tuning_cache = ""} : (tensor<2x64x64x320xf16>, tensor<320x3x3x320xf16>, tensor<320xf16>) -> tensor<2x64x64x320xf16>
    return %1, %1 : tensor<2x64x64x320xf16>, tensor<2x64x64x320xf16>
  }
}

Compared with the original MLIR ModuleOp, we found that only one common subexpression (cast and transpose) in these two functions is retained, achieving the purpose of common subexpression elimination. In the OneFlow compiler, this optimization was first introduced in OneFlow's Stable Diffusion, which accelerated the inference speed of the model.

0x2. Principle & Code Implementation

The principle of implementing CSE based on OneFlow is that we need to eliminate the two attributes of OpName and SymbolID of UserOp in OneFlow. These two attributes have no effect on CSE, but they are added by the OneFlow system, so we need to make a Preprocessing ignores these two inconsistencies. Then after calling the CSE Pass of the MLIR system, we need to add this ignored attribute back. Only in this way can we ensure that the optimized IR can be transferred back to the OneFlow graph and executed correctly.

First, based on ODS, two CSE-related Pass classes are defined in https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/ir/include/OneFlow/OneFlowPasses.td#L156-L172, and MLIR will automatically Generate the definitions of these two Passes. Let's take a closer look at the details:

def CSEWithAttributesIgnored : Pass<"cse-with-attributes-ignored", "ModuleOp"> {
    
     //  定义了一个名为 "cse-with-attributes-ignored" 的 Pass,它作用在 MLIR 中的模块操作(ModuleOp)上。
  let summary = "ignore oneflow attributes to have cse work"; // summary 和 description: 提供了有关 Pass 功能的简短描述和详细说明。这个 Pass 的目的是执行 CSE 优化,同时忽略 OneFlow 属性(如操作名、符号 ID 等)。
  let description = [{
    
    
    cse and ignore oneflow attributes like op name, symbol id, etc.
  }];
  let constructor = "mlir::oneflow::createCSEWithAttributesIgnored()"; // 指定用于创建这个 Pass 的函数,即 mlir::oneflow::createCSEWithAttributesIgnored()。
  let dependentDialects = []; // 列出这个 Pass 依赖的其他方言。在这种情况下,它是空的,表示没有依赖关系。
}

def CSEPutAttributes : Pass<"cse-put-attributes", "ModuleOp"> {
    
    
  let summary = "cse and ignore oneflow attributes";
  let description = [{
    
    
    put back oneflow attributes like op name, symbol id, etc.
  }];
  let constructor = "mlir::oneflow::createCSEPutAttributes()";
  let dependentDialects = [];
}

It can be seen that the pre-processing and post-processing Pass of CSE mainly implements the two functions createCSEWithAttributesIgnored and createCSEPutAttributes. They are defined in:https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/ir/include/OneFlow/Transform/CSEWithAttributesIgnored.h#L25-L33

// CSEState 结构体包含两个成员:
// scopeSymbolIDs:一个 llvm::DenseMap,将 Operation* 类型的指针映射到 IntegerAttr 类型的属性。这个映射可能用于存储操作的范围符号ID。
// opNames:一个 llvm::DenseMap,将 Operation* 类型的指针映射到 StringAttr 类型的属性。这个映射可能用于存储操作的名称。
struct CSEState {
    
    
  llvm::DenseMap<Operation*, IntegerAttr> scopeSymbolIDs;
  llvm::DenseMap<Operation*, StringAttr> opNames;
};
// 这个函数返回一个 std::unique_ptr<mlir::Pass> 类型的对象。根据函数名称,这个函数创建一个CSE Pass,其中忽略了属性。
std::unique_ptr<mlir::Pass> createCSEWithAttributesIgnored();
// 这个函数也返回一个 std::unique_ptr<mlir::Pass> 类型的对象。根据函数名称,这个函数创建一个CSE Pass,会处理或放置属性。
std::unique_ptr<mlir::Pass> createCSEPutAttributes();
// 这个函数接受一个 std::shared_ptr<CSEState> 类型的参数,并返回一个 std::pair,其中包含两个 std::unique_ptr<Pass> 类型的对象。这个函数创建一对CSE Pass,它们共享给定的 CSEState。
std::pair<std::unique_ptr<Pass>, std::unique_ptr<Pass>> createCSEPasses(
    std::shared_ptr<CSEState> state);
// 这个函数接受一个 std::shared_ptr<CSEState> 类型的参数。根据函数名称,这个函数可能会注册一组CSE Pass,它们共享给定的 CSEState。
void registerCSEPasses(std::shared_ptr<CSEState> state);

Next, let's look at the specific implementation of these Passes. code inhttps://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/ir/lib/OneFlow/Transform/CSEWithAttributesIgnored.cpp

First look at createCSEWithAttributesIgnored:


struct EraseAttributes : public mlir::OpInterfaceRewritePattern<UserOpCompatible> {
    
    
  explicit EraseAttributes(mlir::MLIRContext* context, std::shared_ptr<CSEState> state)
      : OpInterfaceRewritePattern<UserOpCompatible>(context, /*benefit=*/1), state_{
    
    state} {
    
    }
  mlir::LogicalResult matchAndRewrite(UserOpCompatible op,
                                      mlir::PatternRewriter& rewriter) const override {
    
    
    if (op->getAttrOfType<StringAttr>(OpTrait::IsOpConfCompatible<void>::getOpNameAttr())
            .getValue()
            .str()
        != MAGIC_OP_NAME) {
    
    
      if (state_) {
    
    
        state_->opNames[op] =
            op->getAttrOfType<StringAttr>(OpTrait::IsOpConfCompatible<void>::getOpNameAttr());
        state_->scopeSymbolIDs[op] = op->getAttrOfType<IntegerAttr>(
            OpTrait::IsOpConfCompatible<void>::getScopeSymbolIDAttr());
      }
      op->setAttr(OpTrait::IsOpConfCompatible<void>::getOpNameAttr(),
                  rewriter.getStringAttr(MAGIC_OP_NAME));
      op->setAttr(OpTrait::IsOpConfCompatible<void>::getScopeSymbolIDAttr(),
                  rewriter.getI64IntegerAttr(MAGIC_SCOPE_SYMBOL_ID));
      return success();
    } else {
    
    
      return failure();
    }
  }

 private:
  std::shared_ptr<CSEState> state_;
};

class CSEWithAttributesIgnored : public CSEWithAttributesIgnoredBase<CSEWithAttributesIgnored> {
    
    
 public:
  explicit CSEWithAttributesIgnored() {
    
    }
  explicit CSEWithAttributesIgnored(std::shared_ptr<CSEState> state) : state_(state) {
    
    }
  void runOnOperation() override {
    
    
    Operation* op = getOperation();
    RewritePatternSet patterns(op->getContext());
    patterns.add<EraseAttributes>(op->getContext(), state_);
    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
  }

 private:
  std::shared_ptr<CSEState> state_;
};

std::unique_ptr<Pass> createCSEWithAttributesIgnored() {
    
    
  return std::make_unique<CSEWithAttributesIgnored>();
}

This code defines an EraseAttributes override class that removes certain attributes in op. It inherits from OpInterfaceRewritePattern, which means it can match the OpInterface that implements UserOpCompatible. The EraseAttributes constructor then accepts an MLIRContext* and a shared_ptr. CSEState is used to track properties of ops that have been overridden. The matchAndRewrite method checks whether op has a StringAttr attribute named OpNameAttr, and if it does, and its value is not equal to MAGIC_OP_NAME, the method will:

  • Record the OpNameAttr and ScopeSymbolIDAttr attributes of op in CSEState.
  • Set OpNameAttr to MAGIC_OP_NAME and ScopeSymbolIDAttr to MAGIC_SCOPE_SYMBOL_ID.

Then, CSEWithAttributesIgnored inherits from CSEWithAttributesIgnoredBase and rewrites its runOnOperation method. This method will instantiate a RewritePatternSet, add the EraseAttributes matching rewrite template, and then apply the template to remove the attributes in the user op. It also saves a shared_ptr pointing to CSEState, which can be used in EraseAttributes. Note that CSEWithAttributesIgnoredBase here is the Pass class definition automatically generated by ODS. The createCSEWithAttributesIgnored function will create a CSEWithAttributesIgnored pass and return it.

Then look at the implementation of createCSEPutAttributes,

struct PutAttributes : public mlir::OpInterfaceRewritePattern<UserOpCompatible> {
    
    
  explicit PutAttributes(mlir::MLIRContext* context, std::shared_ptr<CSEState> state)
      : OpInterfaceRewritePattern<UserOpCompatible>(context, /*benefit=*/1), state_{
    
    state} {
    
    }
  mlir::LogicalResult matchAndRewrite(UserOpCompatible op,
                                      mlir::PatternRewriter& rewriter) const override {
    
    
    if (op->getAttrOfType<StringAttr>(OpTrait::IsOpConfCompatible<void>::getOpNameAttr())
            .getValue()
            .str()
        == MAGIC_OP_NAME) {
    
    
      if (state_) {
    
    
        op->setAttr(OpTrait::IsOpConfCompatible<void>::getOpNameAttr(), state_->opNames[op]);
        op->setAttr(OpTrait::IsOpConfCompatible<void>::getScopeSymbolIDAttr(),
                    state_->scopeSymbolIDs[op]);
      }
      return success();
    } else {
    
    
      return failure();
    }
  }

 private:
  std::shared_ptr<CSEState> state_;
};

class CSEPutAttributes : public CSEPutAttributesBase<CSEPutAttributes> {
    
    
 public:
  explicit CSEPutAttributes() {
    
    }
  explicit CSEPutAttributes(std::shared_ptr<CSEState> state) {
    
     state_ = state; }

  void runOnOperation() override {
    
    
    Operation* op = getOperation();
    RewritePatternSet patterns(op->getContext());
    patterns.add<PutAttributes>(op->getContext(), state_);
    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
  }

 private:
  std::shared_ptr<CSEState> state_;
};


std::unique_ptr<Pass> createCSEPutAttributes() {
    
     return std::make_unique<CSEPutAttributes>(); }

This PutAttributes override template is the opposite of EraseAttributes, which restores previously deleted attributes back to op. The PutAttributes constructor also accepts an MLIRContext* and a shared_ptr. It uses CSEState to find previously deleted attribute values. The matchAndRewrite method checks if op has a StringAttr attribute named OpNameAttr equal to MAGIC_OP_NAME. If yes, it will look up the original OpNameAttr and ScopeSymbolIDAttr attribute values ​​from CSEState. Set OpNameAttr to its original value and ScopeSymbolIDAttr to its original value.

The above two passes are pre-processing and post-processing in OneFlow, and the real CSE Pass is the CSE Pass ( oneflow/build/oneflow/ir/llvm_monorepo-src/mlir/lib/Transforms/CSE.cpp) that comes with MLIR. Let's analyze it.

struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
    
    
  static unsigned getHashValue(const Operation *opC) {
    
    
    return OperationEquivalence::computeHash(
        const_cast<Operation *>(opC),
        /*hashOperands=*/OperationEquivalence::directHashValue,
        /*hashResults=*/OperationEquivalence::ignoreHashValue,
        OperationEquivalence::IgnoreLocations);
  }
  static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
    
    
    auto *lhs = const_cast<Operation *>(lhsC);
    auto *rhs = const_cast<Operation *>(rhsC);
    if (lhs == rhs)
      return true;
    if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
        rhs == getTombstoneKey() || rhs == getEmptyKey())
      return false;
    return OperationEquivalence::isEquivalentTo(
        const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
        OperationEquivalence::IgnoreLocations);
  }
};

SimpleOperationInfo This structure inherits from llvm::DenseMapInfo<Operation *>. This structure is intended to provide custom hash and equality functions for Operation objects used in LLVM DenseMap. It overloads two methods:

  • getHashValue: Calculate the hash value for Operation*. It uses OperationEquivalence::computeHash to calculate the hash value and passes hashOperands=directHashValue and hashResults=ignoreHashValue. This means it computes the hash directly on the op's operands, but ignores the result.
  • isEqual: Checks whether two Operation* are equal. It first checks if it is the same op, and if so, returns true. Otherwise, it checks whether two ops are equivalent using OperationEquivalence::isEquivalentTo. Also, it passes IgnoreLocations, which means it will ignore op's location information.

So, this DenseMapInfo allows to use Operation* as key of DenseMap regardless of result and position. Operands are used for equivalence checking and hash value calculations.

/// Simple common sub-expression elimination.
// 这是一个名为CSE(Common Sub-expression Elimination,公共子表达式消除)的结构体定义,用于执行简单的公共子表达式消除。CSE是一种编译器优化技术,用于消除程序中的重复计算,提高执行效率。
struct CSE : public impl::CSEBase<CSE> {
    
    
  /// Shared implementation of operation elimination and scoped map definitions.
  // 使用AllocatorTy和ScopedMapTy来定义分配器和作用域映射。ScopedMapTy是一个散列表,用于存储操作之间的映射关系。
  using AllocatorTy = llvm::RecyclingAllocator<
      llvm::BumpPtrAllocator,
      llvm::ScopedHashTableVal<Operation *, Operation *>>;
  using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
                                            SimpleOperationInfo, AllocatorTy>;

  /// Cache holding MemoryEffects information between two operations. The first
  /// operation is stored has the key. The second operation is stored inside a
  /// pair in the value. The pair also hold the MemoryEffects between those
  /// two operations. If the MemoryEffects is nullptr then we assume there is
  /// no operation with MemoryEffects::Write between the two operations.
  // MemEffectsCache 用于在两个操作之间缓存 MemoryEffects 信息。MemoryEffects 表示某个操作对内存的影响。
  using MemEffectsCache =
      DenseMap<Operation *, std::pair<Operation *, MemoryEffects::Effect *>>;

  /// Represents a single entry in the depth first traversal of a CFG.
  // CFGStackNode结构体表示控制流图(CFG)深度优先遍历中的一个节点。包括作用域、节点、子节点迭代器等信息。
  struct CFGStackNode {
    
    
    CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node)
        : scope(knownValues), node(node), childIterator(node->begin()) {
    
    }

    /// Scope for the known values.
    ScopedMapTy::ScopeTy scope;

    DominanceInfoNode *node;
    DominanceInfoNode::const_iterator childIterator;

    /// If this node has been fully processed yet or not.
    bool processed = false;
  };

  /// Attempt to eliminate a redundant operation. Returns success if the
  /// operation was marked for removal, failure otherwise.
  // simplifyOperation 函数尝试消除冗余操作。如果操作被标记为移除,则返回成功,否则返回失败。
  LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
                                  bool hasSSADominance);
  // simplifyBlock函数简化指定的基本块(Block)。
  void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance);
  // simplifyRegion函数简化指定的区域(Region)。
  void simplifyRegion(ScopedMapTy &knownValues, Region &region);
	
	// runOnOperation函数是重写的基类方法,用于执行CSE优化。
  void runOnOperation() override;

private:
	// replaceUsesAndDelete函数用于替换操作的使用和删除操作。
  void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
                            Operation *existing, bool hasSSADominance);

  /// Check if there is side-effecting operations other than the given effect
  /// between the two operations.
  // hasOtherSideEffectingOpInBetween函数检查给定操作之间是否存在其他具有副作用的操作。
  bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp);
	
  /// Operations marked as dead and to be erased.
  // opsToErase是一个用于存储将要删除的操作的向量。
  std::vector<Operation *> opsToErase;
  // domInfo是一个指向支配信息(DominanceInfo)的指针。
  DominanceInfo *domInfo = nullptr;
  // memEffectsCache是一个缓存,用于存储操作之间的内存效果信息。
  MemEffectsCache memEffectsCache;
};
} // namespace

Let's take a look at the core runOperation method first.

void CSE::runOnOperation() {
    
    
  /// A scoped hash table of defining operations within a region.
  // 定义一个名为knownValues的局部变量。它是一个作用域内的哈希表,用于存储在一个区域内定义的操作。
  ScopedMapTy knownValues;
	
	// 从DominanceInfo分析中获取支配关系信息,并将其存储在名为domInfo的变量中。
  domInfo = &getAnalysis<DominanceInfo>();
  // 获取当前操作(rootOp),并遍历其所有区域。对每个区域执行简化操作(simplifyRegion)。
  Operation *rootOp = getOperation();

  for (auto &region : rootOp->getRegions())
    simplifyRegion(knownValues, region);
	
	// 如果opsToErase(要删除的操作)为空,说明没有操作被删除,因此保留所有分析。
  // If no operations were erased, then we mark all analyses as preserved.
  if (opsToErase.empty())
    return markAllAnalysesPreserved();

  /// Erase any operations that were marked as dead during simplification.
  // 如果opsToErase中有操作,遍历opsToErase并删除其中的操作。然后清空opsToErase。
  for (auto *op : opsToErase)
    op->erase();
  opsToErase.clear();

  // We currently don't remove region operations, so mark dominance as
  // preserved.
  // 由于当前代码不会删除区域操作,因此将支配关系信息(DominanceInfo)和后支配关系信息(PostDominanceInfo)标记为已保留。将domInfo设置为nullptr。
  markAnalysesPreserved<DominanceInfo, PostDominanceInfo>();
  domInfo = nullptr;
}

Here, the dominance relationship in the Region in the current ModuleOp will be obtained first, so that the dominance information can be updated after the Op is deleted after the subsequent execution of the CSE. The focus here is the simplifyRegion function, which is the specific detail of implementing CSE. This function mainly uses the dominator tree to traverse the basic blocks in the region, and calls the simplifyBlock() function to simplify each basic block.

// 函数接受一个类型为ScopedMapTy的引用knownValues和一个类型为Region的引用region作为参数。
void CSE::simplifyRegion(ScopedMapTy &knownValues, Region &region) {
    
    
  // If the region is empty there is nothing to do.
  if (region.empty())
    return;
	// 判断区域是否具有SSA支配关系(Static Single Assignment Dominance),并将结果存储在变量hasSSADominance中。
  bool hasSSADominance = domInfo->hasSSADominance(&region);

  // If the region only contains one block, then simplify it directly.
  // 如果区域只包含一个基本块,那么直接对其进行简化。创建一个名为scope的ScopedMapTy::ScopeTy对象,然后调用simplifyBlock()函数对该基本块进行简化。
  if (region.hasOneBlock()) {
    
    
    ScopedMapTy::ScopeTy scope(knownValues);
    simplifyBlock(knownValues, &region.front(), hasSSADominance);
    return;
  }

  // If the region does not have dominanceInfo, then skip it.
  // TODO: Regions without SSA dominance should define a different
  // traversal order which is appropriate and can be used here.
  // 如果区域没有支配关系信息(hasSSADominance为false),则跳过它。此处提到了一个TODO:对于没有SSA支配关系的区域,应该定义一个不同的遍历顺序。
  if (!hasSSADominance)
    return;

  // Note, deque is being used here because there was significant performance
  // gains over vector when the container becomes very large due to the
  // specific access patterns. If/when these performance issues are no
  // longer a problem we can change this to vector. For more information see
  // the llvm mailing list discussion on this:
  // http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html
  // 定义一个名为stack的std::deque容器,用于存储CFGStackNode的std::unique_ptr。这里使用deque是因为它在容器变大时具有更好的性能表现。
  std::deque<std::unique_ptr<CFGStackNode>> stack;

  // Process the nodes of the dom tree for this region.
  // 处理这个区域的支配树节点。将区域的根节点压入栈中。
  stack.emplace_back(std::make_unique<CFGStackNode>(
      knownValues, domInfo->getRootNode(&region)));
	// 当栈不为空时,执行以下循环操作:
  while (!stack.empty()) {
    
    
    // 获取栈顶的当前节点(currentNode)。
    auto &currentNode = stack.back();

    // Check to see if we need to process this node.
    // 检查当前节点是否需要被处理。如果未处理,则将其标记为已处理,并调用simplifyBlock()函数对当前节点所在的基本块进行简化。
    if (!currentNode->processed) {
    
    
      currentNode->processed = true;
      simplifyBlock(knownValues, currentNode->node->getBlock(),
                    hasSSADominance);
    }

    // Otherwise, check to see if we need to process a child node.
    // 检查是否需要处理子节点。如果当前节点的子节点迭代器未到达末尾,将子节点压入栈中。
    if (currentNode->childIterator != currentNode->node->end()) {
    
    
      auto *childNode = *(currentNode->childIterator++);
      stack.emplace_back(
          std::make_unique<CFGStackNode>(knownValues, childNode));
    } else {
    
    
      // Finally, if the node and all of its children have been processed
      // then we delete the node.
      // 如果当前节点及其所有子节点都已处理完毕,则将节点从栈中弹出。
      stack.pop_back();
    }
  }
}

Please refer to the comments for the execution flow of the function. After this step, the specific implementation of CSE is actually in the simplifyBlock function, and we will continue to track it. The function accepts a reference knownValues ​​of type ScopedMapTy, a pointer bb of type Block, and a boolean value hasSSADominance as parameters. As can be inferred from the code, the purpose of this function is to simplify a given basic block.

void CSE::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
                        bool hasSSADominance) {
    
    
  // 遍历基本块bb中的所有操作(op)
  for (auto &op : *bb) {
    
    
    // Most operations don't have regions, so fast path that case.
    // 检查操作是否包含区域。如果操作包含区域,执行以下操作:
    if (op.getNumRegions() != 0) {
    
    
      // If this operation is isolated above, we can't process nested regions
      // with the given 'knownValues' map. This would cause the insertion of
      // implicit captures in explicit capture only regions.
      // 如果操作具有IsIsolatedFromAbove特性,那么我们不能使用给定的knownValues映射来处理嵌套区域,
      // 因为这可能导致在仅显式捕获的区域中插入隐式捕获。在这种情况下,创建一个新的nestedKnownValues映射,
      // 并对操作的每个区域调用simplifyRegion()函数。
      if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
    
    
        ScopedMapTy nestedKnownValues;
        for (auto &region : op.getRegions())
          simplifyRegion(nestedKnownValues, region);
      } else {
    
    
        // Otherwise, process nested regions normally.
        // 如果操作没有IsIsolatedFromAbove特性,那么正常处理嵌套区域。
        // 对操作的每个区域调用simplifyRegion()函数,传入knownValues映射。
        for (auto &region : op.getRegions())
          simplifyRegion(knownValues, region);
      }
    }
		// 如果操作被简化(调用simplifyOperation()函数并检查其返回值),则不处理操作包含的任何区域,继续处理下一个操作。
    // If the operation is simplified, we don't process any held regions.
    if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance)))
      continue;
  }
  // Clear the MemoryEffects cache since its usage is by block only.
  // 在处理完所有操作后,清空memEffectsCache,因为它的使用仅限于单个基本块。
  memEffectsCache.clear();
}

In simplifyBlock, simplifyOperation will be further called to optimize the Operation. Let's follow up with this function at the end.
The parameters of the function are the same as simplifyBlock, accepting a reference knownValues ​​of type ScopedMapTy, a pointer op of type Operation, and a boolean value hasSSADominance as parameters.

/// Attempt to eliminate a redundant operation.
LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op,
                                     bool hasSSADominance) {
    
    
  // Don't simplify terminator operations.
  // 如果操作是终止操作(具有IsTerminator特性),则不对其进行简化。
  if (op->hasTrait<OpTrait::IsTerminator>())
    return failure();

  // If the operation is already trivially dead just add it to the erase list.
  // 如果操作已经是无关紧要的死代码,将其添加到待擦除操作列表opsToErase中,增加死代码消除计数,然后返回成功。
  if (isOpTriviallyDead(op)) {
    
    
    opsToErase.push_back(op);
    ++numDCE;
    return success();
  }

  // Don't simplify operations with regions that have multiple blocks.
  // TODO: We need additional tests to verify that we handle such IR correctly.
  // 不简化具有多个基本块的区域中的操作。这里提到了一个TODO:需要额外的测试来验证处理此类IR的正确性。
  if (!llvm::all_of(op->getRegions(), [](Region &r) {
    
    
        return r.getBlocks().empty() || llvm::hasSingleElement(r.getBlocks());
      }))
    return failure();

  // Some simple use case of operation with memory side-effect are dealt with
  // here. Operations with no side-effect are done after.
  // 首先处理具有内存副作用的简单操作。没有副作用的操作会在后面处理。
  if (!isMemoryEffectFree(op)) {
    
    
    auto memEffects = dyn_cast<MemoryEffectOpInterface>(op);
    // TODO: Only basic use case for operations with MemoryEffects::Read can be
    // eleminated now. More work needs to be done for more complicated patterns
    // and other side-effects.
    // 如果操作不是无内存副作用的,尝试获取其MemoryEffectOpInterface。
    // 如果操作没有MemoryEffectOpInterface,或者它不仅仅具有MemoryEffects::Read副作用,则返回失败。
    if (!memEffects || !memEffects.onlyHasEffect<MemoryEffects::Read>())
      return failure();

    // Look for an existing definition for the operation.
    // 查找操作的现有定义。如果找到现有定义,并且操作在同一个基本块中,并且两者之间没有其它具有副作用的操作,
    // 则可以删除冗余操作。调用replaceUsesAndDelete()函数替换使用并删除操作。
    if (auto *existing = knownValues.lookup(op)) {
    
    
      if (existing->getBlock() == op->getBlock() &&
          !hasOtherSideEffectingOpInBetween(existing, op)) {
    
    
        // The operation that can be deleted has been reach with no
        // side-effecting operations in between the existing operation and
        // this one so we can remove the duplicate.
        replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
        return success();
      }
    }
    // 将操作插入knownValues映射中,并返回失败。
    knownValues.insert(op, op);
    return failure();
  }

  // Look for an existing definition for the operation.
  // 查找操作的现有定义。如果找到现有定义,调用replaceUsesAndDelete()函数替换使用并删除操作,
  // 增加公共子表达式消除计数,并返回成功。
  if (auto *existing = knownValues.lookup(op)) {
    
    
    replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
    ++numCSE;
    return success();
  }

  // Otherwise, we add this operation to the known values map.
  // 否则,将此操作添加到knownValues映射中,并返回失败。
  knownValues.insert(op, op);
  return failure();
}

We can see that in simplifyOperation, not only common subexpression elimination (CSE), but also dead code elimination (DCE) are included. Also, when processing an Operation, it takes into account the memory side effects of the Operation and whether the Operation is in a region with multiple basic blocks.

0x3. Summary

In the process of reading the code implementation, I found that when doing common subexpression elimination based on MLIR, the function of dead code elimination was also implemented incidentally. In addition, when considering the elimination of common subexpressions, it is necessary to ensure that the two repeated operations are in the same basic block and that there are no other operations with side effects between the two repeated operations before they can be eliminated. In the implementation of OneFlow, only the special properties of OneFlow's UserOp, namely OpName and SymbolID, are erased and replaced with a magic property, because these two properties should not affect the elimination of common subexpressions. This optimization is quite useful, and it plays a big role in OneFlow's Stable Diffusion optimization.

0x4. Related links

  • Analysis of TVM's CSE Pass implementation: https://blog.csdn.net/Eurypterid/article/details/123118666

Guess you like

Origin blog.csdn.net/just_sort/article/details/130907187