Caffe训练数据转换为HD5与LMDB的代码实现

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/m_buddy/article/details/84781584

1. 前言

一般来讲再caffe中经常的数据结构是LMDB以及HD5文件。再进行训练的时候需要将其转换为对应的格式,自然直接读取原始图像数据也是可以的,但是转换之后其读取的效率更高。那么这篇博客中就借着这两点来梳理一下这两种数据是怎么转换来的,在后面的文章中再讲网络训练过程中怎么从这些文件中读取数据。

2. LMDB文件

再caffe环境下怎么调用现有的接口实现训练数据集的转换可以参考我之前的文章:使用Caffe的convert_imageset生成lmdb文件。那么调用该接口之后caffe究竟干了一些什么事情呢?接下来就要梳理convert_imageset.cpp文件了。
首先,传入的命令行参数进行解析:

gflags::ParseCommandLineFlags(&argc, &argv, true);

if (argc < 4) {
    gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/convert_imageset");
    return 1;
}

const bool is_color = !FLAGS_gray;
const bool check_size = FLAGS_check_size;
const bool encoded = FLAGS_encoded;
const string encode_type = FLAGS_encode_type;

通过txt文件,读取所有的文件名与对应的label,指定resize的大小

 //读取标注文件
std::ifstream infile(argv[2]);
std::vector<std::pair<std::string, int> > lines; //标注的数据结构piar<文件名,标签类别>
std::string line; //图像名
size_t pos; //标注文件中空格的位置
int label; //分类标签
while (std::getline(infile, line)) {
    pos = line.find_last_of(' ');
    label = atoi(line.substr(pos + 1).c_str());
    lines.push_back(std::make_pair(line.substr(0, pos), label));
}
if (FLAGS_shuffle) { // 打乱训练数据
    // randomly shuffle data
    LOG(INFO) << "Shuffling data";
    shuffle(lines.begin(), lines.end());
}
LOG(INFO) << "A total of " << lines.size() << " images."; //打印文件总数

if (encode_type.size() && !encoded)
    LOG(INFO) << "encode_type specified, assuming encoded=true.";

int resize_height = std::max<int>(0, FLAGS_resize_height);  //指定resize的高度
int resize_width = std::max<int>(0, FLAGS_resize_width);    //指定resize的宽度

接下来定义数据存储对象,并逐行读取图像数据,并添加图像像素数据与label信息。

// Create new DB
scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend)); //lmdb数据对象
db->Open(argv[3], db::NEW); //通过给定的保存文件夹来初始化对象
scoped_ptr<db::Transaction> txn(db->NewTransaction()); //获得数据库操作句柄

// Storing to db
std::string root_folder(argv[1]);
Datum datum; //存储图像数据与对应的label
int count = 0; //记录处理的图片数量
int data_size = 0; //记录图像像素的个数C*H*W
bool data_size_initialized = false;

for (int line_id = 0; line_id < lines.size(); ++line_id) { //遍历txt文件中所有的图像数据
    bool status;
    std::string enc = encode_type;
  if (encoded && !enc.size()) {
      // Guess the encoding type from the file name
      string fn = lines[line_id].first;
      size_t p = fn.rfind('.');
      if ( p == fn.npos )
        LOG(WARNING) << "Failed to guess the encoding of '" << fn << "'";
      enc = fn.substr(p+1);
      std::transform(enc.begin(), enc.end(), enc.begin(), ::tolower);
  }
  //往DB中添加图像数据并设置标签
  status = ReadImageToDatum(root_folder + lines[line_id].first,
        lines[line_id].second, resize_height, resize_width, is_color,
        enc, &datum);
  if (status == false) continue;
  if (check_size) {
    if (!data_size_initialized) {
        data_size = datum.channels() * datum.height() * datum.width();
        data_size_initialized = true;
    } else {
        const std::string& data = datum.data();
        CHECK_EQ(data.size(), data_size) << "Incorrect data field size "
            << data.size();
      }
  }
  // sequential
  string key_str = caffe::format_int(line_id, 8) + "_" + lines[line_id].first;

  // Put in db 通过键值写到数据库
  string out;
  CHECK(datum.SerializeToString(&out));
  txn->Put(key_str, out);

  //1k个图像就提交一次
  if (++count % 1000 == 0) {
      // Commit db
      txn->Commit();
      txn.reset(db->NewTransaction());
      LOG(INFO) << "Processed " << count << " files.";
  }

其中,涉及到图像数据读取与存储:

bool ReadImageToDatum(const string& filename, const int label,
    const int height, const int width, const bool is_color,
    const std::string & encoding, Datum* datum) {
  //根据给定的文件名 需要的宽高 是否为彩色使用opencv图像读取,并作变换
  cv::Mat cv_img = ReadImageToCVMat(filename, height, width, is_color);
  if (cv_img.data) {
    if (encoding.size()) {
      if ( (cv_img.channels() == 3) == is_color && !height && !width &&
          matchExt(filename, encoding) )
        return ReadFileToDatum(filename, label, datum);
      std::vector<uchar> buf;
      cv::imencode("."+encoding, cv_img, buf);
      datum->set_data(std::string(reinterpret_cast<char*>(&buf[0]),
                      buf.size()));
      datum->set_label(label);
      datum->set_encoded(true);
      return true;
    }
    //创建一个同像素个数同样大小的字符串缓冲区,将图像数据填进去,再将其添加到DB中
    CVMatToDatum(cv_img, datum);
    datum->set_label(label); //设置标签
    return true;
  } else {
    return false;
  }
}

其中涉及到Google Protobuf的Datum数据结构,可以看作是一个三通道( C H W C*H*W )的Matrix。这里给出它的一个使用例子:

Datum datum;
datum.set_width(3); // our data has three inputs
datum.set_height(1); // our data is one-dimensional
datum.set_channels(1);

google::protobuf::RepeatedField<float>* datumFloatData = datum.mutable_float_data();
datumFloatData->Add(0.0f);
datumFloatData->Add(1.0f);
datumFloatData->Add(0.0f);

datum.set_label(2);

对应LMDB文件的数据读取层定义:

layer {
  name: "mnist"
  type: "Data"
  top: "data"
  top: "label"
  include {
    phase: TRAIN
  }
  transform_param {
    scale: 0.00390625
  }
  data_param {
    source: "examples/mnist/mnist_train_lmdb"
    batch_size: 64
    backend: LMDB
  }
}

3. HD5文件

import h5py

with h5py.File(file_name, 'w') as f:
    f.create_dataset('data', data = img_data) # 图像数据,是一个N*C*H*W的numpy数组
    f.create_dataset('label', data = label_seq) # 标签数据

caffe下的layer读取:

layer {
  name: "data"
  type: "HDF5Data"
  top: "data"
  top: "label"
  hdf5_data_param {
    source: "examples/hdf5_classification/data/train.txt"
    batch_size: 10
  }
}

猜你喜欢

转载自blog.csdn.net/m_buddy/article/details/84781584