纯C++超分辨率重建LapSRN --改编--(三)转置卷积vl_nnconvt函数

接前文,

转置卷积vl_nnconvt:

void vl_nnconvt(卷积层 const *indata,卷积层 *out,层数据 *filters_biases,//残差放大
	int upsampleX = 1 ,  int upsampleY = 1 ,
  int cropLeft = 0 ,  int cropRight = 0 ,  int cropTop = 0 ,  int cropBottom = 0 )
{
	vl_nnconvt_forward(
                                out,
                                indata,
								filters_biases,
                                //filters,
                                //biases,
                                upsampleY, upsampleX,
                                cropTop, cropBottom, cropLeft, cropRight) ;  
}

vl_nnconvt_forward 分两处分别处理核 和偏置:

void vl_nnconvt_forward(
                    卷积层 * output,
                    卷积层 const *data,
                    层数据 *filters_biases,
                    int upsampleY, int upsampleX,
                    int cropTop, int cropBottom,
                    int cropLeft, int cropRight)
{
      vl_impl_nnconv_backward_blas_CPU_float //核
      (
                    output,
                    data,
                    filters_biases,
                                 upsampleY, upsampleX,
                                 cropTop, cropBottom,
                                 cropLeft, cropRight) ;

	if (filters_biases->偏移长度 > 0) {//偏置
		vl_nnbias_forward(
                                output, 1,
                                 0,
                                filters_biases, 1) ;
  }
}

这个vl_impl_nnconv_backward_blas_CPU_float就是前面的哪个反向:

void vl_impl_nnconv_backward_blas_CPU_float(

		卷积层 * derData,
		卷积层 const *derOutput,
		层数据 *filters_biases,
		int strideY, int strideX,
		int padTop, int padBottom,
		int padLeft, int padRight)
{

  int numGroups = 0 ;
  int numFiltersPerGroup = 0 ;
  int filtersVolume = 0 ;
  int tempVolume = 0 ;
  float* tempMemory = NULL ;

  //int numOutputPixels = derOutput.getHeight() * derOutput.getWidth() ;
  int numOutputPixels = derOutput->height * derOutput->width ;
  
  int filters_getDepth=0, filters_getHeight=0, filters_getWidth=0;
  if(filters_biases->权重长度==36864)
  {
	  filters_getHeight=3;
	  filters_getWidth=3;
	  filters_getDepth=64;

  }else if(filters_biases->权重长度==16)
  {
	  filters_getHeight=4;
	  filters_getWidth=4;
	  filters_getDepth=1;

  }

  if (derData) {

	  //numGroups = derData.getDepth() / filters.getDepth() ;
    numGroups = derData->depth / filters_getDepth ;

    //filtersVolume = filters.getHeight() * filters.getWidth() * filters.getDepth() ;
    filtersVolume = filters_getHeight * filters_getWidth * filters_getDepth ;
  }
  //numFiltersPerGroup = derOutput.getDepth() / numGroups ;
  numFiltersPerGroup = derOutput->depth / numGroups ;

  // 获得临时空间
  tempVolume = numOutputPixels * filtersVolume * numGroups ;
  if (tempVolume) {
    //tempMemory = (float*) context.getWorkspace(CPU, tempVolume * sizeof(float)) ;
    tempMemory = new float[tempVolume * sizeof(float)];
    if (tempMemory == NULL) {
		printf("分配内存错误!\n");
      goto done ;
    }
  }



    /* compute derData dz/dx */
    if (derData) 
	{
				//int filterGrpOffset = filtersVolume * numFiltersPerGroup * g ;
				//int tempGrpOffset = numOutputPixels * filtersVolume * g ;
				//int derOutputGrpOffset = numOutputPixels * numFiltersPerGroup * g  ;
				float alpha = 1 ;
				float beta = 0 ;
				//printf("gemm<CPU,float>\n");
				gemm(
										'n', 't',
										numOutputPixels, filtersVolume, numFiltersPerGroup,
										alpha,
										(float*)derOutput->data /*.getMemory() + derOutputOffset + derOutputGrpOffset*/, numOutputPixels,
										(float*)filters_biases->权重_数据/*filters.getMemory() + filterGrpOffset*/, filtersVolume,
										beta,
										tempMemory /*+ tempGrpOffset*/, numOutputPixels) ;
	 {
		 //在vl_impl_row2im中对数据进行转置,这里要还原

	  
		  vl_impl_row2im(
											  (float*)derData->data/*.getMemory() + derDataOffset*/,  //;tmpderData
											  tempMemory,
											  //derData->height, derData->width, derData->depth,//derData.getHeight(), derData.getWidth(), derData.getDepth(),
											  derData->width, derData->height, derData->depth,////宽、高倒置
											  filters_getHeight,filters_getWidth,//filters.getHeight(), filters.getWidth(),
											  strideY, strideX,
											  padTop, padBottom, padLeft, padRight) ;

	 }
    }


done:
	  if (tempMemory != NULL) {
        delete []tempMemory;  tempMemory=NULL;  
	  }
}

vl_impl_row2im:

void vl_impl_row2im(
                                 float* data,
                                 float const* stacked,
                                 size_t height, size_t width, size_t depth,
                                 size_t windowHeight, size_t windowWidth,
                                 size_t strideY, size_t strideX,
                                 size_t padTop, size_t padBottom, size_t padLeft, size_t padRight)
{
	//由于row2im_cpu是为matlab(数据按列排)设计的,所以先把data转置一下
	//data中没有数据,转置省略


  row2im_cpu(data,stacked ,//
                    height, width, depth,
                    windowHeight, windowWidth,
                    strideY, strideX,
                    padTop, padBottom, padLeft, padRight) ;
}

row2im_cpu:

static inline void
row2im_cpu(float* data,
           float const* stacked,
           size_t width,
           size_t height,
           size_t depth,
           size_t windowWidth,
           size_t windowHeight,
           size_t strideX,
           size_t strideY,
           size_t padLeft,
           size_t padRight,
           size_t padTop,
           size_t padBottom)//转置

   //         size_t height, 
			//size_t width, 
			//size_t depth,
   //         size_t windowHeight, 
			//size_t windowWidth,
   //         size_t strideY, 
			//size_t strideX,
   //         size_t padTop, 
			//size_t padBottom, 
			//size_t padLeft, 
			//size_t padRight) //无转置
{//printf("row2im_cpu\n");
  int numPatchesX = (width + (padLeft + padRight) - windowWidth)/strideX + 1 ;
  int numPatchesY = (height + (padTop + padBottom) - windowHeight)/strideY + 1 ;
  int numRows = windowWidth * windowHeight * depth ;

  memset(data, 0, sizeof(float) * width * height * depth) ;

  /*
	与im2col相反,仍扫描堆叠图像的行。
	有关算法的说明,请参见im2col的注释。
  */
  for (int row = 0; row < numRows ; ++row) {
    int u = row ;
    int v = u / windowWidth ;
    int z = v / windowHeight ;
    u %= windowWidth ;
    v %= windowHeight ;

    int x0 = static_min(numPatchesX, ceil_divide(padLeft - u, strideX)) ;
    int y0 = static_min(numPatchesY, ceil_divide(padTop - v, strideY)) ;
    int x1 = static_min(numPatchesX, floor_divide(width-1 + padLeft - u, strideX) + 1) ;
    int y1 = static_min(numPatchesY, floor_divide(height-1 + padTop - v, strideY) + 1) ;
    int x ;
    int y ;

    y = static_max(0, y0) ;
    stacked += numPatchesX * static_max(y, 0) ;
    for ( ; y < y1 ; ++y) {
      x = static_max(0, x0) ;
      int y_data = y * strideY + v - padTop ;
      int x_data = x * strideX + u - padLeft ;
      float * b = data + (z * height + y_data) * width + x_data ;
      stacked += x ;
      for ( ; x < x1 ; ++x) {
        *b += *stacked++ ;
        b += strideX ;
      }
      stacked += numPatchesX - x ;
    }
    stacked += numPatchesX * (numPatchesY - y) ;
  }
}

vl_nnconvt函数已完成。

猜你喜欢

转载自blog.csdn.net/juebai123/article/details/81539478