版权声明:本文为博主原创文章,转载请注明。 https://blog.csdn.net/elaine_bao/article/details/79433733
caffe中video_data_layer.cpp的解析,直接看代码中的注释。
// include的部分就不介绍了
#include <fstream>
#include <iostream>
#include <string>
#include <utility>
#include <vector>
#include "caffe/data_layers.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/io.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/util/rng.hpp"
#ifdef USE_MPI
#include "mpi.h"
#include <boost/filesystem.hpp>
using namespace boost::filesystem;
#endif
namespace caffe{
template <typename Dtype>
// 析构函数
VideoDataLayer<Dtype>:: ~VideoDataLayer<Dtype>(){
this->JoinPrefetchThread();
}
template <typename Dtype>
// DataLayer SetUp
void VideoDataLayer<Dtype>:: DataLayerSetUp(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top){
const int new_height = this->layer_param_.video_data_param().new_height();
const int new_width = this->layer_param_.video_data_param().new_width();
//表示一次取多少帧/光流
const int new_length = this->layer_param_.video_data_param().new_length();
//表示视频分成多少份(segment)进行抽帧/光流
const int num_segments = this->layer_param_.video_data_param().num_segments();
//输入的列表文件,这个列表文件的每一行包括:视频抽帧/光流的存储位置,视频帧数,视频label标号
const string& source = this->layer_param_.video_data_param().source();
// 1. 读入source文件
LOG(INFO) << "Opening file: " << source;
std:: ifstream infile(source.c_str());
string filename;
int label;
int length;
//从列表的每一行依次读入存储位置,视频帧数(长度),视频label标号
while (infile >> filename >> length >> label){
//lines_存入(filename,label)
lines_.push_back(std::make_pair(filename,label));
//lines_duration_存入(length)
lines_duration_.push_back(length);
}
// 2. shuffle文件行的顺序
if (this->layer_param_.video_data_param().shuffle()){
const unsigned int prefectch_rng_seed = caffe_rng_rand();
//prefetch_rng_1_生成随机数,用于shuffle lines_
prefetch_rng_1_.reset(new Caffe::RNG(prefectch_rng_seed));
//prefetch_rng_2_生成随机数,用于shuffle lines_duration_
prefetch_rng_2_.reset(new Caffe::RNG(prefectch_rng_seed));
ShuffleVideos(); //看2.2的部分
}
LOG(INFO) << "A total of " << lines_.size() << " videos.";
lines_id_ = 0;
// 3. check name pattern:如果name pattern为空,则name_pattern设置为"image_%04d.jpg"(帧)/"flow_%c_%04d.jpg"(光流)
if (this->layer_param_.video_data_param().name_pattern() == ""){
if (this->layer_param_.video_data_param().modality() == VideoDataParameter_Modality_RGB){
name_pattern_ = "image_%04d.jpg";
}else if (this->layer_param_.video_data_param().modality() == VideoDataParameter_Modality_FLOW){
name_pattern_ = "flow_%c_%04d.jpg";
}
}else{
name_pattern_ = this->layer_param_.video_data_param().name_pattern();
}
// 4. 生成随机frame编号:由于source列表给出的列表是以视频为单位的,视频中具体选择哪帧进行训练需要通过随机生成
Datum datum;
bool is_color = !this->layer_param_.video_data_param().grayscale();
const unsigned int frame_prefectch_rng_seed = caffe_rng_rand();
frame_prefetch_rng_.reset(new Caffe::RNG(frame_prefectch_rng_seed));
//将视频总帧数分成num_segments份,每份的平均长度
int average_duration = (int) lines_duration_[lines_id_]/num_segments;
vector<int> offsets;
for (int i = 0; i < num_segments; ++i){
//产生随机数
caffe::rng_t* frame_rng = static_cast<caffe::rng_t*>(frame_prefetch_rng_->generator());
//每个segment都会抽取new_length长度的帧,具体从每个segment的哪个位置开始取,根据随机数决定
int offset = (*frame_rng)() % (average_duration - new_length + 1);
offsets.push_back(offset+i*average_duration);
}
// 5. 读数据到datum
//关于ReadSegmentFlowToDatum函数,见下文5.2
if (this->layer_param_.video_data_param().modality() == VideoDataParameter_Modality_FLOW)
CHECK(ReadSegmentFlowToDatum(lines_[lines_id_].first, lines_[lines_id_].second,
offsets, new_height, new_width, new_length, &datum, name_pattern_.c_str()));
else //关于ReadSegmentRGBToDatum函数,见下文5.1
CHECK(ReadSegmentRGBToDatum(lines_[lines_id_].first, lines_[lines_id_].second,
offsets, new_height, new_width, new_length, &datum, is_color, name_pattern_.c_str()));
// 6. crop:如果需要crop则crop
const int crop_size = this->layer_param_.transform_param().crop_size();
const int batch_size = this->layer_param_.video_data_param().batch_size();
if (crop_size > 0){
top[0]->Reshape(batch_size, datum.channels(), crop_size, crop_size);
this->prefetch_data_.Reshape(batch_size, datum.channels(), crop_size, crop_size);
} else {
top[0]->Reshape(batch_size, datum.channels(), datum.height(), datum.width());
this->prefetch_data_.Reshape(batch_size, datum.channels(), datum.height(), datum.width());
}
LOG(INFO) << "output data size: " << top[0]->num() << "," << top[0]->channels() << "," << top[0]->height() << "," << top[0]->width();
// 7. 输出
top[1]->Reshape(batch_size, 1, 1, 1);
this->prefetch_label_.Reshape(batch_size, 1, 1, 1);
vector<int> top_shape = this->data_transformer_->InferBlobShape(datum);
this->transformed_data_.Reshape(top_shape);
}
// 2.2 shuffle视频的顺序
template <typename Dtype>
void VideoDataLayer<Dtype>::ShuffleVideos(){
//prefetch_rng1和prefetch_rng2的顺序是一致的,所以对lines_和lines_duration_进行shuffle的顺序是一致的
caffe::rng_t* prefetch_rng1 = static_cast<caffe::rng_t*>(prefetch_rng_1_->generator());
caffe::rng_t* prefetch_rng2 = static_cast<caffe::rng_t*>(prefetch_rng_2_->generator());
shuffle(lines_.begin(), lines_.end(), prefetch_rng1);
shuffle(lines_duration_.begin(), lines_duration_.end(),prefetch_rng2);
}
// InternalThreadEntry,跟前面的内容差不多,不赘述
template <typename Dtype>
void VideoDataLayer<Dtype>::InternalThreadEntry(){
Datum datum;
CHECK(this->prefetch_data_.count());
Dtype* top_data = this->prefetch_data_.mutable_cpu_data();
Dtype* top_label = this->prefetch_label_.mutable_cpu_data();
VideoDataParameter video_data_param = this->layer_param_.video_data_param();
const int batch_size = video_data_param.batch_size();
const int new_height = video_data_param.new_height();
const int new_width = video_data_param.new_width();
const int new_length = video_data_param.new_length();
const int num_segments = video_data_param.num_segments();
const int lines_size = lines_.size();
bool is_color = !this->layer_param_.video_data_param().grayscale();
for (int item_id = 0; item_id < batch_size; ++item_id){
CHECK_GT(lines_size, lines_id_);
vector<int> offsets;
int average_duration = (int) lines_duration_[lines_id_] / num_segments;
for (int i = 0; i < num_segments; ++i){
if (this->phase_==TRAIN){
if (average_duration >= new_length){
caffe::rng_t* frame_rng = static_cast<caffe::rng_t*>(frame_prefetch_rng_->generator());
int offset = (*frame_rng)() % (average_duration - new_length + 1);
offsets.push_back(offset+i*average_duration);
} else {
offsets.push_back(0);
}
} else{
if (average_duration >= new_length)
offsets.push_back(int((average_duration-new_length+1)/2 + i*average_duration));
else
offsets.push_back(0);
}
}
if (this->layer_param_.video_data_param().modality() == VideoDataParameter_Modality_FLOW){
if(!ReadSegmentFlowToDatum(lines_[lines_id_].first, lines_[lines_id_].second,
offsets, new_height, new_width, new_length, &datum, name_pattern_.c_str())) {
continue;
}
} else{
if(!ReadSegmentRGBToDatum(lines_[lines_id_].first, lines_[lines_id_].second,
offsets, new_height, new_width, new_length, &datum, is_color, name_pattern_.c_str())) {
continue;
}
}
int offset1 = this->prefetch_data_.offset(item_id);
this->transformed_data_.set_cpu_data(top_data + offset1);
this->data_transformer_->Transform(datum, &(this->transformed_data_));
top_label[item_id] = lines_[lines_id_].second;
//LOG()
//next iteration
lines_id_++;
if (lines_id_ >= lines_size) {
DLOG(INFO) << "Restarting data prefetching from start.";
lines_id_ = 0;
if(this->layer_param_.video_data_param().shuffle()){
ShuffleVideos();
}
}
}
}
INSTANTIATE_CLASS(VideoDataLayer);
REGISTER_LAYER_CLASS(VideoData);
}
在io.cpp中有ReadSegmentRGBToDatum函数和ReadSegmentFlowToDatum函数,我们来看一下。
// 5.1 ReadSegmentRGBToDatum函数
bool ReadSegmentRGBToDatum(const string& filename, const int label,
const vector<int> offsets, const int height, const int width, const int length, Datum* datum, bool is_color,
const char* name_pattern ){
cv::Mat cv_img;
string* datum_string;
char tmp[30];
int cv_read_flag = (is_color ? CV_LOAD_IMAGE_COLOR :
CV_LOAD_IMAGE_GRAYSCALE);
for (int i = 0; i < offsets.size(); ++i){
//表示在第i个segment的哪个位置开始取数据
int offset = offsets[i];
//每个segment从offset位置起读取length长度个帧
for (int file_id = 1; file_id < length+1; ++file_id){
sprintf(tmp, name_pattern, int(file_id+offset));
string filename_t = filename + "/" + tmp;
//读入图片
cv::Mat cv_img_origin = cv::imread(filename_t, cv_read_flag);
//如果图片不存在,返回false
if (!cv_img_origin.data){
LOG(ERROR) << "Could not load file " << filename_t;
return false;
}
//如果需要resize,则resize
if (height > 0 && width > 0){
cv::resize(cv_img_origin, cv_img, cv::Size(width, height));
}else{
cv_img = cv_img_origin;
}
int num_channels = (is_color ? 3 : 1);
//初始化数据大小
if (file_id==1 && i==0){
//一个数据的大小 = 视频segment个数 * 每个segment取length的长度 * 每帧的channel数
datum->set_channels(num_channels*length*offsets.size());
datum->set_height(cv_img.rows);
datum->set_width(cv_img.cols);
datum->set_label(label);
datum->clear_data();
datum->clear_float_data();
datum_string = datum->mutable_data();
}
//写入数据到datum
if (is_color) {
for (int c = 0; c < num_channels; ++c) {
for (int h = 0; h < cv_img.rows; ++h) {
for (int w = 0; w < cv_img.cols; ++w) {
datum_string->push_back(
static_cast<char>(cv_img.at<cv::Vec3b>(h, w)[c]));
}
}
}
} else { // Faster than repeatedly testing is_color for each pixel w/i loop
for (int h = 0; h < cv_img.rows; ++h) {
for (int w = 0; w < cv_img.cols; ++w) {
datum_string->push_back(
static_cast<char>(cv_img.at<uchar>(h, w)));
}
}
}
}
}
return true;
}
5.2 ReadSegmentFlowToDatum函数
bool ReadSegmentFlowToDatum(const string& filename, const int label,
const vector<int> offsets, const int height, const int width, const int length, Datum* datum,
const char* name_pattern ){
cv::Mat cv_img_x, cv_img_y;
string* datum_string;
char tmp[30];
for (int i = 0; i < offsets.size(); ++i){
//表示在第i个segment的哪个位置开始取数据
int offset = offsets[i];
//每个segment从offset位置起读取length长度个帧
for (int file_id = 1; file_id < length+1; ++file_id){
//flow的数据包含x,y两个部分,所以需要同时读取两帧图片stack在一起。channel数为2
sprintf(tmp,name_pattern, 'x', int(file_id+offset));
string filename_x = filename + "/" + tmp;
cv::Mat cv_img_origin_x = cv::imread(filename_x, CV_LOAD_IMAGE_GRAYSCALE);
sprintf(tmp, name_pattern, 'y', int(file_id+offset));
string filename_y = filename + "/" + tmp;
cv::Mat cv_img_origin_y = cv::imread(filename_y, CV_LOAD_IMAGE_GRAYSCALE);
//如果flow_x或flow_y不存在,则返回false
if (!cv_img_origin_x.data || !cv_img_origin_y.data){
LOG(ERROR) << "Could not load file " << filename_x << " or " << filename_y;
return false;
}
//如果需要resize,则resize
if (height > 0 && width > 0){
cv::resize(cv_img_origin_x, cv_img_x, cv::Size(width, height));
cv::resize(cv_img_origin_y, cv_img_y, cv::Size(width, height));
}else{
cv_img_x = cv_img_origin_x;
cv_img_y = cv_img_origin_y;
}
if (file_id==1 && i==0){
//一个数据的大小 = 视频segment个数 * 每个segment取length的长度 * 每帧的channel数(2)
int num_channels = 2;
datum->set_channels(num_channels*length*offsets.size());
datum->set_height(cv_img_x.rows);
datum->set_width(cv_img_x.cols);
datum->set_label(label);
datum->clear_data();
datum->clear_float_data();
datum_string = datum->mutable_data();
}
//写入数据到datum
for (int h = 0; h < cv_img_x.rows; ++h){
for (int w = 0; w < cv_img_x.cols; ++w){
datum_string->push_back(static_cast<char>(cv_img_x.at<uchar>(h,w)));
}
}
for (int h = 0; h < cv_img_y.rows; ++h){
for (int w = 0; w < cv_img_y.cols; ++w){
datum_string->push_back(static_cast<char>(cv_img_y.at<uchar>(h,w)));
}
}
}
}
return true;
}