H.266/VVC代码学习:xRecurIntraCodingLumaQT函数

xRecurIntraCodingLumaQT函数主要是用来遍历上层函数传来的变换集,通过调用xIntraCodingTUBlock计算相应的失真,从而计算RD Cost,从而选出最佳变换。主要流程如下:

  • 确定当前分区是否可以继续划分,如果需要划分则递归调用自身。
  • 确定候选模式变换集
  • 遍历候选模式遍历集,调用xIntraCodingTUBlock函数计算残差、进行变换、量化、重建。
  • 计算码率和RD Cost

具体代码注释如下:

bool IntraSearch::xRecurIntraCodingLumaQT( CodingStructure &cs, Partitioner &partitioner, const double bestCostSoFar, const int subTuIdx, const PartSplit ispType, const bool ispIsCurrentWinner, bool mtsCheckRangeFlag, int mtsFirstCheckId, int mtsLastCheckId, bool moreProbMTSIdxFirst )
{
        int   subTuCounter = subTuIdx;
  const UnitArea &currArea = partitioner.currArea();
  const CodingUnit     &cu = *cs.getCU( currArea.lumaPos(), partitioner.chType );
        bool  earlySkipISP = false;
  uint32_t currDepth       = partitioner.currTrDepth;
  const SPS &sps           = *cs.sps;
  const PPS &pps           = *cs.pps;
  const bool keepResi      = pps.getPpsRangeExtension().getCrossComponentPredictionEnabledFlag() || KEEP_PRED_AND_RESI_SIGNALS;
  //是否可以继续划分
  bool bCheckFull          = true;
  bool bCheckSplit         = false;
  bCheckFull               = !partitioner.canSplit( TU_MAX_TR_SPLIT, cs );
  bCheckSplit              = partitioner.canSplit( TU_MAX_TR_SPLIT, cs );

  if( cu.ispMode )
  {
    bCheckSplit = partitioner.canSplit( ispType, cs );
    bCheckFull = !bCheckSplit;
  }
  uint32_t    numSig           = 0;

  double     dSingleCost                        = MAX_DOUBLE;
  Distortion uiSingleDistLuma                   = 0;
  uint64_t   singleFracBits                     = 0;
  bool       checkTransformSkip                 = sps.getTransformSkipEnabledFlag();//检查是否跳过变换
  int        bestModeId[ MAX_NUM_COMPONENT ]    = { 0, 0, 0 };
  //cu.mtsflag在xCheckRDCostIntra函数里控制,如果为0表示一次变换为DCT-2,否则为DST7与DCT8的四种组合
  uint8_t    nNumTransformCands                 = cu.mtsFlag ? 4 : 1;//需要检查的变换类型数,1表示DCT-2,4表示MTS的四种变换核
  uint8_t    numTransformIndexCands             = nNumTransformCands;
 
  //初始化上下文
  const TempCtx ctxStart  ( m_CtxCache, m_CABACEstimator->getCtx() );
  TempCtx       ctxBest   ( m_CtxCache );

  CodingStructure *csSplit = nullptr;
  CodingStructure *csFull  = nullptr;

#if JVET_P1026_ISP_LFNST_COMBINATION
  CUCtx cuCtx;
  cuCtx.isDQPCoded = true;
  cuCtx.isChromaQpAdjCoded = true;
#endif

  if( bCheckSplit )
  {
    csSplit = &cs;
  }
  else if( bCheckFull )
  {
    csFull = &cs;
  }

  bool validReturnFull = false;

  if( bCheckFull )//不能再继续划分
  {
    csFull->cost = 0.0;//初始化Cost为0

    TransformUnit &tu = csFull->addTU( CS::getArea( *csFull, currArea, partitioner.chType ), partitioner.chType );//获取TU
    tu.depth = currDepth;

    const bool tsAllowed  = TU::isTSAllowed( tu, COMPONENT_Y );//是否允许变换跳过
#if JVET_P1026_MTS_SIGNALLING
    const bool mtsAllowed = CU::isMTSAllowed( cu, COMPONENT_Y );//是否允许MTS
#else
    const bool mtsAllowed = TU::isMTSAllowed( tu, COMPONENT_Y );
#endif
    std::vector<TrMode> trModes;

    if( sps.getUseLFNST() )//SPS层开启LFNST
    {
	  //判断是否跳过变换
      checkTransformSkip &= tsAllowed;
      checkTransformSkip &= !cu.mtsFlag;//mtsflag为0
      checkTransformSkip &= !cu.lfnstIdx;//lfnstIdx为0,即不使用LFNST变换

      if( !cu.mtsFlag && checkTransformSkip )
      {
        trModes.push_back( TrMode( 0, true ) ); //DCT2
        trModes.push_back( TrMode( 1, true ) ); //TS
      }
    }
    else//SPS层关闭LFNST
    {
      nNumTransformCands = 1 + ( tsAllowed ? 1 : 0 ) + ( mtsAllowed ? 4 : 0 ); // DCT + TS + 4 MTS = 6 tests

      trModes.push_back( TrMode( 0, true ) ); //DCT2
      if( tsAllowed )//如果允许TS
      {
        trModes.push_back( TrMode( 1, true ) );//TS
      }
      if( mtsAllowed )//如果允许MTS
      {
        for( int i = 2; i < 6; i++ )
        {
          trModes.push_back( TrMode( i, true ) );//DST7和DCT8的四种组合
        }
      }
    }

    CHECK( !tu.Y().valid(), "Invalid TU" );

    CodingStructure &saveCS = *m_pSaveCS[0];

    TransformUnit *tmpTU = nullptr;

    Distortion singleDistTmpLuma = 0;//亮度失真
    uint64_t     singleTmpFracBits = 0;
    double     singleCostTmp     = 0;//RD Cost
	//第一个检查的变换
    int        firstCheckId      = ( sps.getUseLFNST() && mtsCheckRangeFlag && cu.mtsFlag ) ? mtsFirstCheckId : 0;

    //we add the MTS candidates to the loop. TransformSkip will still be the last one to be checked (when modeId == lastCheckId) as long as checkTransformSkip is true
    //我们将MTS候选者添加到循环中。 只要checkTransformSkip为true,TransformSkip仍将是要检查的最后一个(当modeId == lastCheckId时)
	int        lastCheckId       = sps.getUseLFNST() ? ( ( mtsCheckRangeFlag && cu.mtsFlag ) ? ( mtsLastCheckId + ( int ) checkTransformSkip ) : ( numTransformIndexCands - ( firstCheckId + 1 ) + ( int ) checkTransformSkip ) ) :
                                   trModes[ nNumTransformCands - 1 ].first;
	//不仅仅检查一种模式
    bool isNotOnlyOneMode        = sps.getUseLFNST() ? lastCheckId != firstCheckId : nNumTransformCands != 1;

    if( isNotOnlyOneMode )
    {
      saveCS.pcv     = cs.pcv;
      saveCS.picture = cs.picture;
      saveCS.area.repositionTo(cs.area);
      saveCS.clearTUs();
      tmpTU = &saveCS.addTU(currArea, partitioner.chType);
    }

    bool    cbfBestMode      = false;
    bool    cbfBestModeValid = false;
    bool    cbfDCT2  = true;

    double bestDCT2cost = MAX_DOUBLE;
    double threshold = m_pcEncCfg->getUseFastISP() && !cu.ispMode && ispIsCurrentWinner && nNumTransformCands > 1 ? 1 + 1.4 / sqrt( cu.lwidth() * cu.lheight() ) : 1;
    
	//遍历所有的变换模式
	for( int modeId = firstCheckId; modeId <= ( sps.getUseLFNST() ? lastCheckId : ( nNumTransformCands - 1 ) ); modeId++ )
    {
      uint8_t transformIndex = modeId;

      if( sps.getUseLFNST() )//SPS层开启LFNST
      {
		//如果模式为transformSkip,我们将避免这种情况
        if( ( transformIndex < lastCheckId ) || ( ( transformIndex == lastCheckId ) && !checkTransformSkip ) ) //we avoid this if the mode is transformSkip
        {
          // Skip checking other transform candidates if zero CBF is encountered and it is the best transform so far
          // 如果遇到零CBF,则跳过检查其他变换候选对象,这是迄今为止最好的变换
			if( m_pcEncCfg->getUseFastLFNST() && transformIndex && !cbfBestMode && cbfBestModeValid )
          {
            continue;
          }
        }
      }
      else//SPS层关闭LFNST
      {
#if JVET_AHG14_LOSSLESS
        if( !( m_pcEncCfg->getCostMode() == COST_LOSSLESS_CODING ) )
        {
#endif
#if JVET_P0058_CHROMA_TS
        if( !cbfDCT2 || ( m_pcEncCfg->getUseTransformSkipFast() && bestModeId[ COMPONENT_Y ] == MTS_SKIP))
#else
        if( !cbfDCT2 || ( m_pcEncCfg->getUseTransformSkipFast() && bestModeId[ COMPONENT_Y ] == 1 ) )
#endif
        {
          break;
        }
        if( !trModes[ modeId ].second )
        {
          continue;
        }
        //we compare the DCT-II cost against the best ISP cost so far (except for TS)
		//我们将DCT-II成本与迄今为止最好的ISP成本进行了比较(TS除外)
#if JVET_P0058_CHROMA_TS
        if (m_pcEncCfg->getUseFastISP() && !cu.ispMode && ispIsCurrentWinner && trModes[modeId].first != MTS_DCT2_DCT2 && (trModes[modeId].first != MTS_SKIP || !tsAllowed) && bestDCT2cost > bestCostSoFar * threshold)
#else
        if( m_pcEncCfg->getUseFastISP() && !cu.ispMode && ispIsCurrentWinner && trModes[ modeId ].first != 0 && ( trModes[ modeId ].first != 1 || !tsAllowed ) && bestDCT2cost > bestCostSoFar * threshold )
#endif
        {
          continue;
        }
#if JVET_AHG14_LOSSLESS
        }
#endif
#if JVET_P0058_CHROMA_TS
        tu.mtsIdx[COMPONENT_Y] = trModes[modeId].first;
#else
        tu.mtsIdx = trModes[ modeId ].first;
#endif
      }//SPS层关闭LFNST


      if ((modeId != firstCheckId) && isNotOnlyOneMode)
      {
        m_CABACEstimator->getCtx() = ctxStart;
      }

      int default0Save1Load2 = 0;
      singleDistTmpLuma = 0;

      if( modeId == firstCheckId && ( sps.getUseLFNST() ? ( modeId != lastCheckId ) : ( nNumTransformCands > 1 ) ) )
      {
        default0Save1Load2 = 1;
      }
      else if (modeId != firstCheckId)
      {
        if( sps.getUseLFNST() && !cbfBestModeValid )
        {
          default0Save1Load2 = 1;
        }
        else
        {
          default0Save1Load2 = 2;
        }
      }
      if( cu.ispMode )
      {
        default0Save1Load2 = 0;
      }

      if( sps.getUseLFNST() )//SPS层开启LFNST
      {
        if( cu.mtsFlag )//MTS四种变换核
        {
          if( moreProbMTSIdxFirst )
          {
            const ChannelType     chType      = toChannelType( COMPONENT_Y );
            const CompArea&       area        = tu.blocks[ COMPONENT_Y ];
            const PredictionUnit& pu          = *cs.getPU( area.pos(), chType );
            uint32_t              uiIntraMode = pu.intraDir[ chType ];//获得预测模式

            if( transformIndex == 1 )
            {
#if JVET_P0058_CHROMA_TS
              tu.mtsIdx[COMPONENT_Y] = (uiIntraMode < 34) ? MTS_DST7_DCT8 : MTS_DCT8_DST7;
#else
              tu.mtsIdx = ( uiIntraMode < 34 ) ? MTS_DST7_DCT8 : MTS_DCT8_DST7;
#endif
            }
            else if( transformIndex == 2 )
            {
#if JVET_P0058_CHROMA_TS
              tu.mtsIdx[COMPONENT_Y] = (uiIntraMode < 34) ? MTS_DCT8_DST7 : MTS_DST7_DCT8;
#else
              tu.mtsIdx = ( uiIntraMode < 34 ) ? MTS_DCT8_DST7 : MTS_DST7_DCT8;
#endif
            }
            else
            {
#if JVET_P0058_CHROMA_TS
              tu.mtsIdx[COMPONENT_Y] = MTS_DST7_DST7 + transformIndex;
#else
              tu.mtsIdx = MTS_DST7_DST7 + transformIndex;
#endif
            }
          }//if( moreProbMTSIdxFirst )
          else
          {
#if JVET_P0058_CHROMA_TS
            tu.mtsIdx[COMPONENT_Y] = MTS_DST7_DST7 + transformIndex;
#else
            tu.mtsIdx = MTS_DST7_DST7 + transformIndex;
#endif
          }
        }//if( cu.mtsFlag )
        else//此时应该为DCT-2
        {
#if JVET_P0058_CHROMA_TS
          tu.mtsIdx[COMPONENT_Y] = transformIndex;
#else
          tu.mtsIdx = transformIndex;
#endif
        }

        if( !cu.mtsFlag && checkTransformSkip )//如果次数为DCT-2且允许TS,则比较DCT-2和TS
        {
          xIntraCodingTUBlock( tu, COMPONENT_Y, false, singleDistTmpLuma, default0Save1Load2, &numSig, modeId == 0 ? &trModes : nullptr, true );
          if( modeId == 0 )
          {
            for( int i = 0; i < 2; i++ )
            {
              if( trModes[ i ].second )
              {
                lastCheckId = trModes[ i ].first;
              }
            }
          }
        }
        else//直接进行预测、变换、量化,并计算失真
        {
          xIntraCodingTUBlock( tu, COMPONENT_Y, false, singleDistTmpLuma, default0Save1Load2, &numSig );
        }
      }//SPS层关闭LFNST
      else
      {
        if( nNumTransformCands > 1 )//如果变换选项大于1
        {
          xIntraCodingTUBlock( tu, COMPONENT_Y, false, singleDistTmpLuma, default0Save1Load2, &numSig, modeId == 0 ? &trModes : nullptr, true );
          if( modeId == 0 )
          {
            for( int i = 0; i < nNumTransformCands; i++ )
            {
              if( trModes[ i ].second )
              {
                lastCheckId = trModes[ i ].first;
              }
            }
          }
        }
        else
        {
          xIntraCodingTUBlock( tu, COMPONENT_Y, false, singleDistTmpLuma, default0Save1Load2, &numSig );
        }
      }

      //----- determine rate and r-d cost -----
	  //计算速率和RD Cost
      if( ( sps.getUseLFNST() ? ( modeId == lastCheckId && modeId != 0 && checkTransformSkip ) : ( trModes[ modeId ].first != 0 ) ) && !TU::getCbfAtDepth( tu, COMPONENT_Y, currDepth ) )
      {
        //In order not to code TS flag when cbf is zero, the case for TS with cbf being zero is forbidden.
		//为了在cbf为零时不对TS标志进行编码,禁止cbf为零的TS的情况。
        singleCostTmp = MAX_DOUBLE;
      }
      else
      {
        if( cu.ispMode && m_pcRdCost->calcRdCost( csFull->fracBits, csFull->dist + singleDistTmpLuma ) > bestCostSoFar )
        {
          earlySkipISP = true;
        }
        else
        {
#if JVET_P1026_ISP_LFNST_COMBINATION
          singleTmpFracBits = xGetIntraFracBitsQT( *csFull, partitioner, true, false, subTuCounter, ispType, &cuCtx );
#else
          singleTmpFracBits = xGetIntraFracBitsQT( *csFull, partitioner, true, false, subTuCounter, ispType );
#endif
        }
        singleCostTmp     = m_pcRdCost->calcRdCost( singleTmpFracBits, singleDistTmpLuma );//计算RD Cost
      }

      if ( !cu.ispMode && nNumTransformCands > 1 && modeId == firstCheckId )
      {
        bestDCT2cost = singleCostTmp;//保存最佳dct2的Cost
      }

      if (singleCostTmp < dSingleCost)//如果当前RD Cost小于保存的RD Cost
      {
        dSingleCost       = singleCostTmp;
        uiSingleDistLuma  = singleDistTmpLuma;
        singleFracBits    = singleTmpFracBits;

        if( sps.getUseLFNST() )
        {//如果开启LFNST
          bestModeId[ COMPONENT_Y ] = modeId;//保存目前的变换模式
          cbfBestMode = TU::getCbfAtDepth( tu, COMPONENT_Y, currDepth );
          cbfBestModeValid = true;
          validReturnFull = true;
        }
        else
        {
          bestModeId[ COMPONENT_Y ] = trModes[ modeId ].first;//保存目前的变换模式
          if( trModes[ modeId ].first == 0 )
          {
            cbfDCT2 = TU::getCbfAtDepth( tu, COMPONENT_Y, currDepth );
          }
        }

        if( bestModeId[COMPONENT_Y] != lastCheckId )
        {
          saveCS.getPredBuf( tu.Y() ).copyFrom( csFull->getPredBuf( tu.Y() ) );//获取预测像素
          saveCS.getRecoBuf( tu.Y() ).copyFrom( csFull->getRecoBuf( tu.Y() ) );//获取重建像素

          if( keepResi )
          {
            saveCS.getResiBuf   ( tu.Y() ).copyFrom( csFull->getResiBuf   ( tu.Y() ) );
            saveCS.getOrgResiBuf( tu.Y() ).copyFrom( csFull->getOrgResiBuf( tu.Y() ) );
          }

          tmpTU->copyComponentFrom( tu, COMPONENT_Y );

          ctxBest = m_CABACEstimator->getCtx();
        }
      }//if (singleCostTmp < dSingleCost)
    }//modeId loop

    if( sps.getUseLFNST() && !validReturnFull )
    {
      csFull->cost = MAX_DOUBLE;

      if( bCheckSplit )
      {
        ctxBest = m_CABACEstimator->getCtx();
      }
    }
    else
    {
      if( bestModeId[COMPONENT_Y] != lastCheckId )
      {
        csFull->getPredBuf( tu.Y() ).copyFrom( saveCS.getPredBuf( tu.Y() ) );
        csFull->getRecoBuf( tu.Y() ).copyFrom( saveCS.getRecoBuf( tu.Y() ) );

        if( keepResi )
        {
          csFull->getResiBuf   ( tu.Y() ).copyFrom( saveCS.getResiBuf   ( tu.Y() ) );
          csFull->getOrgResiBuf( tu.Y() ).copyFrom( saveCS.getOrgResiBuf( tu.Y() ) );
        }

        tu.copyComponentFrom( *tmpTU, COMPONENT_Y );

        if( !bCheckSplit )
        {
          m_CABACEstimator->getCtx() = ctxBest;
        }
      }
      else if( bCheckSplit )
      {
        ctxBest = m_CABACEstimator->getCtx();
      }

      csFull->cost     += dSingleCost;
      csFull->dist     += uiSingleDistLuma;
      csFull->fracBits += singleFracBits;
    }
  }

  bool validReturnSplit = false;
  if( bCheckSplit )//如果还可以继续划分
  {
    //----- store full entropy coding status, load original entropy coding status -----
	//存储完整的熵编码状态,加载原始熵编码状态
    if( bCheckFull )
    {
      m_CABACEstimator->getCtx() = ctxStart;
    }
    //----- code splitted block -----
    csSplit->cost = 0;

    bool uiSplitCbfLuma  = false;
    bool splitIsSelected = true;
    if( partitioner.canSplit( TU_MAX_TR_SPLIT, cs ) )
    {
      partitioner.splitCurrArea( TU_MAX_TR_SPLIT, cs );
    }

    if( cu.ispMode )
    {
      partitioner.splitCurrArea( ispType, *csSplit );
    }
    do//循环遍历所有的子分区
    {
      bool tmpValidReturnSplit = xRecurIntraCodingLumaQT( *csSplit, partitioner, bestCostSoFar, subTuCounter, ispType, false, mtsCheckRangeFlag, mtsFirstCheckId, mtsLastCheckId );
      subTuCounter += subTuCounter != -1 ? 1 : 0;
      if( sps.getUseLFNST() && !tmpValidReturnSplit )
      {
        splitIsSelected = false;
        break;
      }

      if( !cu.ispMode )
      {
        csSplit->setDecomp( partitioner.currArea().Y() );
      }
      else if( CU::isISPFirst( cu, partitioner.currArea().Y(), COMPONENT_Y ) )
      {
        csSplit->setDecomp( cu.Y() );
      }

      uiSplitCbfLuma |= TU::getCbfAtDepth( *csSplit->getTU( partitioner.currArea().lumaPos(), partitioner.chType, subTuCounter - 1 ), COMPONENT_Y, partitioner.currTrDepth );
      if( cu.ispMode )
      {
        //exit condition if the accumulated cost is already larger than the best cost so far (no impact in RD performance)
        if( csSplit->cost > bestCostSoFar )
        {
          earlySkipISP    = true;
          splitIsSelected = false;
          break;
        }
        else
        {
          //more restrictive exit condition
          bool tuIsDividedInRows = CU::divideTuInRows( cu );
          int nSubPartitions = tuIsDividedInRows ? cu.lheight() >> floorLog2(cu.firstTU->lheight()) : cu.lwidth() >> floorLog2(cu.firstTU->lwidth());
          double threshold = nSubPartitions == 2 ? 0.95 : subTuCounter == 1 ? 0.83 : 0.91;
          if( subTuCounter < nSubPartitions && csSplit->cost > bestCostSoFar*threshold )
          {
            earlySkipISP    = true;
            splitIsSelected = false;
            break;
          }
        }
      }



    } while( partitioner.nextPart( *csSplit ) );

    partitioner.exitCurrSplit();

    if( splitIsSelected )
    {
      for( auto &ptu : csSplit->tus )
      {
        if( currArea.Y().contains( ptu->Y() ) )
        {
          TU::setCbfAtDepth( *ptu, COMPONENT_Y, currDepth, uiSplitCbfLuma ? 1 : 0 );
        }
      }

      //----- restore context states -----
      m_CABACEstimator->getCtx() = ctxStart;

#if JVET_P1026_ISP_LFNST_COMBINATION
      cuCtx.violatesLfnstConstrained[CHANNEL_TYPE_LUMA] = false;
      cuCtx.violatesLfnstConstrained[CHANNEL_TYPE_CHROMA] = false;
      cuCtx.lfnstLastScanPos = false;
#if JVET_P1026_MTS_SIGNALLING
      cuCtx.violatesMtsCoeffConstraint = false;
#endif
#endif

      //----- determine rate and r-d cost -----
#if JVET_P1026_ISP_LFNST_COMBINATION
      csSplit->fracBits = xGetIntraFracBitsQT( *csSplit, partitioner, true, false, cu.ispMode ? 0 : -1, ispType, &cuCtx );
#else
      csSplit->fracBits = xGetIntraFracBitsQT( *csSplit, partitioner, true, false, cu.ispMode ? 0 : -1, ispType );
#endif

      //--- update cost ---
      csSplit->cost     = m_pcRdCost->calcRdCost(csSplit->fracBits, csSplit->dist);

      validReturnSplit = true;
    }
  }//if( bCheckSplit )

  bool retVal = false;
  if( csFull || csSplit )
  {
    if( !sps.getUseLFNST() || validReturnFull || validReturnSplit )
    {
      {
        // otherwise this would've happened in useSubStructure
        cs.picture->getRecoBuf( currArea.Y() ).copyFrom( cs.getRecoBuf( currArea.Y() ) );//获取重建像素
        cs.picture->getPredBuf( currArea.Y() ).copyFrom( cs.getPredBuf( currArea.Y() ) );//获取预测像素
      }

      if( cu.ispMode && earlySkipISP )
      {
        cs.cost = MAX_DOUBLE;
      }
      else
      {
        cs.cost = m_pcRdCost->calcRdCost( cs.fracBits, cs.dist );//保存RD Cost
        retVal = true;
      }
    }
  }
  return retVal;
}
发布了80 篇原创文章 · 获赞 101 · 访问量 5万+

猜你喜欢

转载自blog.csdn.net/BigDream123/article/details/104215730