神经网络中的卷积运算解析

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/yudiemiaomiao/article/details/73877519

卷积运算是利用卷积核对图像中的每个像素进行的操作,卷积核实用来做图像处理时的矩阵,图像处理时也称为掩模,是与原图像做运算的参数。卷积核通常是一个四方形的网格结构(例如3*3的矩阵或像素区域),该区域上每个方格都有一个权重值。
使用卷积进行计算时,需要将卷积核的中心位置放置在要计算的像素上,依次计算核中每个元素和其覆盖的图像像素值的乘积并求和,得到的结构就是该位置的新像素值。
在卷积神经网络的计算中,通常为了提高运算效率,会分为两步实现卷积运算:
(1)利用im2col将待卷积运算的(图像)矩阵重排;
(2)利用GEMM实现具体计算;
为了直观理解,可参考以下几张图片:

(图像)矩阵各位置的特征向量:


Cout为滤波器个数,C为通道数,Filter Matrix乘以Feature Matrix的转置,得到输出矩阵Cout x (H x W),就可以解释为输出的三维Blob(Cout x H x W)。

使用im2col的方法将卷积转为矩阵相乘,例图如下:

caffe中的计算源码,目前很多模型均采用这种方式计算卷积:

float im2col_get_pixel(float *im, int height, int width, int channels,
                    int row, int col, int channel, int pad)
{
row -= pad;
col -= pad;

if (row < 0 || col < 0 ||
    row >= height || col >= width) return 0;
return im[col + width*(row + height*channel)];
}

//From Berkeley Vision's Caffe!
//https://github.com/BVLC/caffe/blob/master/LICENSE
void im2col_cpu(float* data_im,
 int channels,  int height,  int width,
 int ksize,  int stride, int pad, float* data_col) 
{
int c,h,w;
int height_col = (height + 2*pad - ksize) / stride + 1;
int width_col = (width + 2*pad - ksize) / stride + 1;

int channels_col = channels * ksize * ksize;
//最外层循环是每个卷积核的参数个数
for (c = 0; c < channels_col; ++c) {
    int w_offset = c % ksize;
    int h_offset = (c / ksize) % ksize;
    int c_im = c / ksize / ksize;
    //以下两层循环是用卷积核将图像遍历一遍
    for (h = 0; h < height_col; ++h) {
        for (w = 0; w < width_col; ++w) {
            int im_row = h_offset + h * stride;
            int im_col = w_offset + w * stride;
            int col_index = (c * height_col + h) * width_col + w;
            data_col[col_index] = im2col_get_pixel(data_im, height, width, channels,
                    im_row, im_col, c_im, pad);
        }
     } 
  }
}

//矩阵计算
void gemm(int TA, int TB, int M, int N, int K, float ALPHA, 
    float *A, int lda, 
    float *B, int ldb,
    float BETA,
    float *C, int ldc)
{
gemm_cpu( TA,  TB,  M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc);
}

void gemm_nn(int M, int N, int K, float ALPHA, 
    float *A, int lda, 
    float *B, int ldb,
    float *C, int ldc)
{
  int i,j,k;
  for(i = 0; i < M; ++i){
    for(k = 0; k < K; ++k){
        register float A_PART = ALPHA*A[i*lda+k];
     for (j = 0; j < N; ++j){
            C[i*ldc+j] += A_PART*B[k*ldb+j];
        }
      }
    }
}

void gemm_nt(int M, int N, int K, float ALPHA, 
    float *A, int lda, 
    float *B, int ldb,
    float *C, int ldc)
{
  int i,j,k;
  //M=batch,每个样本有N(yolo.train.cfg中是1715=S×S×(B∗5+C))个输出
  for(i = 0; i < M; ++i){
    for(j = 0; j < N; ++j){
        register float sum = 0;
        //K是inputs,即输入个数
        for(k = 0; k < K; ++k){
        //输入项和权重项对应相乘相加
            sum += ALPHA*A[i*lda+k]*B[j*ldb + k];
        }
        C[i*ldc+j] += sum;
     }
   }
}

void gemm_tn(int M, int N, int K, float ALPHA, 
    float *A, int lda, 
    float *B, int ldb,
    float *C, int ldc)
{
 int i,j,k;
 for(i = 0; i < M; ++i){
    for(k = 0; k < K; ++k){
        register float A_PART = ALPHA*A[k*lda+i];
        for(j = 0; j < N; ++j){
            C[i*ldc+j] += A_PART*B[k*ldb+j];
        }
     }
  }
}

void gemm_tt(int M, int N, int K, float ALPHA, 
    float *A, int lda, 
    float *B, int ldb,
    float *C, int ldc)
{
 int i,j,k;
 for(i = 0; i < M; ++i){
    for(j = 0; j < N; ++j){
        register float sum = 0;
        for(k = 0; k < K; ++k){
            sum += ALPHA*A[i+k*lda]*B[k+j*ldb];
        }
        C[i*ldc+j] += sum;
     }
  }
}


void gemm_cpu(int TA, int TB, int M, int N, int K, float ALPHA, 
    float *A, int lda, 
    float *B, int ldb,
    float BETA,
    float *C, int ldc)
{
//printf("cpu: %d %d %d %d %d %f %d %d %f %d\n",TA, TB, M, N, K, ALPHA, lda, ldb, BETA, ldc);
int i, j;
for(i = 0; i < M; ++i){
    for(j = 0; j < N; ++j){
        C[i*ldc + j] *= BETA;
    }
}
if(!TA && !TB)
    gemm_nn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
else if(TA && !TB)
    gemm_tn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
else if(!TA && TB)
    gemm_nt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
else
    gemm_tt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
}

猜你喜欢

转载自blog.csdn.net/yudiemiaomiao/article/details/73877519