caffe 源码分析【一】: Blob类

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/xiaoxu2050/article/details/82864019
  • Blob类的:    

//头文件: include\caffe\blob.hpp
//cpp文件: src\caffe\blob.cpp
//cu文件: src/caffe/blob.cu
//定义某layer的输入blobs
const vector<Blob<Dtype> *> bottom;

//定义某layer输出blobs
const vector<Blob<Dtype> *> top;

//获取blob中不可变的数据指针
const Dtype* bottom_data = bottom[0]->cpu_data();

//获取blob中可变数据指针
Dtype* top_data = top[0]->mutable_cpu_data();

//获取blob中不可变的梯度指针
const Dtype * top_diff = top[0]->cpu_diff();

//获取blob中可变的梯度指针
Dtype * bottom_diff = bottom[0]->bottom_diff();

//获取blob中数据单元的数量, 等于 Batch * C * H * W
const int count = bottom[0]->count();

//获取BatchSize大小
const int num = bottom[0]->num();

//获取通道数
const int channels = bottom[0]->channels();

//获取图片高度
const int height = bottom[0]->height();

//获取图片宽度
const int width = bottom[0]->width();

//获取指定维度的大小的通用方法,可以使用赋值index
num = bottom[0]->shape(0)       #第一个维度大小,通常为bath size
width = bottom[0]->shape(-1)    #最后一个维度大小,通常为width
const vector<int> &bottom_shape = bottom[0]->shape();

//得到有多少个维度(axes num)
const axes_num = bottom[0]->num_axes();


//读取具体的数据
const datum = bottom[0]->data_at(0,0,0,0); //获取batch 0 channel 0 height 0 width 0的数据值
const diff_datum = bottom[0]->diff_at(0,0,0,0)  //获取batch 0 channel 0 height 0 width 0的梯度    
vector<int> vector_index;
for(int i=0; i<bottom[0]->num_axes(); i++)
    vector_index.push_back(0);
datum = bottom[0]->data_at(vector_index);
diff_datum = bottom[0]->diff_at(vector_index);

/*
 * 修改blob的尺寸
 * 内存不够时会重新分配内存,存在多余的内存则不会释放
 * layer::reshape()后需要执行Net:Farward()或者Net:Reshape()调整整个网络结构,之后才可以调用 
 * Net:Backward(),
 *
*/
bottom[0]->Reshape(1,2,3,4);

/*
 *Update()是更新网络中参数设计的blob函数, 参数= 参数+ alpha*参数梯度
 *计算公式为: Y = alpha * X + Y
 *    Y= bottom[0]->mutable_cpu_data();
 *    X= bottom[0]->cpu_diff();
 *    alpha为梯度下降算法的超参数
*/

BLOB操作实例:

#include<vector>
#include<iostream>
#include<caffe/blob.hpp>
#include<caffe/util/io.hpp>

using namespace caffe;
using namespace std;
void print_blob(Blob<float> *a)
{
    for(int u = 0;u<a->num();u++)
        for(int v = 0;v<a->channels();v++)
            for(int w=0;w<a->height();w++)
                for(int x = 0;x<a->width();x++)
                    //输出blob的值
                    cout<<"a["<<u<<"]["<<v<<"]["<<w<<"]["<<x<<"]="<<a->data_at(u,v,w,x)                
                      <<endl;
}

  int main(void)
  {
          Blob<float> a;
          BlobProto bp;
          cout<<"size:"<<a.shape_string()<<endl;
           
          a.Reshape(1,2,3,4);
          cout<<"after:"<<a.shape_string()<<endl;

          float *p=a.mutable_cpu_data();
          float *q=a.mutable_cpu_diff();
          for(int i = 0;i<a.count();i++){
                  p[i]=i;
                  q[i]=a.count() - 1 - i;
          }
          //更新blob的数据,主要用于更新网络层参数
          a.Update();//diff data combine
          print_blob(&a);
          //计算blob data的L1范数
          cout<<"ASUM="<<a.asum_data()<<endl;
          //计算blob data的L2范数
          cout<<"SUMSQ="<<a.sumsq_data()<<endl;
          
          //保存网络参数数据
          a.ToProto(&bp,true);    //生成BlobProto对象
          WriteProtoToBinaryFile(bp,"a.blob");//写文件
          
          //从文件中读取网络参数
          BlobProto bp2;
          ReadProtoFromBinaryFileOrDie("a.blob",&bp2);
          Blob<float> b;
          b.FromProto(bp2,true);
          print_blob(&b);

          return 0;
  }

其中blobProto的定义如下:

//src/caffe/proto/caffe.proto文件重定义
message BlobProto {
  optional BlobShape shape = 7;
  repeated float data = 5 [packed = true];
  repeated float diff = 6 [packed = true];
  repeated double double_data = 8 [packed = true];
  repeated double double_diff = 9 [packed = true];

  // 4D dimensions -- deprecated.  Use "shape" instead.
  optional int32 num = 1 [default = 0];
  optional int32 channels = 2 [default = 0];
  optional int32 height = 3 [default = 0];
  optional int32 width = 4 [default = 0];
}

 

猜你喜欢

转载自blog.csdn.net/xiaoxu2050/article/details/82864019