深層学習コンパイラーのレイヤーアウト変換の最適化

ディープ ラーニング コンパイラーの詳細については、https://github.com/BBuf/tvm_mlir_learn を参照してください。同時に、cuda 学習ウェアハウス https://github.com/BBuf/how-to-optim-algorithm-in-cuda と深層学習フレームワーク (PyTorch と OneFlow) の学習方法の学習ウェアハウスも維持しています。 https://github .com/BBuf/how-to-learn-deep-learning-framework 、困っている友達は小さな星をクリックしてください。LLM のトレーニングと推論に関連する一連の記事は、ディレクトリ https://github.com/BBuf/how-to-optim-algorithm-in-cuda/tree/master/large- language-model-note に集められています。

この記事の説明では、インターフェイスとインターフェイスが混在していますが、これら 2 つは同じものであり、両方とも MLIR のインターフェイスを表します。

0x0.背景

深層学習コンパイラの最適化作業の解釈に引き続き、この記事では、OneFlow システムで MLIR に基づく Layerout Transform を実装する方法を紹介します。2次元畳み込みニューラルネットワークでは、NCHWデータ形式の他にNHWCデータ形式が一般的であり、畳み込み演算の場合はNHWC形式で計算した方が性能が良い場合があります。ただし、深層学習ネットワークのトレーニングは通常、NCHW を使用して実行され、通常、推論中に NCHW から NHWC へのレイヤーアウト変換のみを実行します。ここには 2 つの問題があります: まず、Conv2D などの演算子の場合、NCHW でトレーニングされたときに保存される重み形式は [out_channels, in_channels, *kernel_size] ですが、NHWC 形式で推論するときに重み形式を変換する必要があります。 、重みのない演算子の場合、畳み込み演算子の前後に挿入される転置演算によって生じる余分なオーバーヘッドを減らすために、演算子が可能な限り NHWC 演算をサポートするようにする必要もあります。たとえば、次のような小規模ネットワーク x->conv->relu->conv->relu->out があるとします。これを NHWC 形式で実行する場合、2 つの畳み込みの重みを変更することに加えて、次のようにします。 conv 演算子に入力されるデータ形式を変更するには、conv の前後に transpose を挿入する必要もあります。つまり、x->transpose(0, 2, 3, 1)->conv->transpose(0, 3, 1, 2) ) -> relu -> transpose(0, 2, 3, 1) -> conv -> transpose(0, 3, 1, 2) -> relu -> out次に、注意深い読者は、ここに実際に多くの冗長な Transpose があることがわかるでしょう。ReLU は NHWC 形式での操作をサポートしているため、このネットワークはx->transpose(0, 2, 3,これにより、Transpose Op のオーバーヘッドが半分になります。

transpose を簡略化する理由は、transpose 演算子自体に実行やスケジューリングのオーバーヘッドがあるためで、transpose の回数を最小限に抑えないと、NHWC への切り替えによる計算高速化が transpose のオーバーヘッドでカバーされてしまう可能性があるためです。OneFlow に基づいて上記の Layerout Transform 最適化を実装し、テスト結果を以下に示します。

この最適化は V100 でテストされました。テスト コードは https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/ir/test/OneFlow/auto_nhwc/test_resnet101_benchmark.py にあります。パフォーマンス結果は次のとおりです。以下に続きます:

  • nn.Graph の AMP オプションをオンにします。
  • ネットワークは ResNet101 を選択し、それに対して前向き推論を実行します。
バッチサイズ んち 自動nhwc
16 14秒 13秒
32 24時 22秒
64 44秒 38秒

BatchSize=64 の場合は 13.6% の加速が得られ、BatchSize が小さくなるにつれて加速率は低下しますが、一定の加速は常に維持されます。重みパラメータ部分は事前に転置されるため、この部分に追加のオーバーヘッドは発生しないことに注意してください。実際には定数折りという手法を使って完成させましたが、これについては次の記事で説明します。

0x1. 解析を実装する

実装に関しては、3 つの問題を解決する必要があります。1 つ目は、NHWC 操作をサポートする演算子を決定する方法、2 つ目は、Transpose 演算子を挿入すること、3 つ目は、冗長な Transpose ペアを削除することです。

0x1.1 インターフェイスに基づいて NHWC 操作をサポートするオペレーターを決定します

OneFlow では、Op で NHWC 計算をサポートする必要がある場合、Op の定義時に NCHWCompatibilityInterface を宣言するだけで済みます。畳み込みを例に挙げます。

def OneFlow_Conv2DOp : OneFlow_ConvolutionBaseOp<"conv2d", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>, DeclareOpInterfaceMethods<NCHWCompatibleInterface>]> {
    
    }

ここでの DeclareOpInterfaceMethods は、この Operator が、NCHW 形式と互換性のある Operator が実装する必要があるメソッドを定義する NCHWSupportInterface インターフェイスを実装することを意味します。

他の Op に NHWC 操作をサポートさせたい場合は、このインターフェイスを定義し、このインターフェイスのメンバー関数を書き直すだけで済みます。

def NCHWCompatibleInterface : OpInterface<"NCHWCompatible"> {
    
    
  let description = [{
    
    
    Interface of NCHW compatibility
  }];

  let methods = [
    InterfaceMethod<"",
        "bool", "IsNCHW", (ins)
    >,
    InterfaceMethod<"Create NHWC op and return the new op's results to be transposed",
        "llvm::SmallVector<mlir::Value, 4>", "NchwToNhwc", (ins "llvm::SmallVector<mlir::Value, 4>": $transposed_inputs, "PatternRewriter&": $rewriter)
    >,
    InterfaceMethod<"",
        "llvm::DenseSet<mlir::Value>", "OperandsToTranspose", (ins)
    >,
    InterfaceMethod<"",
        "llvm::DenseSet<mlir::Value>", "ResultsToTranspose", (ins)
    >,
  ];
  let cppNamespace = "::mlir::oneflow";
}

このインターフェイスは、MLIR フレームワークのオペレーター インターフェイスを記述する基本クラスである OpInterface インターフェイスを継承します。NCHWSupportInterface は、NCHW 形式と互換性のあるオペレータ インターフェイスを表します。NCHWCompleteInterface は、いくつかのメソッドを定義します。

  • IsNCHW: 現在の Operator がどのような条件で入力データを NCHW 形式で処理しているかを示すブール値を返します。
  • NchwToNhwc: Transpose およびリライター (リライター) 後の入力を受け入れ、NCHW 形式から NHWC 形式に変換するために使用されます。
  • OperandsToTranspose: 転置を必要とする入力値のセットを返します。
  • ResultsToTranspose: Transpose を必要とする出力値のセットを返します。

次に、Conv2D Op に対応する NCHWPracticalInterface インターフェイスの実装を見てみましょう。

bool Conv2DOp::IsNCHW() {
    
     return this->getDataFormat().str() == "channels_first"; }

llvm::DenseSet<Value> Conv2DOp::OperandsToTranspose() {
    
    
  if (this->get_addToOutput()) {
    
    
    return {
    
    this->getIn(), this->getWeight(), this->get_addToOutput()};
  } else {
    
    
    return {
    
    this->getIn(), this->getWeight()};
  }
}

llvm::DenseSet<Value> Conv2DOp::ResultsToTranspose() {
    
     return {
    
    this->getOut()}; }

llvm::SmallVector<Value, 4> Conv2DOp::NchwToNhwc(llvm::SmallVector<Value, 4> value,
                                                 PatternRewriter& rewriter) {
    
    
  auto conv_op = *this;
  SmallVector<Value, 4> operands;
  operands.push_back(value[0]);
  operands.push_back(value[1]);
  if (conv_op.getBias()) operands.push_back(conv_op.getBias());
  if (this->get_addToOutput()) {
    
     operands.push_back(value[2]); }
  NamedAttrList attributes = conv_op->getAttrs();
  attributes.set(conv_op.getDataFormatAttrName(), rewriter.getStringAttr("channels_last"));
  auto res = rewriter
                 .create<oneflow::Conv2DOp>(conv_op.getLoc(), getNHWCResultTypes(conv_op), operands,
                                            attributes)
                 ->getResults();
  llvm::SmallVector<Value, 4> results;
  results.push_back(res[0]);
  return results;
}

このうち、IsNCHW メソッドは、Conv2DOp オペレーションが NCHW 形式を使用するかどうかを示すブール値を返します。これは、オペレーションの data_format 属性を調べることによって行われます。OperandsToTranspose メソッドは、Transpose を必要とする入力値のコレクションを返します。Conv2DOp の場合、主な入力には input、weight、bias (オプション)、および addto_output (オプション) が含まれます (バイアスには Transpose は必要ありません)。この addto_output はオペレータ フュージョン用の OneFlow の特別な出力です。読者は無視できます。ResultsToTranspose メソッドは、Transpose を必要とする出力値のコレクションを返します。Conv2DOp の場合、出力は 1 つだけであるため、出力特徴マップの値が返されます。NchwToNhwc メソッドは、入力値とリライターを NCHW 形式で受け入れ、結果値を NHWC 形式で返します。新しい Conv2DOp オペレーションを作成し、data_format 属性を Channels_last に設定することで、NCHW から NHWC への変換を実装します。

0x1.2 転置演算子の挿入

次のステップは、ネットワーク内の演算子に Transpose 演算子を貪欲に挿入することです。ここでの考え方は、ネットワーク内の各演算子の前後に可能な限り Transpose を挿入しようとすることです。このようにして、最適な値を取得できます。転置ペアが削除されたときの解決策。ネットワーク内の演算子に Transpose を挿入するロジックは、次のパターン コードで説明されています。

struct AutoNhwcPattern : public OpInterfaceRewritePattern<NCHWCompatible> {
    
    
  explicit AutoNhwcPattern(mlir::MLIRContext* context)
      : OpInterfaceRewritePattern<NCHWCompatible>(context, /*benefit=*/1) {
    
    }

 public:
  LogicalResult matchAndRewrite(NCHWCompatible op, PatternRewriter& rewriter) const override {
    
    
    if (op->hasTrait<OpTrait::IsOpConfCompatible>()) {
    
    
      for (mlir::Value operand : op.OperandsToTranspose()) {
    
    
        if (operand.getType().cast<mlir::RankedTensorType>().getShape().size() != 4) {
    
    
          return failure();
        }
      }
      const auto device_name = OpTrait::IsOpConfCompatible<void>::getDeviceTag(op)
                                   .cast<mlir::StringAttr>()
                                   .getValue()
                                   .str();
      if (device_name == "cpu") {
    
     return failure(); }
    }
    llvm::SmallVector<int32_t> perm = getChannelLastTransposePerm();
    llvm::SmallVector<int32_t> result_perm = getChannelFirstTransposePerm();

    NamedAttrList transpose_attributes;
    if (InitTransposeAttributes(op, transpose_attributes, rewriter).succeeded()) {
    
    
      transpose_attributes.append(llvm::StringRef("perm"), getSI32ArrayAttr(rewriter, perm));
    } else {
    
    
      return failure();
    }
    // when op op has no sense of data_format and pre op is transpose, we greedily insert transpose
    // into this op, seeking more opportunities to eliminate transpose pattern.
    const bool greedily_transpose_flag = !op.IsNCHW() && IsInsertTransposeOpBefore(op, rewriter);

    if (op.IsNCHW() || greedily_transpose_flag) {
    
    
      // create transpose op for input operand
      SmallVector<Value, 4> tranposed_operands;
      llvm::DenseSet<Value> operand_transpose = op.OperandsToTranspose();
      int num_transposed_operand = 0;
      for (Value operand : op->getOperands()) {
    
    
        if (operand_transpose.find(operand) != operand_transpose.end()) {
    
    
          SmallVector<Value, 4> input_res = getInputOperandTransposeOp(
              op, operand, transpose_attributes, num_transposed_operand, rewriter);
          tranposed_operands.push_back(input_res[0]);
          num_transposed_operand += 1;
        }
      }
      // create NHWC op
      SmallVector<Value, 4> created_results = op.NchwToNhwc(tranposed_operands, rewriter);
      // create transpose op for results
      int num_transposed_result = 0;
      transpose_attributes.set(llvm::StringRef("perm"), getSI32ArrayAttr(rewriter, result_perm));
      llvm::DenseSet<Value> transpose_result = op.ResultsToTranspose();

      for (Value result : op->getOpResults()) {
    
    
        if (transpose_result.find(result) != transpose_result.end()) {
    
    
          if (auto result_transpose_op =
                  getResultTransposeOp(op, created_results[num_transposed_result],
                                       transpose_attributes, num_transposed_result, rewriter)) {
    
    
            result.replaceAllUsesWith(result_transpose_op);
            num_transposed_result += 1;
          } else {
    
    
            return failure();
          }
        }
      }
    }
    return success();
  }
};

まず、AutoNhwcPattern クラスは、Operation を書き換えるための基本クラスである OpInterfaceRewritePattern を継承します。AutoNhwcPatternは、NCHW互換インターフェースを実装したOperationを書き換えて、NCHWからNHWCへのフォーマット変換を実現します。次に、AutoNhwcPattern は matchAndRewrite メソッドをオーバーライドします。このメソッドは、NCHW から NHWC への変換を実現するために、NCHW互換インターフェイスの操作に遭遇したときに呼び出されます。次に、matchAndRewrite メソッドは、まずオペレーションが 4 次元かどうか、CPU デバイス上にあるかどうかなど、変換条件を満たすかどうかを確認します。満足できない場合は失敗を返します。そうである場合、matchAndRewrite メソッドは、NCHW から NHWC、および NHWC か​​ら NCHW への変換順序を取得します。そして、Transpose Operationのプロパティを初期化します。次に、現在の Op が NCHW 形式であるか、この Op の前の Op が Transpose Op である場合、より多くの最適化の機会を得るために、Transpose Op を挿入する操作がここで実行されます。

ここには関連するツール機能もいくつかあります。それらについて説明しましょう。

llvm::SmallVector<int32_t> getChannelLastTransposePerm() {
    
     return {
    
    0, 2, 3, 1}; }

llvm::SmallVector<int32_t> getChannelFirstTransposePerm() {
    
     return {
    
    0, 3, 1, 2}; }

llvm::SmallVector<mlir::Value, 4> getInputOperandTransposeOp(NCHWCompatible op, Value val,
                                                             NamedAttrList transpose_attributes,
                                                             int num_transposed_operand,
                                                             PatternRewriter& rewriter) {
    
    
  std::string transpose_name = OpTrait::IsOpConfCompatible<void>::getOpName(op).str()
                               + "_transpose_input_" + std::to_string(num_transposed_operand);
  transpose_attributes.set(llvm::StringRef(OpTrait::IsOpConfCompatible<void>::getOpNameAttr()),
                           rewriter.getStringAttr(transpose_name));
  SmallVector<Value, 4> input_operands;
  input_operands.push_back(val);
  auto res = rewriter
                 .create<oneflow::TransposeOp>(op.getLoc(), getNHWCType(val.getType()),
                                               input_operands, transpose_attributes)
                 ->getResults();
  return res;
}

TransposeOp getResultTransposeOp(NCHWCompatible op, Value val, NamedAttrList transpose_attributes,
                                 int num_transposed_result, PatternRewriter& rewriter) {
    
    
  std::string transpose_name = OpTrait::IsOpConfCompatible<void>::getOpName(op).str()
                               + "_transpose_output_" + std::to_string(num_transposed_result);
  transpose_attributes.set(llvm::StringRef(OpTrait::IsOpConfCompatible<void>::getOpNameAttr()),
                           rewriter.getStringAttr(transpose_name));
  SmallVector<Value, 4> operands;
  operands.push_back(val);
  TransposeOp transpose_op = rewriter.create<oneflow::TransposeOp>(
      op.getLoc(), getNCHWType(val.getType()), operands, transpose_attributes);
  return transpose_op;
}

bool IsInsertTransposeOpBefore(NCHWCompatible op, PatternRewriter& rewriter) {
    
    
  bool insert_transpose_op_flag = false;
  for (mlir::Value operand : op->getOperands()) {
    
    
    TransposeOp transposeInputOp = operand.getDefiningOp<TransposeOp>();
    if (!transposeInputOp) continue;
    const auto perm = transposeInputOp.getPermAttr();
    if (perm.size() == 4 && perm[0] == rewriter.getSI32IntegerAttr(0)
        && perm[1] == rewriter.getSI32IntegerAttr(3) && perm[2] == rewriter.getSI32IntegerAttr(1)
        && perm[3] == rewriter.getSI32IntegerAttr(2)) {
    
    
      insert_transpose_op_flag = true;
      break;
    }
  }
  return insert_transpose_op_flag;
}

このうち、getChannelLastTransposePerm メソッドと getChannelFirstTransposePerm メソッドは、それぞれ NHWC か​​ら NCHW へ、NCHW から NHWC への変換順序を返します。getInputOperandTransposeOp メソッドは、操作の入力に対して転置操作を作成します。入力値、Transpose プロパティ、およびオーバーライダーを使用して TransposeOp を作成し、その結果を返します。同様に、getResultTransposeOp メソッドは、操作の出力に対して転置操作を作成します。出力値、Transpose プロパティ、およびオーバーライダーを使用して TransposeOp を作成し、Operation を返します。IsInsertTransposeOpBefore メソッドは、操作の入力にすでに転置操作があるかどうかを確認します。「はい」であり、転置操作によって NHWC が NCHW に変換される場合は true を返し、それ以外の場合は false を返します。

0x1.3 冗長な転置ペアを削除します

次に、Transpose Op に挿入されたグラフ内のすべての隣接する Transpose ペアを可能な限り削除する必要があります。コードの実装は次のとおりです。

bool IsRedundantTransposeMatch(ArrayAttr pre, ArrayAttr afe, mlir::PatternRewriter& rewriter) {
    
    
  const auto prePerm = pre.getValue().vec();
  const auto afePerm = afe.getValue().vec();
  if (prePerm.size() == 4 && afePerm.size() == 4) {
    
    
    // handle nchw->nhwc->nchw: (0, 2, 3, 1) -> (0, 3, 1, 2)
    if (prePerm[0] == afePerm[0] && prePerm[1] == afePerm[3] && prePerm[2] == afePerm[1]
        && prePerm[3] == afePerm[2] && prePerm[0] == rewriter.getSI32IntegerAttr(0)
        && prePerm[1] == rewriter.getSI32IntegerAttr(2)
        && prePerm[2] == rewriter.getSI32IntegerAttr(3)
        && prePerm[3] == rewriter.getSI32IntegerAttr(1))
      return true;
    // handle nhwc->nchw->nhwc: (0, 3, 1, 2) -> (0, 2, 3, 1)
    if (prePerm[0] == afePerm[0] && prePerm[1] == afePerm[2] && prePerm[2] == afePerm[3]
        && prePerm[3] == afePerm[1] && prePerm[0] == rewriter.getSI32IntegerAttr(0)
        && prePerm[1] == rewriter.getSI32IntegerAttr(3)
        && prePerm[2] == rewriter.getSI32IntegerAttr(1)
        && prePerm[3] == rewriter.getSI32IntegerAttr(2))
      return true;
  }
  return false;
}

struct AutoNhwcEliminateRedundantTransposePattern : public mlir::OpRewritePattern<TransposeOp> {
    
    
  explicit AutoNhwcEliminateRedundantTransposePattern(mlir::MLIRContext* context)
      : OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {
    
    }
  mlir::LogicalResult matchAndRewrite(TransposeOp op,
                                      mlir::PatternRewriter& rewriter) const override {
    
    
    mlir::Value transposeInput = op.getOperand();
    TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();

    if (!transposeInputOp
        || !IsRedundantTransposeMatch(op.getPermAttr(), transposeInputOp.getPermAttr(), rewriter)) {
    
    
      return failure();
    }
    rewriter.replaceOp(op, {
    
    transposeInputOp.getOperand()});
    return success();
  }
};

IsRedundantTransposeMatch メソッドは、2 つの Transpose 操作の順序によって冗長性が生じるかどうかをチェックします。2つのTransposeのパーマ特性を比較して判定します。AutoNhwcPattern と同様に、AutoNhwcEliminateRedundantTransposePattern クラスは OpRewritePattern を継承します。TransposeOp をオーバーライドして、Transpose の削除を実装します。順序が NHWC->NCHW->NHWC または NCHW->NHWC->NCHW の場合、冗長 Transpose と判断されます。入力が TransposeOp からも来ており、2 つの Transpose シーケンスによって冗長性が生じている場合、matchAndRewrite メソッドは TransposeOp を TransposeOp の入力に置き換えます。トランスポーズの除去を実現します。matchAndRewrite メソッドは、まず TransposeOp の入力を取得し、その入力も TransposeOp からのものであるかどうかを確認します。そうでない場合、または 2 つの転置の順序が冗長性をもたらさない場合は、失敗が返されます。最後に success を返し、冗長な Transpose の削除が成功したことを示します。

最終的に、上記で紹介した 2 つのパスは AutoNhwcPass にカプセル化され、MLIR の計算グラフに作用してグローバルな最適化を完了します。以下のコードからわかるように、この最適化は ONEFLOW_MLIR_PREFER_NHWC 環境変数がオンになっている場合にのみ有効になります。

void populateAutoNhwcPatterns(::mlir::RewritePatternSet& patterns) {
    
    
  bool enable_nhwc = ::oneflow::ParseBooleanFromEnv("ONEFLOW_MLIR_PREFER_NHWC", false);
  if (enable_nhwc) {
    
    
    patterns.add<AutoNhwcPattern>(patterns.getContext());
    patterns.add<AutoNhwcEliminateRedundantTransposePattern>(patterns.getContext());
  }
}

class AutoNhwcPass : public AutoNhwcPassBase<AutoNhwcPass> {
    
    
  void runOnOperation() override {
    
    
    Operation* op = getOperation();
    RewritePatternSet patterns(op->getContext());
    oneflow::populateAutoNhwcPatterns(patterns);
    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
  }
};

補足:転置消去の0x1.4の重み

重みの転置をどのように扱うかについても簡単に説明する必要があります。0x1.2 では、重みの Transpose Op (定数演算) も挿入しました。その後、重みが定数であることがわかったので、重みの Transpose Op をコンパイル時に折りたたむことができます。このプロセスは、後で紹介する https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/ir/oneflow-translate/lib/OneFlow/MLIROneFlowTranslation.cpp#L808-L811 で行われます。 Constant Folding の実装を見てみましょう。

0x2. 結論

この記事では、OneFlow コンパイラの Layerout Transform について紹介します。このテクノロジは、Stable Diffusion の後の OneFlow バージョンでも重要な役割を果たし、推論速度を向上させました。TVM の Ansor にも同様の最適化があり、Op のスケジュールに影響を与える Op の戦略として異なる Layerout を設定することで、より大きな検索スペースとより良い結果を得るために検索時に Layerout Transform を考慮します。Transpose の余分なオーバーヘッドに対処する方法は唯一の方法ではありません。ここで紹介するのは、私が個人的に比較的簡単だと思う方法です。同様のニーズがある読者は自由に使用してください。

おすすめ

転載: blog.csdn.net/just_sort/article/details/130738580