Win7 caffe use notes - calculate image mean

1. Caffe's data transformer (DataTransformer) needs to do some preprocessing on the image, such as image dicing—crop_size, mirroring—mirror, amplitude scaling—scale, de-averaging—mean_value, and grayscale transformation—force_gray.

Calculate the mean of the original data and get the mean file. In general, the training set minus the mean file and retraining the model works better.

2.工具compute_image_mean.exe


Enter four parameters

input_db: the converted path of the input data

output_file: The path to the output mean file

db_backend: data format


Neil Z. Shao version caffe-like compute_image_mean.cpp

#include <glog/logging.h>
#include <leveldb/db.h>
//#include <lmdb.h>
#include <stdint.h>

#include <algorithm>
#include <string>

#include "caffe/proto/caffe.pb.h"
#include "caffe/util/io.hpp"

using caffe::Datum;
using caffe::BlobProto;
using std::string;
using std::max;

#ifdef _MSC_VER
#define snprintf printf_s
#endif

int main(int argc, char** argv) {
  ::google::InitGoogleLogging(argv[0]);
  if (argc < 3 || argc > 4) {
    LOG(ERROR) << "Usage: compute_image_mean input_db output_file"
               << " db_backend[leveldb or lmdb]";
    return 1;
  }

  string db_backend = "lmdb";
  if (argc == 4) {
    db_backend = string(argv[3]);
  }

  // leveldb
  leveldb::DB* db = NULL;
  leveldb::Options options;
  options.create_if_missing = false;
  leveldb::Iterator* it = NULL;
  // lmdb
  //MDB_env* mdb_env = NULL;
  //MDB_dbi mdb_dbi;
  //MDB_val mdb_key, mdb_value;
  //MDB_txn* mdb_txn = NULL;
  //MDB_cursor* mdb_cursor = NULL;

  // Open db
  if (db_backend == "leveldb") {  // leveldb
    LOG(INFO) << "Opening leveldb " << argv[1];
    leveldb::Status status = leveldb::DB::Open(
        options, argv[1], &db);
    CHECK(status.ok()) << "Failed to open leveldb " << argv[1];
    leveldb::ReadOptions read_options;
    read_options.fill_cache = false;
    it = db->NewIterator(read_options);
    it->SeekToFirst();
  }
  //else if (db_backend == "lmdb") {  // lmdb
  //  LOG(INFO) << "Opening lmdb " << argv[1];
  //  CHECK_EQ(mdb_env_create(&mdb_env), MDB_SUCCESS) << "mdb_env_create failed";
  //  CHECK_EQ(mdb_env_set_mapsize(mdb_env, 1099511627776), MDB_SUCCESS);  // 1TB
  //  CHECK_EQ(mdb_env_open(mdb_env, argv[1], MDB_RDONLY, 0664),
  //      MDB_SUCCESS) << "mdb_env_open failed";
  //  CHECK_EQ(mdb_txn_begin(mdb_env, NULL, MDB_RDONLY, &mdb_txn), MDB_SUCCESS)
  //      << "mdb_txn_begin failed";
  //  CHECK_EQ(mdb_open(mdb_txn, NULL, 0, &mdb_dbi), MDB_SUCCESS)
  //      << "mdb_open failed";
  //  CHECK_EQ(mdb_cursor_open(mdb_txn, mdb_dbi, &mdb_cursor), MDB_SUCCESS)
  //      << "mdb_cursor_open failed";
  //  CHECK_EQ(mdb_cursor_get(mdb_cursor, &mdb_key, &mdb_value, MDB_FIRST),
  //      MDB_SUCCESS);
  //}
  else {
    LOG(FATAL) << "Unknown db backend " << db_backend;
  }

  Date Date;
  BlobProto sum_blob;
  int count = 0;
  // load first datum
  if (db_backend == "leveldb") {
    datum.ParseFromString(it->value().ToString());
  }
  //else if (db_backend == "lmdb") {
  //  datum.ParseFromArray(mdb_value.mv_data, mdb_value.mv_size);
  //}
  else {
    LOG(FATAL) << "Unknown db backend " << db_backend;
  }

  sum_blob.set_num(1);
  sum_blob.set_channels(datum.channels());
  sum_blob.set_height(datum.height());
  sum_blob.set_width(datum.width());
  const int data_size = datum.channels() * datum.height() * datum.width();
  int size_in_datum = std::max<int>(datum.data().size(),
                                    datum.float_data_size());
  for (int i = 0; i < size_in_datum; ++i) {
    sum_blob.add_data(0.);
  }
  LOG(INFO) << "Starting Iteration";
  if (db_backend == "leveldb") {  // leveldb
    for (it->SeekToFirst(); it->Valid(); it->Next()) {
      // just a dummy operation
      datum.ParseFromString(it->value().ToString());
      const string& data = datum.data();
      size_in_datum = std::max<int>(datum.data().size(),
          datum.float_data_size());
      CHECK_EQ(size_in_datum, data_size) << "Incorrect data field size " <<
          size_in_datum;
      if (data.size() != 0) {
        for (int i = 0; i < size_in_datum; ++i) {
          sum_blob.set_data(i, sum_blob.data(i) + (uint8_t)data[i]);
        }
      } else {
        for (int i = 0; i < size_in_datum; ++i) {
          sum_blob.set_data(i, sum_blob.data(i) +
              static_cast<float>(datum.float_data(i)));
        }
      }
      ++count;
      if (count % 10000 == 0) {
        LOG(ERROR) << "Processed " << count << " files.";
      }
    }
  }
  //else if (db_backend == "lmdb") {  // lmdb
  //  CHECK_EQ(mdb_cursor_get(mdb_cursor, &mdb_key, &mdb_value, MDB_FIRST),
  //      MDB_SUCCESS);
  //  do {
  //    // just a dummy operation
  //    datum.ParseFromArray(mdb_value.mv_data, mdb_value.mv_size);
  //    const string& data = datum.data();
  //    size_in_datum = std::max<int>(datum.data().size(),
  //        datum.float_data_size());
  //    CHECK_EQ(size_in_datum, data_size) << "Incorrect data field size " <<
  //        size_in_datum;
  //    if (data.size() != 0) {
  //      for (int i = 0; i < size_in_datum; ++i) {
  //        sum_blob.set_data(i, sum_blob.data(i) + (uint8_t)data[i]);
  //      }
  //    }
	 // else {
  //      for (int i = 0; i < size_in_datum; ++i) {
  //        sum_blob.set_data(i, sum_blob.data(i) +
  //            static_cast<float>(datum.float_data(i)));
  //      }
  //    }
  //    ++count;
  //    if (count % 10000 == 0) {
  //      LOG(ERROR) << "Processed " << count << " files.";
  //    }
  //  } while (mdb_cursor_get(mdb_cursor, &mdb_key, &mdb_value, MDB_NEXT)
  //      == MDB_SUCCESS);
  //}
  else {
    LOG(FATAL) << "Unknown db backend " << db_backend;
  }

  if (count % 10000 != 0) {
    LOG(ERROR) << "Processed " << count << " files.";
  }
  for (int i = 0; i < sum_blob.data_size(); ++i) {
    sum_blob.set_data(i, sum_blob.data(i) / count);
  }
  // Write to disk
  LOG(INFO) << "Write to " << argv[2];
  WriteProtoToBinaryFile(sum_blob, argv[2]);

  // Clean up
  if (db_backend == "leveldb") {
    delete db;
  }
  //else if (db_backend == "lmdb") {
  //  mdb_cursor_close(mdb_cursor);
  //  mdb_close(mdb_env, mdb_dbi);
  //  mdb_txn_abort(mdb_txn);
  //  mdb_env_close(mdb_env);
  //}
  else {
    LOG(FATAL) << "Unknown db backend " << db_backend;
  }
  return 0;
}

The string db_backend="lmdb" defaults to the db type, but there is a problem in calculating the mean value of lmdb in the test under Windows, so the fourth parameter is entered in leveldb.

leveldb::DB::Open read-only opens the file, ParseFromString gets the data datum, sums it through sum_blob and calculates the mean, and finally WriteProtoToBinaryFile writes it into the binary file.


Write the make_leveldb_mean.bat script.

set EXAMPLE=../../examples/Planthopper
set DATA=../../data/Planthopper
set TOOLS=../../tools/bin/Release

set GLOG_logtostderr=1

echo "Creating train leveldb mean..."

"%TOOLS%/compute_image_mean.exe"  %EXAMPLE%/planthopper_train_leveldb %DATA%/planthopper_train_mean.binaryproto leveldb

echo "Creating test leveldb mean..."
"%TOOLS%/compute_image_mean.exe"  %EXAMPLE%/planthopper_test_leveldb %DATA%/planthopper_test_mean.binaryproto leveldb

echo "Done."

pause


3. Get the mean file under caffe/data, modify train.prototxt, add mean_file under transforms_param, and train the mean-removing network

layers {
  name: "Planthopper"
  type: DATA
  top: "data"
  top: "label"
  data_param {
    source: "examples/Planthopper/planthopper_train_leveldb"
    backend: LEVELDB
    batch_size: 64
  }
  transform_param {
  mean_file:"data/Planthopper/planthopper_train_mean.binaryproto"
    scale: 0.00390625
  }
  include: { phase: TRAIN }
}
layers {
  name: "PlanthopperNet"
  type: DATA
  top: "data"
  top: "label"
  data_param {
    source: "examples/Planthopper/planthopper_test_leveldb"
    backend: LEVELDB
    batch_size: 100
  }
  transform_param {
   mean_file:"data/Planthopper/planthopper_test_mean.binaryproto"
    scale: 0.00390625
  }
  include: { phase: TEST }
}




Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325608059&siteId=291194637