Caffe 实现多标签分类 支持Multi-Label的LMDB数据格式输入

Caffe自带的图像转LMDB接口只支持单label,对于多label的任务,可以使用HDF5的格式,也可以通过修改caffe代码来实现, 我的文章Caffe 实现多标签分类 里介绍了怎么通过修改ImageDataLayer来实现Multilabel的任务, 本篇文章介绍怎么通过修改DataLayer来实现带Multilabel的LMDB格式数据输入的分类任务

1. 首先修改代码

        修改下面的几个文件:

        $CAFFE_ROOT/src/caffe/proto/caffe.proto

        $CAFFE_ROOT/src/caffe/layers/data_layer.cpp

        $CAFFE_ROOT/src/caffe/util/io.cpp

        $CAFFE_ROOT/include/caffe/util/io.hpp

        $CAFFE_ROOT/tools/convert_imageset.cpp

  (1) 修改 caffe.proto

              在 message Datum { }里添加用于容纳labels的一项

    repeated float labels = 8;

          如果你的Label只有int类型,可以用 repeated int32 labels = 8; 

 (2) 修改 data_layer.cpp

         修改函数 DataLayerSetUp()

         新的代码:

 
  1. vector<int> label_shape(2);

  2. label_shape[0] = batch_size;

  3. label_shape[1] = datum.labels_size();

   代码修改前后,右边是修改后的代码

        修改函数  load_batch()

  新的代码:

 
  1. int labelSize = datum.labels_size();

  2. for(int i = 0; i < labelSize; i++){

  3. top_label[item_id*labelSize + i] = datum.labels(i);

  4. }

代码修改前后,右边是修改后的代码

  (3) 修改 io.hpp

   新的代码

 
  1. bool ReadFileToDatum(const string& filename, const vector<float> label, Datum* datum);

  2.  
  3. inline bool ReadFileToDatum(const string& filename, Datum* datum) {

  4. return ReadFileToDatum(filename, vector<float>(), datum);

  5. }

  6.  
  7. bool ReadImageToDatum(const string& filename, const vector<float> label,

  8. const int height, const int width, const bool is_color,

  9. const std::string & encoding, Datum* datum);

  10.  
  11. inline bool ReadImageToDatum(const string& filename, const vector<float> label,

  12. const int height, const int width, const bool is_color, Datum* datum) {

  13. return ReadImageToDatum(filename, label, height, width, is_color,

  14. "", datum);

  15. }

  16.  
  17. inline bool ReadImageToDatum(const string& filename, const vector<float> label,

  18. const int height, const int width, Datum* datum) {

  19. return ReadImageToDatum(filename, label, height, width, true, datum);

  20. }

  21.  
  22. inline bool ReadImageToDatum(const string& filename, const vector<float> label,

  23. const bool is_color, Datum* datum) {

  24. return ReadImageToDatum(filename, label, 0, 0, is_color, datum);

  25. }

  26.  
  27. inline bool ReadImageToDatum(const string& filename, const vector<float> label,

  28. Datum* datum) {

  29. return ReadImageToDatum(filename, label, 0, 0, true, datum);

  30. }

  31.  
  32. inline bool ReadImageToDatum(const string& filename, const vector<float> label,

  33. const std::string & encoding, Datum* datum) {

  34. return ReadImageToDatum(filename, label, 0, 0, true, encoding, datum);

  35. }

代码修改前后,右边是修改后的代码

  (4) 修改 io.cpp

   修改函数 ReadImageToDatum()

   修改后的代码

 
  1. bool ReadImageToDatum(const string& filename, const vector<float> label,

  2. const int height, const int width, const bool is_color,

  3. const std::string & encoding, Datum* datum) {

  4. cv::Mat cv_img = ReadImageToCVMat(filename, height, width, is_color);

  5. if (cv_img.data) {

  6. if (encoding.size()) {

  7. if ( (cv_img.channels() == 3) == is_color && !height && !width &&

  8. matchExt(filename, encoding) )

  9. return ReadFileToDatum(filename, label, datum);

  10. std::vector<uchar> buf;

  11. cv::imencode("."+encoding, cv_img, buf);

  12. datum->set_data(std::string(reinterpret_cast<char*>(&buf[0]),

  13. buf.size()));

  14.  
  15. datum->clear_labels();

  16. for (int i = 0; i < label.size(); i++){

  17. datum->add_labels(label[i]);

  18. }

  19. datum->set_encoded(true);

  20. return true;

  21. }

  22. CVMatToDatum(cv_img, datum);

  23.  
  24. datum->clear_labels();

  25. for (int i = 0; i < label.size(); i++){

  26. datum->add_labels(label[i]);

  27. }

  28. return true;

  29. } else {

  30. return false;

  31. }

  32. }

代码修改前后,右边是修改后的代码

  修改函数 ReadFileToDatum()

  修改后的代码

 
  1. bool ReadFileToDatum(const string& filename, const vector<float> label,

  2. Datum* datum) {

  3. std::streampos size;

  4.  
  5. fstream file(filename.c_str(), ios::in|ios::binary|ios::ate);

  6. if (file.is_open()) {

  7. size = file.tellg();

  8. std::string buffer(size, ' ');

  9. file.seekg(0, ios::beg);

  10. file.read(&buffer[0], size);

  11. file.close();

  12. datum->set_data(buffer);

  13.  
  14. datum->clear_labels();

  15. for (int i = 0; i < label.size(); i++){

  16. datum->add_labels(label[i]);

  17. }

  18. datum->set_encoded(true);

  19. return true;

  20. } else {

  21. return false;

  22. }

  23. }

   代码修改前后,右边是修改后的代码

  
  (5) 修改 convert_imageset.cpp

              修改部分新的代码

 
  1. std::vector<std::pair<std::string, vector<float> > > lines;

  2. std::string line, filename;

  3.  
  4. float label;

  5. while (std::getline(infile, line)) {

  6. std::istringstream iss(line);

  7. iss >> filename;

  8. std::vector<float> labels;

  9. while(iss >> label) {

  10. labels.push_back(label);

  11. }

  12. lines.push_back(std::make_pair(filename, labels));

  13. }

   代码修改前后,右边是修改后的代码


  

2. 编译代码

mark@ubuntu:~/caffe/build$ make all
mark@ubuntu:~/caffe/build$ sudo make install

3. 生成LMDB文件

编译成功后,使用新生成的 convert_imageset 将训练所用的图片转换成LMDB文件

将训练所用图片转换为LMDB文件

mark@ubuntu:~/caffe$ sudo ./build/tools/convert_imageset -shuffle=true  /home/mark/data/  /home/mark/data/train.txt  ./examples/captcha/captcha_train_lmdb

/home/mark/data/                  是训练所用的图片所在的root目录
/home/mark/data/train.txt  记录每个训练图片文件的名称和标签,它的内容见下图,训练图片文件的名称和/home/mark/data/拼接起来是训练图片的绝对路径
./examples/captcha/captcha_train_lmdb 是生成的lmdb文件所在目录

同样可以将测试图片转换成LMDB文件

mark@ubuntu:~/caffe$ sudo ./build/tools/convert_imageset -shuffle=true  /home/mark/data/  /home/mark/data/test.txt  ./examples/captcha/captcha_test_lmdb

4. 网络结构和solver

网络结构文件 captcha_train_test_lmdb.prototxt

 
  1. name: "captcha"

  2. layer {

  3. name: "Input"

  4. type: "Data"

  5. top: "data"

  6. top: "label"

  7. include {

  8. phase: TRAIN

  9. }

  10. transform_param {

  11. scale: 0.00390625

  12. }

  13. data_param {

  14. source: "examples/captcha/captcha_train_lmdb"

  15. batch_size: 50

  16. backend: LMDB

  17. }

  18. }

  19.  
  20. layer {

  21. name: "Input"

  22. type: "Data"

  23. top: "data"

  24. top: "label"

  25. include {

  26. phase: TEST

  27. }

  28. transform_param {

  29. scale: 0.00390625

  30. }

  31. data_param {

  32. source: "examples/captcha/captcha_test_lmdb"

  33. batch_size: 20

  34. backend: LMDB

  35. }

  36. }

  37.  
  38. layer {

  39. name: "slice"

  40. type: "Slice"

  41. bottom: "label"

  42. top: "label_1"

  43. top: "label_2"

  44. top: "label_3"

  45. top: "label_4"

  46. slice_param {

  47. axis: 1

  48. slice_point:1

  49. slice_point:2

  50. slice_point:3

  51. }

  52. }

  53.  
  54. layer {

  55. name: "conv1"

  56. type: "Convolution"

  57. bottom: "data"

  58. top: "conv1"

  59. param {

  60. lr_mult: 1

  61. }

  62. param {

  63. lr_mult: 2

  64. }

  65. convolution_param {

  66. num_output: 20

  67. kernel_size: 5

  68. stride: 1

  69. weight_filler {

  70. type: "xavier"

  71. }

  72. bias_filler {

  73. type: "constant"

  74. }

  75. }

  76. }

  77. layer {

  78. name: "pool1"

  79. type: "Pooling"

  80. bottom: "conv1"

  81. top: "pool1"

  82. pooling_param {

  83. pool: MAX

  84. kernel_size: 2

  85. stride: 2

  86. }

  87. }

  88. layer {

  89. name: "conv2"

  90. type: "Convolution"

  91. bottom: "pool1"

  92. top: "conv2"

  93. param {

  94. lr_mult: 1

  95. }

  96. param {

  97. lr_mult: 2

  98. }

  99. convolution_param {

  100. num_output: 50

  101. kernel_size: 5

  102. stride: 1

  103. weight_filler {

  104. type: "xavier"

  105. }

  106. bias_filler {

  107. type: "constant"

  108. }

  109. }

  110. }

  111. layer {

  112. name: "pool2"

  113. type: "Pooling"

  114. bottom: "conv2"

  115. top: "pool2"

  116. pooling_param {

  117. pool: MAX

  118. kernel_size: 2

  119. stride: 2

  120. }

  121. }

  122. layer {

  123. name: "ip1"

  124. type: "InnerProduct"

  125. bottom: "pool2"

  126. top: "ip1"

  127. param {

  128. lr_mult: 1

  129. }

  130. param {

  131. lr_mult: 2

  132. }

  133. inner_product_param {

  134. num_output: 500

  135. weight_filler {

  136. type: "xavier"

  137. }

  138. bias_filler {

  139. type: "constant"

  140. }

  141. }

  142. }

  143. layer {

  144. name: "relu1"

  145. type: "ReLU"

  146. bottom: "ip1"

  147. top: "ip1"

  148. }

  149.  
  150. layer {

  151. name: "ip2"

  152. type: "InnerProduct"

  153. bottom: "ip1"

  154. top: "ip2"

  155. param {

  156. lr_mult: 1

  157. }

  158. param {

  159. lr_mult: 2

  160. }

  161. inner_product_param {

  162. num_output: 100

  163. weight_filler {

  164. type: "xavier"

  165. }

  166. bias_filler {

  167. type: "constant"

  168. }

  169. }

  170. }

  171.  
  172. layer {

  173. name: "ip3_1"

  174. type: "InnerProduct"

  175. bottom: "ip2"

  176. top: "ip3_1"

  177. param {

  178. lr_mult: 1

  179. }

  180. param {

  181. lr_mult: 2

  182. }

  183. inner_product_param {

  184. num_output: 10

  185. weight_filler {

  186. type: "xavier"

  187. }

  188. bias_filler {

  189. type: "constant"

  190. }

  191. }

  192. }

  193.  
  194. layer {

  195. name: "ip3_2"

  196. type: "InnerProduct"

  197. bottom: "ip2"

  198. top: "ip3_2"

  199. param {

  200. lr_mult: 1

  201. }

  202. param {

  203. lr_mult: 2

  204. }

  205. inner_product_param {

  206. num_output: 10

  207. weight_filler {

  208. type: "xavier"

  209. }

  210. bias_filler {

  211. type: "constant"

  212. }

  213. }

  214. }

  215.  
  216. layer {

  217. name: "ip3_3"

  218. type: "InnerProduct"

  219. bottom: "ip2"

  220. top: "ip3_3"

  221. param {

  222. lr_mult: 1

  223. }

  224. param {

  225. lr_mult: 2

  226. }

  227. inner_product_param {

  228. num_output: 10

  229. weight_filler {

  230. type: "xavier"

  231. }

  232. bias_filler {

  233. type: "constant"

  234. }

  235. }

  236. }

  237.  
  238. layer {

  239. name: "ip3_4"

  240. type: "InnerProduct"

  241. bottom: "ip2"

  242. top: "ip3_4"

  243. param {

  244. lr_mult: 1

  245. }

  246. param {

  247. lr_mult: 2

  248. }

  249. inner_product_param {

  250. num_output: 10

  251. weight_filler {

  252. type: "xavier"

  253. }

  254. bias_filler {

  255. type: "constant"

  256. }

  257. }

  258. }

  259.  
  260. layer {

  261. name: "accuracy1"

  262. type: "Accuracy"

  263. bottom: "ip3_1"

  264. bottom: "label_1"

  265. top: "accuracy1"

  266. include {

  267. phase: TEST

  268. }

  269. }

  270. layer {

  271. name: "loss1"

  272. type: "SoftmaxWithLoss"

  273. bottom: "ip3_1"

  274. bottom: "label_1"

  275. top: "loss1"

  276. }

  277.  
  278. layer {

  279. name: "accuracy2"

  280. type: "Accuracy"

  281. bottom: "ip3_2"

  282. bottom: "label_2"

  283. top: "accuracy2"

  284. include {

  285. phase: TEST

  286. }

  287. }

  288. layer {

  289. name: "loss2"

  290. type: "SoftmaxWithLoss"

  291. bottom: "ip3_2"

  292. bottom: "label_2"

  293. top: "loss2"

  294. }

  295.  
  296. layer {

  297. name: "accuracy3"

  298. type: "Accuracy"

  299. bottom: "ip3_3"

  300. bottom: "label_3"

  301. top: "accuracy3"

  302. include {

  303. phase: TEST

  304. }

  305. }

  306. layer {

  307. name: "loss3"

  308. type: "SoftmaxWithLoss"

  309. bottom: "ip3_3"

  310. bottom: "label_3"

  311. top: "loss3"

  312. }

  313.  
  314. layer {

  315. name: "accuracy4"

  316. type: "Accuracy"

  317. bottom: "ip3_4"

  318. bottom: "label_4"

  319. top: "accuracy4"

  320. include {

  321. phase: TEST

  322. }

  323. }

  324. layer {

  325. name: "loss4"

  326. type: "SoftmaxWithLoss"

  327. bottom: "ip3_4"

  328. bottom: "label_4"

  329. top: "loss4"

  330. }

solver文件 captcha_solver_lmdb.prototxt

 
  1. # The train/test net protocol buffer definition

  2. net: "examples/captcha/captcha_train_test_lmdb.prototxt"

  3. # test_iter specifies how many forward passes the test should carry out.

  4. # covering the full 9,800 testing images.

  5. test_iter: 200

  6. # Carry out testing every 200 training iterations.

  7. test_interval: 200

  8. # The base learning rate, momentum and the weight decay of the network.

  9. base_lr: 0.001

  10. momentum: 0.9

  11. weight_decay: 0.0005

  12. # The learning rate policy

  13. lr_policy: "inv"

  14. gamma: 0.0001

  15. power: 0.75

  16. # Display every 100 iterations

  17. display: 100

  18. # The maximum number of iterations

  19. max_iter: 10000

  20. # snapshot intermediate results

  21. snapshot: 5000

  22. snapshot_prefix: "examples/captcha/captcha"

  23. # solver mode: CPU or GPU

  24. solver_mode: GPU


5. 开始训练

mark@ubuntu:~/caffe$ sudo ./build/tools/caffe train --solver=examples/captcha/captcha_solver_lmdb.prototxt

训练完后,生成model文件:   captcha_iter_10000.caffemodel

6. 用生成的model 文件进行测试

首先,需要一个deploy.prototxt文件,在captcha_train_test_lmdb.prototxt的基础上修改,修改后保存为 captcha_deploy_lmdb.prototxt  内容如下

 
  1. name: "captcha"

  2.  
  3. input: "data"

  4. input_dim: 1 # batchsize

  5. input_dim: 3 # number of channels - rgb

  6. input_dim: 60 # height

  7. input_dim: 160 # width

  8.  
  9. layer {

  10. name: "conv1"

  11. type: "Convolution"

  12. bottom: "data"

  13. top: "conv1"

  14. param {

  15. lr_mult: 1

  16. }

  17. param {

  18. lr_mult: 2

  19. }

  20. convolution_param {

  21. num_output: 20

  22. kernel_size: 5

  23. stride: 1

  24. weight_filler {

  25. type: "xavier"

  26. }

  27. bias_filler {

  28. type: "constant"

  29. }

  30. }

  31. }

  32. layer {

  33. name: "pool1"

  34. type: "Pooling"

  35. bottom: "conv1"

  36. top: "pool1"

  37. pooling_param {

  38. pool: MAX

  39. kernel_size: 2

  40. stride: 2

  41. }

  42. }

  43. layer {

  44. name: "conv2"

  45. type: "Convolution"

  46. bottom: "pool1"

  47. top: "conv2"

  48. param {

  49. lr_mult: 1

  50. }

  51. param {

  52. lr_mult: 2

  53. }

  54. convolution_param {

  55. num_output: 50

  56. kernel_size: 5

  57. stride: 1

  58. weight_filler {

  59. type: "xavier"

  60. }

  61. bias_filler {

  62. type: "constant"

  63. }

  64. }

  65. }

  66. layer {

  67. name: "pool2"

  68. type: "Pooling"

  69. bottom: "conv2"

  70. top: "pool2"

  71. pooling_param {

  72. pool: MAX

  73. kernel_size: 2

  74. stride: 2

  75. }

  76. }

  77. layer {

  78. name: "ip1"

  79. type: "InnerProduct"

  80. bottom: "pool2"

  81. top: "ip1"

  82. param {

  83. lr_mult: 1

  84. }

  85. param {

  86. lr_mult: 2

  87. }

  88. inner_product_param {

  89. num_output: 500

  90. weight_filler {

  91. type: "xavier"

  92. }

  93. bias_filler {

  94. type: "constant"

  95. }

  96. }

  97. }

  98. layer {

  99. name: "relu1"

  100. type: "ReLU"

  101. bottom: "ip1"

  102. top: "ip1"

  103. }

  104.  
  105. layer {

  106. name: "ip2"

  107. type: "InnerProduct"

  108. bottom: "ip1"

  109. top: "ip2"

  110. param {

  111. lr_mult: 1

  112. }

  113. param {

  114. lr_mult: 2

  115. }

  116. inner_product_param {

  117. num_output: 100

  118. weight_filler {

  119. type: "xavier"

  120. }

  121. bias_filler {

  122. type: "constant"

  123. }

  124. }

  125. }

  126.  
  127. layer {

  128. name: "ip3_1"

  129. type: "InnerProduct"

  130. bottom: "ip2"

  131. top: "ip3_1"

  132. param {

  133. lr_mult: 1

  134. }

  135. param {

  136. lr_mult: 2

  137. }

  138. inner_product_param {

  139. num_output: 10

  140. weight_filler {

  141. type: "xavier"

  142. }

  143. bias_filler {

  144. type: "constant"

  145. }

  146. }

  147. }

  148.  
  149. layer {

  150. name: "ip3_2"

  151. type: "InnerProduct"

  152. bottom: "ip2"

  153. top: "ip3_2"

  154. param {

  155. lr_mult: 1

  156. }

  157. param {

  158. lr_mult: 2

  159. }

  160. inner_product_param {

  161. num_output: 10

  162. weight_filler {

  163. type: "xavier"

  164. }

  165. bias_filler {

  166. type: "constant"

  167. }

  168. }

  169. }

  170.  
  171. layer {

  172. name: "ip3_3"

  173. type: "InnerProduct"

  174. bottom: "ip2"

  175. top: "ip3_3"

  176. param {

  177. lr_mult: 1

  178. }

  179. param {

  180. lr_mult: 2

  181. }

  182. inner_product_param {

  183. num_output: 10

  184. weight_filler {

  185. type: "xavier"

  186. }

  187. bias_filler {

  188. type: "constant"

  189. }

  190. }

  191. }

  192.  
  193. layer {

  194. name: "ip3_4"

  195. type: "InnerProduct"

  196. bottom: "ip2"

  197. top: "ip3_4"

  198. param {

  199. lr_mult: 1

  200. }

  201. param {

  202. lr_mult: 2

  203. }

  204. inner_product_param {

  205. num_output: 10

  206. weight_filler {

  207. type: "xavier"

  208. }

  209. bias_filler {

  210. type: "constant"

  211. }

  212. }

  213. }

  214.  
  215. layer {

  216. name: "prob1"

  217. type: "Softmax"

  218. bottom: "ip3_1"

  219. top: "prob1"

  220. }

  221. layer {

  222. name: "prob2"

  223. type: "Softmax"

  224. bottom: "ip3_2"

  225. top: "prob2"

  226. }

  227. layer {

  228. name: "prob3"

  229. type: "Softmax"

  230. bottom: "ip3_3"

  231. top: "prob3"

  232. }

  233. layer {

  234. name: "prob4"

  235. type: "Softmax"

  236. bottom: "ip3_4"

  237. top: "prob4"

  238. }


编写测试代码:

 
  1. import numpy as np

  2. import os

  3. import sys

  4. os.environ['GLOG_minloglevel'] = '3'

  5. import caffe

  6.  
  7. CAFFE_ROOT = '/home/mark/caffe'

  8. deploy_file_name = 'captcha_deploy_lmdb.prototxt'

  9. model_file_name = 'captcha_iter_10000.caffemodel'

  10.  
  11. IMAGE_HEIGHT = 60

  12. IMAGE_WIDTH = 160

  13. IMAGE_CHANNEL = 3

  14.  
  15. def classify(imageFileName):

  16. deploy_file = CAFFE_ROOT + '/examples/captcha/' + deploy_file_name

  17. model_file = CAFFE_ROOT + '/examples/captcha/' + model_file_name

  18. #初始化caffe

  19. net = caffe.Net(deploy_file, model_file, caffe.TEST)

  20.  
  21. #数据预处理

  22. transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})

  23. transformer.set_transpose('data', (2, 0, 1))#pycaffe读取的图片文件格式为H×W×C,需转化为C×H×W

  24.  
  25. #pycaffe将图片存储为[0, 1], 如果模型输入用的是0~255的原始格式,需要做如下转换

  26. #transformer.set_raw_scale('data', 255)

  27.  
  28. transformer.set_channel_swap('data', (2, 1, 0))#caffe中图片是BGR格式,而原始格式是RGB,所以要转化

  29.  
  30. # 将输入图片格式转化为合适格式(与deploy文件相同)

  31. net.blobs['data'].reshape(1, IMAGE_CHANNEL, IMAGE_HEIGHT, IMAGE_WIDTH)

  32.  
  33. #读取图片

  34. #参数color: True(default)是彩色图,False是灰度图

  35. img = caffe.io.load_image(imageFileName, color=True)

  36.  
  37. #数据输入、预处理

  38. net.blobs['data'].data[...] = transformer.preprocess('data', img)

  39.  
  40. #前向迭代,即分类

  41. out = net.forward()

  42.  
  43. #求出每个标签概率最大值的下标

  44. result = []

  45. predict1 = out['prob1'][0].argmax()

  46. result.append(predict1)

  47.  
  48. predict2 = out['prob2'][0].argmax()

  49. result.append(predict2)

  50.  
  51. predict3 = out['prob3'][0].argmax()

  52. result.append(predict3)

  53.  
  54. predict4 = out['prob4'][0].argmax()

  55. result.append(predict4)

  56.  
  57. return result

  58.  
  59. if __name__ == '__main__':

  60.  
  61. imgList = sys.argv[1:]

  62. for captcha in imgList:

  63. predict = classify(captcha)

  64. print "captcha:", captcha, " predict:", predict

运行测试代码进行测试

猜你喜欢

转载自blog.csdn.net/zlf19910726/article/details/81091928
今日推荐