ECM technology learning: template matching (Template matching)

Template matching (Template matching, TM) is a derivation method at the decoding end, which is used to refine the motion information of the current CU to make the MV of the current CU more accurate.

TM mainly minimizes the matching error between the template of the current picture (the top and/or left adjacent blocks of the current CU) and the template of the reference picture by finding an MV. As shown in the figure below, a better MV is searched around the initial MV of the current CU within the [– 8, +8] pixel search range. The TM determines the search step size based on the AMVR mode, and the TM can be cascaded with the bilateral matching (BM) process in the Merge mode.

 In the AMVP mode, only specific MV candidates are refined. Specifically, the MVP candidate for TM refinement is determined according to the template matching error: select the MVP candidate with the smallest difference between the current block template and the reference block template for TM refinement. change. The TM refines this MVP candidate starting from full-pixel MVD accuracy (or 4 pixels in AMVR mode) within the [–8, +8] pixel search range by using an iterative diamond search. AMVP candidates can be further refined by using an intersect search with full-pixel MVD precision (or 4 pixels in AMVR mode), followed by half-pixel and quarter-pixel searches according to the AMVR mode specified in Table 1. This search process ensures that the MVP candidate still maintains the same MV accuracy as indicated by the AMVR mode after the TM process.

 In Merge mode, a similar search method is applied to the Merge candidates indicated by the Merge index. As shown in Table 1, TM can always perform to 1/8 pixel MVD precision or skip those exceeding half pixel MVD precision, depending on whether to use alternative interpolation filter (alternative interpolation filter, AltIF, when Used when AMVR is in half-pixel mode). Furthermore, when TM mode is enabled, template matching can be used as an independent process between block-based and sub-block-based bilateral matching (BM) methods or as an additional MV refinement process, depending on whether BM can be enabled according to its enabling condition check .

related code

In ECM, the entry function of TM refined MV is the deriveTMMv function. It should be noted that in Merge mode, TM refinement will be performed on all MV candidates, while in AMVP mode, only the MV with the smallest template matching error in the candidate list will be refined. For TM refinement, the functions called by the two are different, as follows:

#if TM_MRG
  // Merge模式下调用的函数
  void       deriveTMMv         (PredictionUnit& pu);
#endif
  // 对特定的MV进行细化
  Distortion deriveTMMv         (const PredictionUnit& pu, bool fillCurTpl, Distortion curBestCost, RefPicList eRefList, int refIdx, int maxSearchRounds, Mv& mv, const MvField* otherMvf = nullptr);

 The code and comments of these two functions are as follows:

void InterPrediction::deriveTMMv(PredictionUnit& pu)
{
  if( !pu.tmMergeFlag )
  {
    return;
  }

  Distortion minCostUni[NUM_REF_PIC_LIST_01] = { std::numeric_limits<Distortion>::max(), std::numeric_limits<Distortion>::max() };

  for (int iRefList = 0; iRefList < ( pu.cu->slice->isInterB() ? NUM_REF_PIC_LIST_01 : 1 ) ; ++iRefList)
  {
    if (pu.interDir & (iRefList + 1))
    {
      minCostUni[iRefList] = deriveTMMv(pu, true, std::numeric_limits<Distortion>::max(), (RefPicList)iRefList, pu.refIdx[iRefList], TM_MAX_NUM_OF_ITERATIONS, pu.mv[iRefList]);
    }
  }

  if (pu.cu->slice->isInterB() && pu.interDir == 3
#if MULTI_PASS_DMVR
    && !PU::checkBDMVRCondition(pu)
#endif
    )
  {
    if (minCostUni[0] == std::numeric_limits<Distortion>::max() || minCostUni[1] == std::numeric_limits<Distortion>::max())
    {
      return;
    }

    RefPicList eTargetPicList = (minCostUni[0] <= minCostUni[1]) ? REF_PIC_LIST_1 : REF_PIC_LIST_0;
    MvField    mvfBetterUni(pu.mv[1 - eTargetPicList], pu.refIdx[1 - eTargetPicList]);
    Distortion minCostBi = deriveTMMv(pu, true, std::numeric_limits<Distortion>::max(), eTargetPicList, pu.refIdx[eTargetPicList], TM_MAX_NUM_OF_ITERATIONS, pu.mv[eTargetPicList], &mvfBetterUni);

    if (minCostBi > (minCostUni[1 - eTargetPicList] + (minCostUni[1 - eTargetPicList] >> 3)))
    {
      pu.interDir = 1 + (1 - eTargetPicList);
      pu.mv    [eTargetPicList] = Mv();
      pu.refIdx[eTargetPicList] = NOT_VALID;
    }
  }
}

 

#if TM_AMVP || TM_MRG
// maxSearchRounds 最大搜索次数,为0时表示不进行搜索,仅计算初始MV对应的模板的Cost
Distortion InterPrediction::deriveTMMv(const PredictionUnit& pu, bool fillCurTpl, Distortion curBestCost, RefPicList eRefList, int refIdx, int maxSearchRounds, Mv& mv, const MvField* otherMvf)
{
  CHECK(refIdx < 0, "Invalid reference index for TM");
  const CodingUnit& cu   = *pu.cu;
  const Picture& refPic  = *cu.slice->getRefPic(eRefList, refIdx)->unscaledPic;
  bool doSimilarityCheck = otherMvf == nullptr ? false : cu.slice->getRefPOC((RefPicList)eRefList, refIdx) == cu.slice->getRefPOC((RefPicList)(1 - eRefList), otherMvf->refIdx);

  InterPredResources interRes(m_pcReshape, m_pcRdCost, m_if, m_filteredBlockTmp[0][COMPONENT_Y]
                           ,  m_filteredBlock[3][1][0], m_filteredBlock[3][0][0]
  );
  // 构造函数,获取当前模板和参考模板
  TplMatchingCtrl tplCtrl(pu, interRes, refPic, fillCurTpl, COMPONENT_Y, true, maxSearchRounds, m_pcCurTplAbove, m_pcCurTplLeft, m_pcRefTplAbove, m_pcRefTplLeft, mv, (doSimilarityCheck ? &(otherMvf->mv) : nullptr), curBestCost);
  
  if (!tplCtrl.getTemplatePresentFlag()) 
  {
    // 如果上模板和左模板都不存在
    return std::numeric_limits<Distortion>::max();
  }

  if (otherMvf == nullptr) // uni prediction 单向预测
  {
    tplCtrl.deriveMvUni<TM_TPL_SIZE>();
    mv = tplCtrl.getFinalMv(); // 返回最终细化的MV
    return tplCtrl.getMinCost(); // 返回最小的代价
  }
  else // bi prediction 双向预测
  {
    const Picture& otherRefPic = *cu.slice->getRefPic((RefPicList)(1-eRefList), otherMvf->refIdx)->unscaledPic; // 另一个方向的参考帧
    // 当前模板减去另一个方向的参考模板
    tplCtrl.removeHighFreq<TM_TPL_SIZE>(otherRefPic, otherMvf->mv, getBcwWeight(cu.BcwIdx, eRefList));
    tplCtrl.deriveMvUni<TM_TPL_SIZE>();
    mv = tplCtrl.getFinalMv();

    int8_t intWeight = getBcwWeight(cu.BcwIdx, eRefList);
    return (tplCtrl.getMinCost() * intWeight + (g_BcwWeightBase >> 1)) >> g_BcwWeightBase;
  }
}

The template acquisition and search process in the TM process are controlled by the TplMatchingCtrl class, the code is as follows:

class TplMatchingCtrl
{
  enum TMSearchMethod
  {
    TMSEARCH_DIAMOND,
    TMSEARCH_CROSS,
    TMSEARCH_NUMBER_OF_METHODS
  };

  const CodingUnit&         m_cu;
  const PredictionUnit&     m_pu;
        InterPredResources& m_interRes;

  const Picture&    m_refPic;
  const Mv          m_mvStart;
        Mv          m_mvFinal;
  const Mv*         m_otherRefListMv;
        Distortion  m_minCost;
        bool        m_useWeight;
        int         m_maxSearchRounds;
        ComponentID m_compID;

  PelBuf m_curTplAbove;
  PelBuf m_curTplLeft;
  PelBuf m_refTplAbove;
  PelBuf m_refTplLeft;
  PelBuf m_refSrAbove; // pre-filled samples on search area
  PelBuf m_refSrLeft;  // pre-filled samples on search area

#if JVET_X0056_DMVD_EARLY_TERMINATION
  Distortion m_earlyTerminateTh;
#endif
#if MULTI_PASS_DMVR
  Distortion m_tmCostArrayDiamond[9];
  Distortion m_tmCostArrayCross[5];
#endif

public:
  // 构造函数,获取当前模板和参考模板
  TplMatchingCtrl(const PredictionUnit&     pu,
                        InterPredResources& interRes, // Bridge required resource from InterPrediction
                  const Picture&            refPic,
                  const bool                fillCurTpl,
                  const ComponentID         compID,
                  const bool                useWeight,
                  const int                 maxSearchRounds,
                        Pel*                curTplAbove,
                        Pel*                curTplLeft,
                        Pel*                refTplAbove,
                        Pel*                refTplLeft,
                  const Mv&                 mvStart,
                  const Mv*                 otherRefListMv,
                  const Distortion          curBestCost
  );
  // 返回模板是否存在
  bool       getTemplatePresentFlag() { return m_curTplAbove.buf != nullptr || m_curTplLeft.buf != nullptr; }
  Distortion getMinCost            () { return m_minCost; } // 返回最小的cost
  Mv         getFinalMv            () { return m_mvFinal; } // 返回最终细化后的MV
  static int getDeltaMean          (const PelBuf& bufCur, const PelBuf& bufRef, const int rowSubShift, const int bd);

  template <int tplSize> void deriveMvUni    (); // 推导单向MV
  template <int tplSize> void removeHighFreq (const Picture& otherRefPic, const Mv& otherRefMv, const uint8_t curRefBcwWeight);

private:
  template <int tplSize, bool TrueA_FalseL>         bool       xFillCurTemplate   (Pel* tpl);
  template <int tplSize, bool TrueA_FalseL, int sr> PelBuf     xGetRefTemplate    (const PredictionUnit& curPu, const Picture& refPic, const Mv& _mv, PelBuf& dstBuf);
  template <int tplSize, bool TrueA_FalseL>         void       xRemoveHighFreq    (const Picture& otherRefPic, const Mv& otherRefMv, const uint8_t curRefBcwWeight);
  template <int tplSize, int searchPattern>         void       xRefineMvSearch    (int maxSearchRounds, int searchStepShift);
#if MULTI_PASS_DMVR
  template <int searchPattern>                      void       xNextTmCostAarray  (int bestDirect);
  template <int searchPattern>                      void       xDeriveCostBasedMv ();
  template <bool TrueX_FalseY>                      void       xDeriveCostBasedOffset (Distortion costLorA, Distortion costCenter, Distortion costRorB, int log2StepSize);
                                                    int        xBinaryDivision    (int64_t numerator, int64_t denominator, int fracBits);
#endif
  template <int tplSize>                            Distortion xGetTempMatchError (const Mv& mv);
  template <int tplSize, bool TrueA_FalseL>         Distortion xGetTempMatchError (const Mv& mv);
};

 The acquisition of the current template and the reference template in TM mode are implemented in the constructor of the TplMatchingCtrl class, and the xFillCurTemplate function and the xGetRefTemplate function are respectively called to realize the acquisition of the current template and the reference template.

#if TM_AMVP || TM_MRG
TplMatchingCtrl::TplMatchingCtrl( const PredictionUnit&     pu,
                                        InterPredResources& interRes,
                                  const Picture&            refPic,
                                  const bool                fillCurTpl,
                                  const ComponentID         compID,
                                  const bool                useWeight,
                                  const int                 maxSearchRounds,
                                        Pel*                curTplAbove,
                                        Pel*                curTplLeft,
                                        Pel*                refTplAbove,
                                        Pel*                refTplLeft,
                                  const Mv&                 mvStart,
                                  const Mv*                 otherRefListMv,
                                  const Distortion          curBestCost
)
: m_cu              (*pu.cu)
, m_pu              (pu)
, m_interRes        (interRes)
, m_refPic          (refPic)
, m_mvStart         (mvStart)
, m_mvFinal         (mvStart)
, m_otherRefListMv  (otherRefListMv)
, m_minCost         (curBestCost)
, m_useWeight       (useWeight)
, m_maxSearchRounds (maxSearchRounds)
, m_compID          (compID)
{
  // Initialization 初始化
  // 填充当前模板
  const bool tplAvalableAbove = xFillCurTemplate<TM_TPL_SIZE, true >((fillCurTpl ? curTplAbove : nullptr)); // 上侧模板可用
  const bool tplAvalableLeft  = xFillCurTemplate<TM_TPL_SIZE, false>((fillCurTpl ? curTplLeft  : nullptr)); // 左侧模板可用
  m_curTplAbove = tplAvalableAbove ? PelBuf(curTplAbove, pu.lwidth(),   TM_TPL_SIZE ) : PelBuf();
  m_curTplLeft  = tplAvalableLeft  ? PelBuf(curTplLeft , TM_TPL_SIZE,   pu.lheight()) : PelBuf();
  // 参考模板
  m_refTplAbove = tplAvalableAbove ? PelBuf(refTplAbove, m_curTplAbove              ) : PelBuf();
  m_refTplLeft  = tplAvalableLeft  ? PelBuf(refTplLeft , m_curTplLeft               ) : PelBuf();
#if JVET_X0056_DMVD_EARLY_TERMINATION
  m_earlyTerminateTh = TM_TPL_SIZE * ((tplAvalableAbove ? m_pu.lwidth() : 0) + (tplAvalableLeft ? m_pu.lheight() : 0));
#endif

  // Pre-interpolate samples on search area 在搜索区域预插样本
  // 上参考模板以及其相邻长度为 8 的搜索范围
  m_refSrAbove = tplAvalableAbove && maxSearchRounds > 0 ? PelBuf(interRes.m_preFillBufA, m_curTplAbove.width + 2 * TM_SEARCH_RANGE, m_curTplAbove.height + 2 * TM_SEARCH_RANGE) : PelBuf();
  if (m_refSrAbove.buf != nullptr)
  {
    m_refSrAbove = xGetRefTemplate<TM_TPL_SIZE, true, TM_SEARCH_RANGE>(m_pu, m_refPic, mvStart, m_refSrAbove);
    m_refSrAbove = m_refSrAbove.subBuf(Position(TM_SEARCH_RANGE, TM_SEARCH_RANGE), m_curTplAbove); // 定位到搜索参考模板的初始位置
  }
  // 左参考模板
  m_refSrLeft  = tplAvalableLeft  && maxSearchRounds > 0 ? PelBuf(interRes.m_preFillBufL, m_curTplLeft .width + 2 * TM_SEARCH_RANGE, m_curTplLeft .height + 2 * TM_SEARCH_RANGE) : PelBuf();
  if (m_refSrLeft.buf != nullptr)
  {
    m_refSrLeft = xGetRefTemplate<TM_TPL_SIZE, false, TM_SEARCH_RANGE>(m_pu, m_refPic, mvStart, m_refSrLeft);
    m_refSrLeft = m_refSrLeft.subBuf(Position(TM_SEARCH_RANGE, TM_SEARCH_RANGE), m_curTplLeft);
  }
}

The xFillCurTemplate function gets the current template: 

template <int tplSize, bool TrueA_FalseL>
bool TplMatchingCtrl::xFillCurTemplate(Pel* tpl)
{
  const Position          posOffset = TrueA_FalseL ? Position(0, -tplSize) : Position(-tplSize, 0); // 位置偏移
  // 相邻CU
  const CodingUnit* const cuNeigh   = m_cu.cs->getCU(m_pu.blocks[m_compID].pos().offset(posOffset), toChannelType(m_compID));

  if (cuNeigh == nullptr) // 相邻CU不可用,直接返回FALSE
  {
    return false;
  }

  if (tpl == nullptr) // 存储模板的指针为空,返回
  {
    return true;
  }

  const Picture&          currPic = *m_cu.cs->picture; // 当前帧
  const CPelBuf           recBuf  = currPic.getRecoBuf(m_cu.cs->picture->blocks[m_compID]); // 当前帧的重建分量
        std::vector<Pel>& invLUT  = m_interRes.m_pcReshape->getInvLUT();
  const bool              useLUT  = isLuma(m_compID) && m_cu.cs->picHeader->getLmcsEnabledFlag() && m_interRes.m_pcReshape->getCTUFlag();
#if JVET_W0097_GPM_MMVD_TM & TM_MRG
  if (m_cu.geoFlag)
  {
    CHECK(m_pu.geoTmType == GEO_TM_OFF, "invalid geo template type value");
    if (m_pu.geoTmType == GEO_TM_SHAPE_A)
    {
      if (TrueA_FalseL == 0)
      {
        return false;
      }
    }
    if (m_pu.geoTmType == GEO_TM_SHAPE_L)
    {
      if (TrueA_FalseL == 1)
      {
        return false;
      }
    }
  }
#endif
  const Size dstSize = (TrueA_FalseL ? Size(m_pu.lwidth(), tplSize) : Size(tplSize, m_pu.lheight()));
  for (int h = 0; h < (int)dstSize.height; h++)
  {
    const Position recPos = TrueA_FalseL ? Position(0, -tplSize + h) : Position(-tplSize, h);
    const Pel*     rec    = recBuf.bufAt(m_pu.blocks[m_compID].pos().offset(recPos));
          Pel*     dst    = tpl + h * dstSize.width;

    for (int w = 0; w < (int)dstSize.width; w++)
    {
      int recVal = rec[w];
      dst[w] = useLUT ? invLUT[recVal] : recVal;
    }
  }

  return true;
}

The xGetRefTemplate function gets the reference template: 

template <int tplSize, bool TrueA_FalseL, int sr>
PelBuf TplMatchingCtrl::xGetRefTemplate(const PredictionUnit& curPu, const Picture& refPic, const Mv& _mv, PelBuf& dstBuf)
{
  // read from pre-interpolated buffer 从预插值缓冲区读取
  PelBuf& refSrBuf = TrueA_FalseL ? m_refSrAbove : m_refSrLeft;
  // sr = 0 直接从预插值的缓冲区读取样本
  if (sr == 0 && refPic.getPOC() == m_refPic.getPOC() && refSrBuf.buf != nullptr)
  {
    Mv mvDiff = _mv - m_mvStart;
    if ((mvDiff.getAbsHor() & ((1 << MV_FRACTIONAL_BITS_INTERNAL) - 1)) == 0 && (mvDiff.getAbsVer() & ((1 << MV_FRACTIONAL_BITS_INTERNAL) - 1)) == 0)
    {
      mvDiff >>= MV_FRACTIONAL_BITS_INTERNAL;
      if (mvDiff.getAbsHor() <= TM_SEARCH_RANGE && mvDiff.getAbsVer() <= TM_SEARCH_RANGE)
      {
        return refSrBuf.subBuf(Position(mvDiff.getHor(), mvDiff.getVer()), dstBuf);
      }
    }
  }

  // Do interpolation on the fly 插值
  Position blkPos  = ( TrueA_FalseL ? Position(curPu.lx(), curPu.ly() - tplSize) : Position(curPu.lx() - tplSize, curPu.ly()) );
  Size     blkSize = Size(dstBuf.width, dstBuf.height);
  Mv       mv      = _mv - Mv(sr << MV_FRACTIONAL_BITS_INTERNAL, sr << MV_FRACTIONAL_BITS_INTERNAL);
  clipMv( mv, blkPos, blkSize, *m_cu.cs->sps, *m_cu.cs->pps );

  const int lumaShift = 2 + MV_FRACTIONAL_BITS_DIFF;
  const int horShift  = (lumaShift + ::getComponentScaleX(m_compID, m_cu.chromaFormat));
  const int verShift  = (lumaShift + ::getComponentScaleY(m_compID, m_cu.chromaFormat));

  const int xInt  = mv.getHor() >> horShift;
  const int yInt  = mv.getVer() >> verShift;
  const int xFrac = mv.getHor() & ((1 << horShift) - 1);
  const int yFrac = mv.getVer() & ((1 << verShift) - 1);

  const CPelBuf refBuf = refPic.getRecoBuf(refPic.blocks[m_compID]);
  const Pel* ref       = refBuf.bufAt(blkPos.offset(xInt, yInt));
        Pel* dst       = dstBuf.buf;
        int  refStride = refBuf.stride;
        int  dstStride = dstBuf.stride;
        int  bw        = (int)blkSize.width;
        int  bh        = (int)blkSize.height;

  const int  nFilterIdx   = 1;
  const bool useAltHpelIf = false;
  const bool biMCForDMVR  = false;

  if ( yFrac == 0 )
  {
    m_interRes.m_if.filterHor( m_compID, (Pel*) ref, refStride, dst, dstStride, bw, bh, xFrac, true, m_cu.chromaFormat, m_cu.slice->clpRng(m_compID), nFilterIdx, biMCForDMVR, useAltHpelIf );
  }
  else if ( xFrac == 0 )
  {
    m_interRes.m_if.filterVer( m_compID, (Pel*) ref, refStride, dst, dstStride, bw, bh, yFrac, true, true, m_cu.chromaFormat, m_cu.slice->clpRng(m_compID), nFilterIdx, biMCForDMVR, useAltHpelIf );
  }
  else
  {
    const int vFilterSize = isLuma(m_compID) ? NTAPS_BILINEAR : NTAPS_CHROMA;
    PelBuf tmpBuf = PelBuf(m_interRes.m_ifBuf, Size(bw, bh+vFilterSize-1));

    m_interRes.m_if.filterHor( m_compID, (Pel*)ref - ((vFilterSize>>1) -1)*refStride, refStride, tmpBuf.buf, tmpBuf.stride, bw, bh+vFilterSize-1, xFrac, false, m_cu.chromaFormat, m_cu.slice->clpRng(m_compID), nFilterIdx, biMCForDMVR, useAltHpelIf );
    JVET_J0090_SET_CACHE_ENABLE( false );
    m_interRes.m_if.filterVer( m_compID, tmpBuf.buf + ((vFilterSize>>1) -1)*tmpBuf.stride, tmpBuf.stride, dst, dstStride, bw, bh, yFrac, false, true, m_cu.chromaFormat, m_cu.slice->clpRng(m_compID), nFilterIdx, biMCForDMVR, useAltHpelIf );
    JVET_J0090_SET_CACHE_ENABLE( true );
  }

  return dstBuf;
}

 Perform one-way MV refinement in the deriveMvUni function:

template <int tplSize>
void TplMatchingCtrl::deriveMvUni()
{
  if (m_minCost == std::numeric_limits<Distortion>::max())
  {
    m_minCost = xGetTempMatchError<tplSize>(m_mvStart); // 计算初始位置处模板的Cost
  }

  if (m_maxSearchRounds <= 0)
  {
    return;
  }
  // 搜索步长
  int searchStepShift = (m_cu.imv == IMV_4PEL ? MV_FRACTIONAL_BITS_INTERNAL + 2 : MV_FRACTIONAL_BITS_INTERNAL);
  xRefineMvSearch<tplSize, TplMatchingCtrl::TMSEARCH_DIAMOND>(m_maxSearchRounds, searchStepShift);
  xRefineMvSearch<tplSize, TplMatchingCtrl::TMSEARCH_CROSS  >(                1, searchStepShift);
  xRefineMvSearch<tplSize, TplMatchingCtrl::TMSEARCH_CROSS  >(                1, searchStepShift - 1);
#if MULTI_PASS_DMVR
  if (!m_pu.bdmvrRefine)
  {
#endif
  xRefineMvSearch<tplSize, TplMatchingCtrl::TMSEARCH_CROSS  >(                1, searchStepShift - 2);
  xRefineMvSearch<tplSize, TplMatchingCtrl::TMSEARCH_CROSS  >(                1, searchStepShift - 3);
#if MULTI_PASS_DMVR
  }
  else
  {
    xDeriveCostBasedMv<TplMatchingCtrl::TMSEARCH_CROSS>();
  }
#endif
}

Guess you like

Origin blog.csdn.net/BigDream123/article/details/122030869