在github上找到FashionLandmark论文的demo后,由于作者只公布了测试的源码,对于训练时需要的loss层等结构就无从所知了。有幸找到这篇文章《caffemodel解析,caffemodel里面到底记录了什么?》
文章里面先用ReadProtoFromBinaryFile函数将二进制文件读入proto里面,再将proto文件读入txt文件,便得到训练时的网络结构和网络参数。下面是测试代码:
#include <caffe/caffe.hpp>
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>
#include <algorithm>
#include <iosfwd>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <iostream>
#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/io.hpp"
using namespace caffe;
using namespace std;
using google::protobuf::io::FileInputStream;
using google::protobuf::io::FileOutputStream;
using google::protobuf::io::ZeroCopyInputStream;
using google::protobuf::io::CodedInputStream;
using google::protobuf::io::ZeroCopyOutputStream;
using google::protobuf::io::CodedOutputStream;
using google::protobuf::Message;
int main()
{
NetParameter proto;
ReadProtoFromBinaryFile("/home/harry/spyderWorkPace/fld/deep_landmark/model/1_F/_iter_50000.caffemodel", &proto);
WriteProtoToTextFile(proto, "/home/harry/test.txt");
return 0;
}
使用CMake工具生成Makefile后,make时却发生了错误:
说的是找不到caffe/proto/caffe.pb.h头文件,于是去caffe安装路径下果然没找到。网上找到这篇文章:
没有用CMake,我在caffe_root/src/caffe/proto/目录下使用protoc重新生成caffe.pb.h和caffe.pb.cc:
protoc ./caffe.proto –cpp_out=/home/harry/work/caffe/include/proto/
(提前在/home/harry/work/caffe/include/下建好proto目录)
然后make仍然报错:
意思是说我编译caffe.pb.h用的protoc版本過新。查了一下我的protoc版本:
于是使用笔记本上ubuntu的protoc(版本为2.5.0)生成caffe.pb.h,然后拷到台式机ubuntu上,这下make成功了:
成功生成了test.txt:
当然,这样得到的文件由于包含了模型中所有的参数,数据量巨大(十几万行),要想人工从中筛选出网络结构,几乎是不太可能的。观察生成的txt文件,不难发现数据是有规律的,可以写个脚本自动筛选出需要的信息:
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Apr 12 15:44:35 2018
@author: harry
"""
import numpy as np
if __name__ == '__main__':
file_in = '/home/harry/spyderWorkPace/fashion-landmarks/test.txt'
file_out = '/home/harry/spyderWorkPace/fashion-landmarks/t.txt'
fp_in = open(file_in,'r')
fp_out = open(file_out,'w')
line_out = []
tmp = []
for line in fp_in.readlines():
tmp = line.strip().split(' ')
if tmp[0] != 'data:' and tmp[0] != 'blobs':
line_out.append(line)
fp_in.close()
fp_out.writelines(line_out)
fp_out.close()
可见,筛选后的网络没有了中间参数,可以拿去使用了。