【caffe源码笔记】VideoDataLayer解析

版权声明:本文为博主原创文章,转载请注明。 https://blog.csdn.net/elaine_bao/article/details/79433733

caffe中video_data_layer.cpp的解析,直接看代码中的注释。

// include的部分就不介绍了
#include <fstream>
#include <iostream>
#include <string>
#include <utility>
#include <vector>

#include "caffe/data_layers.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/io.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/util/rng.hpp"

#ifdef USE_MPI
#include "mpi.h"
#include <boost/filesystem.hpp>
using namespace boost::filesystem;
#endif

namespace caffe{
template <typename Dtype>
// 析构函数
VideoDataLayer<Dtype>:: ~VideoDataLayer<Dtype>(){
    this->JoinPrefetchThread();
}
template <typename Dtype>
// DataLayer SetUp
void VideoDataLayer<Dtype>:: DataLayerSetUp(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top){
    const int new_height  = this->layer_param_.video_data_param().new_height();
    const int new_width  = this->layer_param_.video_data_param().new_width();
    //表示一次取多少帧/光流
    const int new_length  = this->layer_param_.video_data_param().new_length(); 
    //表示视频分成多少份(segment)进行抽帧/光流
    const int num_segments = this->layer_param_.video_data_param().num_segments(); 
    //输入的列表文件,这个列表文件的每一行包括:视频抽帧/光流的存储位置,视频帧数,视频label标号 
    const string& source = this->layer_param_.video_data_param().source(); 

    // 1. 读入source文件
    LOG(INFO) << "Opening file: " << source;
    std:: ifstream infile(source.c_str());
    string filename;
    int label;
    int length;
    //从列表的每一行依次读入存储位置,视频帧数(长度),视频label标号
    while (infile >> filename >> length >> label){ 
        //lines_存入(filename,label)
        lines_.push_back(std::make_pair(filename,label)); 
        //lines_duration_存入(length)
        lines_duration_.push_back(length); 
    }
    // 2. shuffle文件行的顺序
    if (this->layer_param_.video_data_param().shuffle()){
        const unsigned int prefectch_rng_seed = caffe_rng_rand();
        //prefetch_rng_1_生成随机数,用于shuffle lines_
        prefetch_rng_1_.reset(new Caffe::RNG(prefectch_rng_seed)); 
        //prefetch_rng_2_生成随机数,用于shuffle lines_duration_
        prefetch_rng_2_.reset(new Caffe::RNG(prefectch_rng_seed)); 
        ShuffleVideos(); //看2.2的部分
    }

    LOG(INFO) << "A total of " << lines_.size() << " videos.";
    lines_id_ = 0;

    // 3. check name pattern:如果name pattern为空,则name_pattern设置为"image_%04d.jpg"(帧)/"flow_%c_%04d.jpg"(光流)
    if (this->layer_param_.video_data_param().name_pattern() == ""){
        if (this->layer_param_.video_data_param().modality() == VideoDataParameter_Modality_RGB){
            name_pattern_ = "image_%04d.jpg";
        }else if (this->layer_param_.video_data_param().modality() == VideoDataParameter_Modality_FLOW){
            name_pattern_ = "flow_%c_%04d.jpg";
        }
    }else{
        name_pattern_ = this->layer_param_.video_data_param().name_pattern();
    }
    // 4. 生成随机frame编号:由于source列表给出的列表是以视频为单位的,视频中具体选择哪帧进行训练需要通过随机生成
    Datum datum;
        bool is_color = !this->layer_param_.video_data_param().grayscale();
    const unsigned int frame_prefectch_rng_seed = caffe_rng_rand();
    frame_prefetch_rng_.reset(new Caffe::RNG(frame_prefectch_rng_seed));
    //将视频总帧数分成num_segments份,每份的平均长度
    int average_duration = (int) lines_duration_[lines_id_]/num_segments; 
    vector<int> offsets;
    for (int i = 0; i < num_segments; ++i){
        //产生随机数
        caffe::rng_t* frame_rng = static_cast<caffe::rng_t*>(frame_prefetch_rng_->generator()); 
        //每个segment都会抽取new_length长度的帧,具体从每个segment的哪个位置开始取,根据随机数决定
        int offset = (*frame_rng)() % (average_duration - new_length + 1); 
        offsets.push_back(offset+i*average_duration);
    }
    // 5. 读数据到datum
    //关于ReadSegmentFlowToDatum函数,见下文5.2
    if (this->layer_param_.video_data_param().modality() == VideoDataParameter_Modality_FLOW)
        CHECK(ReadSegmentFlowToDatum(lines_[lines_id_].first, lines_[lines_id_].second,
                                     offsets, new_height, new_width, new_length, &datum, name_pattern_.c_str()));
    else //关于ReadSegmentRGBToDatum函数,见下文5.1
        CHECK(ReadSegmentRGBToDatum(lines_[lines_id_].first, lines_[lines_id_].second,
                                    offsets, new_height, new_width, new_length, &datum, is_color, name_pattern_.c_str()));
    // 6. crop:如果需要crop则crop
    const int crop_size = this->layer_param_.transform_param().crop_size();
    const int batch_size = this->layer_param_.video_data_param().batch_size();
    if (crop_size > 0){
        top[0]->Reshape(batch_size, datum.channels(), crop_size, crop_size);
        this->prefetch_data_.Reshape(batch_size, datum.channels(), crop_size, crop_size);
    } else {
        top[0]->Reshape(batch_size, datum.channels(), datum.height(), datum.width());
        this->prefetch_data_.Reshape(batch_size, datum.channels(), datum.height(), datum.width());
    }
    LOG(INFO) << "output data size: " << top[0]->num() << "," << top[0]->channels() << "," << top[0]->height() << "," << top[0]->width();
    // 7. 输出
    top[1]->Reshape(batch_size, 1, 1, 1);
    this->prefetch_label_.Reshape(batch_size, 1, 1, 1);

    vector<int> top_shape = this->data_transformer_->InferBlobShape(datum);
    this->transformed_data_.Reshape(top_shape);
}
// 2.2 shuffle视频的顺序
template <typename Dtype>
void VideoDataLayer<Dtype>::ShuffleVideos(){
    //prefetch_rng1和prefetch_rng2的顺序是一致的,所以对lines_和lines_duration_进行shuffle的顺序是一致的
    caffe::rng_t* prefetch_rng1 = static_cast<caffe::rng_t*>(prefetch_rng_1_->generator());
    caffe::rng_t* prefetch_rng2 = static_cast<caffe::rng_t*>(prefetch_rng_2_->generator());
    shuffle(lines_.begin(), lines_.end(), prefetch_rng1); 
    shuffle(lines_duration_.begin(), lines_duration_.end(),prefetch_rng2);
}
// InternalThreadEntry,跟前面的内容差不多,不赘述
template <typename Dtype>
void VideoDataLayer<Dtype>::InternalThreadEntry(){

    Datum datum;
    CHECK(this->prefetch_data_.count());
    Dtype* top_data = this->prefetch_data_.mutable_cpu_data();
    Dtype* top_label = this->prefetch_label_.mutable_cpu_data();
    VideoDataParameter video_data_param = this->layer_param_.video_data_param();
    const int batch_size = video_data_param.batch_size();
    const int new_height = video_data_param.new_height();
    const int new_width = video_data_param.new_width();
    const int new_length = video_data_param.new_length();
    const int num_segments = video_data_param.num_segments();
    const int lines_size = lines_.size();

        bool is_color = !this->layer_param_.video_data_param().grayscale();
    for (int item_id = 0; item_id < batch_size; ++item_id){
        CHECK_GT(lines_size, lines_id_);
        vector<int> offsets;
        int average_duration = (int) lines_duration_[lines_id_] / num_segments;
        for (int i = 0; i < num_segments; ++i){
            if (this->phase_==TRAIN){
                if (average_duration >= new_length){
                    caffe::rng_t* frame_rng = static_cast<caffe::rng_t*>(frame_prefetch_rng_->generator());
                    int offset = (*frame_rng)() % (average_duration - new_length + 1);
                    offsets.push_back(offset+i*average_duration);
                } else {
                    offsets.push_back(0);
                }
            } else{
                if (average_duration >= new_length)
                offsets.push_back(int((average_duration-new_length+1)/2 + i*average_duration));
                else
                offsets.push_back(0);
            }
        }
        if (this->layer_param_.video_data_param().modality() == VideoDataParameter_Modality_FLOW){
            if(!ReadSegmentFlowToDatum(lines_[lines_id_].first, lines_[lines_id_].second,
                                       offsets, new_height, new_width, new_length, &datum, name_pattern_.c_str())) {
                continue;
            }
        } else{
            if(!ReadSegmentRGBToDatum(lines_[lines_id_].first, lines_[lines_id_].second,
                                      offsets, new_height, new_width, new_length, &datum, is_color, name_pattern_.c_str())) {
                continue;
            }
        }

        int offset1 = this->prefetch_data_.offset(item_id);
                this->transformed_data_.set_cpu_data(top_data + offset1);
        this->data_transformer_->Transform(datum, &(this->transformed_data_));
        top_label[item_id] = lines_[lines_id_].second;
        //LOG()

        //next iteration
        lines_id_++;
        if (lines_id_ >= lines_size) {
            DLOG(INFO) << "Restarting data prefetching from start.";
            lines_id_ = 0;
            if(this->layer_param_.video_data_param().shuffle()){
                ShuffleVideos();
            }
        }
    }
}

INSTANTIATE_CLASS(VideoDataLayer);
REGISTER_LAYER_CLASS(VideoData);
}

在io.cpp中有ReadSegmentRGBToDatum函数和ReadSegmentFlowToDatum函数,我们来看一下。

// 5.1 ReadSegmentRGBToDatum函数
bool ReadSegmentRGBToDatum(const string& filename, const int label,
    const vector<int> offsets, const int height, const int width, const int length, Datum* datum, bool is_color,
    const char* name_pattern ){
    cv::Mat cv_img;
    string* datum_string;
    char tmp[30];
    int cv_read_flag = (is_color ? CV_LOAD_IMAGE_COLOR :
        CV_LOAD_IMAGE_GRAYSCALE);
    for (int i = 0; i < offsets.size(); ++i){
      //表示在第i个segment的哪个位置开始取数据
        int offset = offsets[i];
        //每个segment从offset位置起读取length长度个帧
        for (int file_id = 1; file_id < length+1; ++file_id){
            sprintf(tmp, name_pattern, int(file_id+offset));
            string filename_t = filename + "/" + tmp;
            //读入图片
            cv::Mat cv_img_origin = cv::imread(filename_t, cv_read_flag);
            //如果图片不存在,返回false
            if (!cv_img_origin.data){
                LOG(ERROR) << "Could not load file " << filename_t;
                return false;
            }
            //如果需要resize,则resize
            if (height > 0 && width > 0){
                cv::resize(cv_img_origin, cv_img, cv::Size(width, height));
            }else{
                cv_img = cv_img_origin;
            }
            int num_channels = (is_color ? 3 : 1);
            //初始化数据大小
            if (file_id==1 && i==0){
              //一个数据的大小 = 视频segment个数 * 每个segment取length的长度 * 每帧的channel数
                datum->set_channels(num_channels*length*offsets.size());
                datum->set_height(cv_img.rows);
                datum->set_width(cv_img.cols);
                datum->set_label(label);
                datum->clear_data();
                datum->clear_float_data();
                datum_string = datum->mutable_data();
            }
            //写入数据到datum
            if (is_color) {
                for (int c = 0; c < num_channels; ++c) {
                  for (int h = 0; h < cv_img.rows; ++h) {
                    for (int w = 0; w < cv_img.cols; ++w) {
                      datum_string->push_back(
                        static_cast<char>(cv_img.at<cv::Vec3b>(h, w)[c]));
                    }
                  }
                }
              } else {  // Faster than repeatedly testing is_color for each pixel w/i loop
                for (int h = 0; h < cv_img.rows; ++h) {
                  for (int w = 0; w < cv_img.cols; ++w) {
                    datum_string->push_back(
                      static_cast<char>(cv_img.at<uchar>(h, w)));
                    }
                  }
              }
        }
    }
    return true;
}
5.2 ReadSegmentFlowToDatum函数
bool ReadSegmentFlowToDatum(const string& filename, const int label,
    const vector<int> offsets, const int height, const int width, const int length, Datum* datum,
    const char* name_pattern ){
    cv::Mat cv_img_x, cv_img_y;
    string* datum_string;
    char tmp[30];
    for (int i = 0; i < offsets.size(); ++i){
      //表示在第i个segment的哪个位置开始取数据
        int offset = offsets[i];
        //每个segment从offset位置起读取length长度个帧
        for (int file_id = 1; file_id < length+1; ++file_id){
          //flow的数据包含x,y两个部分,所以需要同时读取两帧图片stack在一起。channel数为2
            sprintf(tmp,name_pattern, 'x', int(file_id+offset));
            string filename_x = filename + "/" + tmp;
            cv::Mat cv_img_origin_x = cv::imread(filename_x, CV_LOAD_IMAGE_GRAYSCALE);
            sprintf(tmp, name_pattern, 'y', int(file_id+offset));
            string filename_y = filename + "/" + tmp;
            cv::Mat cv_img_origin_y = cv::imread(filename_y, CV_LOAD_IMAGE_GRAYSCALE);
            //如果flow_x或flow_y不存在,则返回false
            if (!cv_img_origin_x.data || !cv_img_origin_y.data){
                LOG(ERROR) << "Could not load file " << filename_x << " or " << filename_y;
                return false;
            }
            //如果需要resize,则resize
            if (height > 0 && width > 0){
                cv::resize(cv_img_origin_x, cv_img_x, cv::Size(width, height));
                cv::resize(cv_img_origin_y, cv_img_y, cv::Size(width, height));
            }else{
                cv_img_x = cv_img_origin_x;
                cv_img_y = cv_img_origin_y;
            }
            if (file_id==1 && i==0){
             //一个数据的大小 = 视频segment个数 * 每个segment取length的长度 * 每帧的channel数(2)
                int num_channels = 2;
                datum->set_channels(num_channels*length*offsets.size());
                datum->set_height(cv_img_x.rows);
                datum->set_width(cv_img_x.cols);
                datum->set_label(label);
                datum->clear_data();
                datum->clear_float_data();
                datum_string = datum->mutable_data();
            }
            //写入数据到datum
            for (int h = 0; h < cv_img_x.rows; ++h){
                for (int w = 0; w < cv_img_x.cols; ++w){
                    datum_string->push_back(static_cast<char>(cv_img_x.at<uchar>(h,w)));
                }
            }
            for (int h = 0; h < cv_img_y.rows; ++h){
                for (int w = 0; w < cv_img_y.cols; ++w){
                    datum_string->push_back(static_cast<char>(cv_img_y.at<uchar>(h,w)));
                }
            }
        }
    }
    return true;
}

猜你喜欢

转载自blog.csdn.net/elaine_bao/article/details/79433733