caffe源码分析-InternalThread

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/haluoluo211/article/details/82956589

InternalThread封装自boost::thread的线程,主要用于多线程的数据获取(可以理解为solver前向传播的同时,后台线程继续获取下一个batch的数据集):

class InternalThread {
public:
    InternalThread() : thread_() {}
    virtual ~InternalThread();
    //Caffe's thread local state will be initialized.
    void StartInternalThread();

    /** Will not return until the internal thread has exited. */
    void StopInternalThread();
    bool is_started() const;

protected:
    virtual void InternalThreadEntry() {} // 子类实现这个方法.

    /* Should be tested when running loops to exit when requested. */
    bool must_stop();
private:
    void entry(int device, Caffe::Brew mode, ....);
private:
    shared_ptr<boost::thread> thread_;
};

下面看几个核心的函数:

void InternalThread::StartInternalThread() {
    CHECK(!is_started()) << "Threads should persist and not be restarted.";
    // init parameters ......
    int device = 0;
    int rand_seed = caffe_rng_rand();
    // create new thread, and bind to method.
    thread_.reset(new boost::thread(&InternalThread::entry, this, device));
    //.....
}

void InternalThread::entry(int device, Caffe::Brew mode, int rand_seed,
                           int solver_count, bool root_solver) {
    // set caffe parameter
    Caffe::set_mode(mode);
    Caffe::set_random_seed(rand_seed);
// 实际运行的函数, 子类根据需要去实现:
    InternalThreadEntry();
}

析构函数:等待线程退出:

InternalThread::~InternalThread() {
    StopInternalThread();
}
void InternalThread::StopInternalThread() {
    if (is_started()) {
        thread_->interrupt();
         thread_->join();
    }
}

使用示例如下:

template <typename Dtype>
void BasePrefetchingDataLayer<Dtype>::InternalThreadEntry() {
        while (!must_stop()) {
            Batch<Dtype> *batch = prefetch_free_.pop();
            load_batch(batch);

            prefetch_full_.push(batch);
        }
}

void DataReader::Body::InternalThreadEntry() {
    shared_ptr<db::DB> db(db::GetDB(param_.data_param().backend()));
    db->Open(param_.data_param().source(), db::READ);
    shared_ptr<db::Cursor> cursor(db->NewCursor());
    vector<shared_ptr<QueuePair> > qps;

        // Main loop
        while (!must_stop()) {
            for (int i = 0; i < solver_count; ++i) {
                read_one(cursor.get(), qps[i].get());
            }
            CHECK_EQ(new_queue_pairs_.size(), 0);
        }
}

上面的线程while循环中会根据,BlockingQueue队列的决定是否阻塞当前线程.

猜你喜欢

转载自blog.csdn.net/haluoluo211/article/details/82956589