【TRT】C++多线程

1. thread

1.1 启动线程

#include <stdio.h>
#include <chrono>
#include <thread>

using namespace std;

void worker(int a, std::string& output) {
    
    
    printf("hello thread!\n");
    this_thread::sleep_for(chrono::milliseconds(1000));
    output = "work output";
    printf("worker done.\n");
}

int main() {
    
    

    std::string output;
    thread t(worker, 567, std::ref(output));

    if (t.joinable()) {
    
    
        t.join();
    }
    printf("output: %s\n", output);
    printf("main done.\n");
    return 0;
}

注意点:

  • 如果主线程不join.等待子线程。在析构的时候会报错。
  • 如果join一个没有启动的线程变量,也会报错
  • 在join时判断线程变量是否可以join

1.2 detach

  • 分离线程,取消管理权,使得线程成为野线程。一般不建议使用,主要用在,不需要知道线程何时结束的场景,比如多线程拷贝文件操作时。
  • detach 过后,不需要join。

1.3 参数传递

注意传引用的时候

thread t(worker, 567, std::ref(output));

1.4 类成员函数作为线程函数

class Infer {
    
    
public:
    Infer() {
    
    
        worker_thread_ = thread(&Infer::infer_worker, this);
    }
    ~Infer() {
    
    
        if(worker_thread_.joinable()) {
    
    
            worker_thread_.join();
        }
    }
private:
    void infer_worker() {
    
    
        for (size_t i = 0; i < 100; i++)
        {
    
    
            printf("hello thread!\n");
            this_thread::sleep_for(chrono::milliseconds(2000));
        }
    }

private:
    thread worker_thread_;
};

将类成员函数作为线程函数:

worker_thread_ = thread(&Infer::infer_worker, this);

2. 生产者消费者

2.1 最简单实现

#include <stdio.h>
#include <thread>
#include <mutex>
#include <string>
#include <queue>
#include <chrono>

using namespace std;
queue<string> qjobs_;

void video_capture() {
    
    
    int pic_id = 0;
    while(true) {
    
    
        char name[100];
        sprintf(name, "PIC_%d", pic_id++);
        printf("生产了一个新图片: %s\n", name);
        qjobs_.push(name);
        this_thread::sleep_for(chrono::microseconds(1000));
    }
}

void infer_worker() {
    
    
    while (true)
    {
    
    
        if(!qjobs_.empty()) {
    
    
            auto pic = qjobs_.front();
            qjobs_.pop();
            printf("消费掉一个图片: %s\n", pic);
            this_thread::sleep_for(chrono::microseconds(1000));
        }
        // 强制当前线程交出时间片,防止一直占用cpu资源
        this_thread::yield();
    }
    
}

int main() {
    
    
    thread t1(video_capture);
    thread t2(infer_worker);
    if (t1.joinable()) {
    
    
        t1.join();
    }
    if (t2.joinable()) {
    
    
        t2.join();
    }
    printf("Done!");
    return 0;
}

因为queue 不是线程安全的,所以需要对共享资源加锁.
这里使用加锁的逻辑

#include <mutex>
mutex lock_;
{
    
    
	lock_guard l(lock_);
	/*code*/
}
#include <stdio.h>
#include <thread>
#include <mutex>
#include <string>
#include <queue>
#include <chrono>
#include <mutex>

using namespace std;
mutex lock_;
queue<string> qjobs_;

void video_capture() {
    
    
    int pic_id = 0;
    while(true) {
    
    
        {
    
    
            lock_guard lock(lock_);
            char name[100];
            sprintf(name, "PIC_%d", pic_id++);
            printf("生产了一个新图片: %s\n", name);
            qjobs_.push(name);
        }
        this_thread::sleep_for(chrono::microseconds(1000));
    }
}

void infer_worker() {
    
    
    while (true)
    {
    
    
        if(!qjobs_.empty()) {
    
    
            {
    
    
                lock_guard lock(lock_);
                auto pic = qjobs_.front();
                qjobs_.pop();
                printf("消费掉一个图片: %s\n", pic);
            }
            this_thread::sleep_for(chrono::microseconds(1000));
        }
        // 强制当前线程交出时间片,防止一直占用cpu资源
        this_thread::yield();
    }
    
}

int main() {
    
    
    thread t1(video_capture);
    thread t2(infer_worker);
    if (t1.joinable()) {
    
    
        t1.join();
    }
    if (t2.joinable()) {
    
    
        t2.join();
    }
    printf("Done!");
    return 0;
}

2.2. 队列溢出问题

当生产太快,消费太慢,如何实现溢出限制。当队列满的时候,不生产,等待队列有空间再生产。需求描述:

  • 当队列满的时候,停止生产,生产线程阻塞,并且释放对队列的锁。
  • 消费者消费后,通知生产线程,生产线程重新获得锁,继续生产。
    为了满足上述的需求,这里可以使用c++ 的 condition_variable来实现
// 生产者线程
{
    
    
   unique_lock lock(lock_);
     // cv_.wait() 
     // 当条件满足时, 继续执行, 获得锁的占有权
     // 当条件不满足时,线程阻塞等待,释放锁
     cv_.wait(lock, [&](){
    
    
         return qjobs_.size() < limit_;
     });
     char name[100];
     sprintf(name, "PIC_%d", pic_id++);
     printf("生产了一个新图片: %s qjob.size(): %d\n", name, qjobs_.size());
     qjobs_.push(name);
 }
// 消费者线程
{
    
    
   unique_lock lock(lock_);
    auto pic = qjobs_.front();
    qjobs_.pop();
    printf("消费掉一个图片: %s\n", pic);
    // 消费一个后通知 cv_ 并释放锁
    cv_.notify_one();
}
  • condition_variable 必须配合unique_lock 使用
  • cv_.wait() 当条件满足时 获得锁的占有权,继续执行
  • 当条件不满足时,释放锁,线程阻塞等待
  • cv_.notify_one() 消费完后通知生产者线程继续生产,并释放锁。

2.3. 跨线程结果传输

为了在消费者线程外拿到推理结果,需要借助一些机制来实现跨线程结果传输。
这里利用c++ 标准线程的 futurepromise 实现
构建 Job 结构体

struct Job {
    
    
	shared_ptr<promise<string>> pro;
	string input;
}

qjobs_

struct Job {
    
    
    shared_ptr<promise<string>> pro;
    string input;
};

生产者线程

void video_capture() {
    
    
    int pic_id = 0;
    while(true) {
    
    
        Job job;
        {
    
    
            unique_lock lock(lock_);
            // cv_.wait() 当条件满足时 获得锁的占有权,继续执行
            // 当条件不满足时,释放锁,线程阻塞等待
            cv_.wait(lock, [&](){
    
    
                return qjobs_.size() < limit_;
            });
            char name[100];
            sprintf(name, "PIC_%d", pic_id++);
            printf("生产了一个新图片: %s qjob.size(): %d\n", name, qjobs_.size());
            job.pro.reset(new promise<string>());
            job.input = name;
            qjobs_.push(job);
        }
        // 等待结果
        auto result = job.pro->get_future().get();
        printf("JOB %s -> %s\n", job.input.c_str(), result);
        this_thread::sleep_for(chrono::milliseconds(2000));
    }
}

消费者

void infer_worker() {
    
    
    while (true)
    {
    
    
        if(!qjobs_.empty()) {
    
    
            {
    
    
                unique_lock lock(lock_);
                auto pjob = qjobs_.front();
                qjobs_.pop();
                printf("消费掉一个图片: %s\n", pjob.input);
                // 消费一个后通知 cv_ 并释放锁
                auto result = pjob.input + "--infer";
                // 存放值
                pjob.pro->set_value(result);
                cv_.notify_one();
            }
            this_thread::sleep_for(chrono::milliseconds(4000));
        }
        // 强制当前线程交出时间片,防止一直占用cpu资源
        this_thread::yield();
    }
    
}

3. 封装多线程Infer类

3.1 Infer类实现

#include <string>
using namespace std;

class Infer {
    
    
public:
    bool load_model(const string& file) {
    
    
        context_ = file;
        return true;
    }

    void forward() {
    
    
        if(context_.empty()) {
    
    
            printf("模型没有加载.\n");
            return;
        }
        /*forward logic*/
    }

	bool destory() {
    
    
		context_.clear();
}

private:
    string context_;
};

Infer infer;
infer.forward();

常见的Infer类将模型状态管理和模型推理放在一个Infer 类里面,这样会导致一些问题:

  • Model 没有正常加载时,推理异常判断。
  • Model 被释放了,会导致推理异常。
  • 重复load逻辑
  • 需要手动释放模型。

Note: Infer中必须考虑模型状态相关的异常情况。

扫描二维码关注公众号,回复: 16850304 查看本文章
  • 正常代码中存在大量异常情况的处理
  • 异常逻辑如果没有写好,后者没有考虑到,将会导致程序崩溃。

3.2 使用RAII实现Infer对象的创建

shared_ptr<Infer> create_infer(const string& file) {
    
    
    shared_ptr<Infer> instance(new Infer());
    if(!instance->load_model(file)) {
    
    
        instance.reset();
    }
    return instance;
}
  • 获取Infer 实例,立即加载模型
  • 加载模型失败,表示资源获取失败
  • 加载模型成功,则资源获取成功

调用代码

auto infer = create_infer("a");
if (infer == nullptr) {
    
    
	printf("failed.\n");
	return -1;
}
infer->forward();
  • 避免了外部执行load_model. 保证只调用一次,避免了重复调用的逻辑
  • 获取模型一定初始化成功,因此forward函数,不必判断模型是否加载成功

3.3 接口模式

解决问题:

  • 解决load_model 还能被外部看到的问题,拒绝外面调用load_model
  • 解决成员变量对外可见的问题。对于成员函数是特殊类型,比如说cudaStream_t, 那么使用者必定包含cuda_runtime.h,否则语法解析失败,造成头文件污染。

头文件接口

#ifndef __INFER_H__
#define __INFER_H__
#include <memory>
#include <string>

// 接口类,纯虚类
// 原则是: 只暴露调用者需要的函数,其他一概不暴露
// 比如load_model, 通过RAII 做定义,因此load_model不需要
// 内部如果有启动线程等,start, stop 也不需要暴露,而是初始化的时候就自动启动,都是RAII的定义
class Interface {
    
    
public:
    virtual void forward() = 0;
};

std::shared_ptr<Interface> create_infer(const std::string& file);

#endif // __INFER_H__
  • 头文件只暴露最简单的功能,避免造成头文件污染。有可能用户找不到依赖库的头文件。
  • 隐藏了内部实现。方便算法库的封装
  • 接口模式对编译友好,防止修改头文件,导致引入头文件的文件重新编译。

实现类

#include "infer.h"
using namespace std;

class InferImpl: public Interface{
    
    
public:
    bool load_model(const string& file) {
    
    
        context_ = file;
        return true;
    }

    virtual void forward() override{
    
    
        /*forward logic*/
    }

private:
    string context_;
};

shared_ptr<Interface> create_infer(const string& file) {
    
    
    shared_ptr<InferImpl> instance(new InferImpl());
    if(!instance->load_model(file)) {
    
    
        instance.reset();
    }
    return instance;
}

使用算法接口

#include "infer.h"

auto infer = create_infer("a");
if (infer == nullptr) {
    
    
	printf("Create infer failed!\n";
	return -1;
}
auto result = infer->forward("xxx");
return 0;

4. 多batch生产者消费者类的封装

4.1 封装

class InferImpl: public Interface{
    
    
public:
    bool load_model(const string& file) {
    
    
        // 尽量保证资源哪里分配哪里释放,哪里使用。这样使得程序足够简单
        context_ = file;
        worker_thread_ = thread(&InferImpl::worker, this);
        return true;
    }

    virtual void forward() override{
    
    
        /*forward logic*/
    }
private:
    void worker() {
    
    

    }
private:
    thread worker_thread_;
    string context_;
};
  bool load_model(const string& file) {
    
    
       // 尽量保证资源哪里分配哪里释放,哪里使用。这样使得程序足够简单
       context_ = file;
       worker_thread_ = thread(&InferImpl::worker, this);
       return true;
   }

在这个代码里,context 在load_model中分配,但是在worker线程中使用。这样不够好,应该保证资源在同一个地方分配,统一个地方释放,同一个地方使用。这样做得目的是保证程序足够简单,并且防止资源泄露。

4.2 修改模型加载

修改:

class InferImpl: public Interface{
    
    
public:
    bool load_model(const string& file) {
    
    
        worker_thread_ = thread(&InferImpl::worker, this, file);
        return true;
    }

    virtual void forward() override{
    
    
        /*forward logic*/
    }
private:
    void worker(string f) {
    
    
        context_ = f;
    }
private:
    thread worker_thread_;
    string context_;
};

这样将模型的context_ 加载放在了worker线程里面,解决了模型和推理模块不在同一个线程的问题。但是这个代码也有如下问题:

  • 不知道最后模型加载的状态
    这里使用 futurepromise 拉解决获取模型加载状态获取的问题。
class InferImpl: public Interface{
    
    
public:
    bool load_model(const string& file) {
    
    
        // 尽量保证资源哪里分配哪里释放,哪里使用。这样使得程序足够简单
        promise<bool> pro;
        worker_thread_ = thread(&InferImpl::worker, this, file, std::ref(pro));
        return pro.get_future().get();
    }

    virtual void forward() override{
    
    
        /*forward logic*/
    }
private:
    void worker(string f, promise<bool>& pro) {
    
    
        string context = f;
        pro.set_value(context_.empty());
        while(true) {
    
    
			/*customer*/
		}
    }
private:
    thread worker_thread_;
    //string context_;
};

由于contex_只在worker内部使用,所以也就不需要context_的成员变量了,context_的生命周期只在 worker() 线程内。

4.3 任务提交与结果返回

4.3.1 直接返回结果

 virtual string forward(string& input) override{
    
    
     /*forward logic
         往队列丢任务
     */
     Job job;
     job.pro.reset(new promise<string>());
     job.input = input;
     qjobs_.push(job);

     /*如何返回结果?*/
     return job.pro->get_future().get();
 }

外部调用代码:

string result1 = infer->forward(input1);
string result2 = infer->forward(input2);
string result3 = infer->forward(input3);

4.3.2 使用shared_future 返回结果

上述方式可以解决结果获取的问题,但是每次提交任务后,必须等待任务处理过后才能提价下一个任务,这本质上还是串行的所以可以使用 shared_future 直接返回 一个 future 让调用方决定什么时候去等待结果。代码如下:

 virtual shared_future<string> forward(string& input) override{
    
    
    Job job;
	job.pro.reset(new std::promise<std::string>);
	job.input = input;
	//std::this_thread::sleep_for(std::chrono::milliseconds(100));

	std::shared_future<std::string> fut = job.pro->get_future();
	{
    
    
		std::lock_guard<std::mutex> l(lock_);
		qjobs_.emplace(job);
	}
	// 被动通知,有任务发送给worker
	cv_.notify_one();
	return fut;
 }

调用代码:

/*提交任务*/
auto rst_future1 = infer->forward(input1);
auto rst_future2 = infer->forward(input2);
auto rst_future3 = infer->forward(input3);

/*在需要的时候获取结果*/
rst_future1.get();
rst_future2.get();
rst_future3.get()

4.4 多batch 推理

void worker(std::string file, std::promise<bool>& pro) {
    
    

		//std::string file = "aaa";
		// worker 内实现,模型的加载、使用、释放
		std::string context_ = file;
		if (context_.empty()) {
    
    
			pro.set_value(false);
			return;
		}
		else {
    
    
			is_running_ = true;
			pro.set_value(true);
		}
		int max_batch_size = 5;
		std::vector<Job> jobs;
		int batch_ids = 0;
		while (true) {
    
    
			//  在队列取任务并执行的过程
			{
    
    
				std::unique_lock<std::mutex> l(lock_);
				cv_.wait(l, [&]() {
    
    
					return !qjobs_.empty();
				});
				while (jobs.size() < max_batch_size && !qjobs_.empty()) {
    
    
					jobs.emplace_back(std::move(qjobs_.front()));
					qjobs_.pop();
				}

				// batch process
				for (auto& job : jobs) {
    
    
					char buff[100];
					sprintf_s(buff, "%s ---processed[%d]", job.input.c_str(), batch_ids);
					job.pro->set_value(buff);
				}
				std::this_thread::sleep_for(std::chrono::milliseconds(1500));
				batch_ids++;
				jobs.clear();

			} // end unique_lock
		}
		printf("[%s] Infer worker done. \n", file.c_str());
	}

一次从队列中取出多个任务,组成一个batch进行推理。

4.5 程序推出机制

在上面的程序中,并没有程序推出的机制即:
当进程终止时,worker()线程仍然在运行,这会导致程序推出而子线程没有推出,导致报错。这里设置一个标示线程是否退出的变量 ** atomic is_running_**。

std::atomic<bool>			is_running_{
    
    false};
void stop() {
    
    
	if (is_running_) {
    
    
		is_running_ = false;
		// 退出worker线程的等待
		cv_.notify_one();
	}

	//  保证推理线程结束,防止成为孤儿线程
	if (this->worker_thread_.joinable()) {
    
    
		worker_thread_.join();
	}
}

析构时或者在需要停止推理时调用

virtual ~InferImpl() {
    
    
	stop();
}

消费者逻辑修改

void worker(std::string file, std::promise<bool>& pro) {

	//std::string file = "aaa";
	// worker 内实现,模型的加载、使用、释放
	std::string context_ = file;
	if (context_.empty()) {
		pro.set_value(false);
		return;
	}
	else {
		is_running_ = true;
		pro.set_value(true);
	}
	int max_batch_size = 5;
	std::vector<Job> jobs;
	int batch_ids = 0;
	while (is_running_) {
		//  在队列取任务并执行的过程
		{
			std::unique_lock<std::mutex> l(lock_);
			cv_.wait(l, [&]() {
				return !is_running_ || !qjobs_.empty();
			});

			if (!is_running_) break;
			while (jobs.size() < max_batch_size && !qjobs_.empty()) {
				jobs.emplace_back(std::move(qjobs_.front()));
				qjobs_.pop();
			}

			// batch process
			for (auto& job : jobs) {
				char buff[100];
				sprintf_s(buff, "%s ---processed[%d]", job.input.c_str(), batch_ids);
				job.pro->set_value(buff);
			}
			std::this_thread::sleep_for(std::chrono::milliseconds(1500));
			batch_ids++;
			jobs.clear();

		} // end unique_lock

		
	}
	printf("[%s] Infer worker done. \n", file.c_str());

}
  • while(true) 变成 while(is_running_)
  • 信号量等待退出条件变成如下,即当不再运行时,结束等待,并且退出推理线程。
cv_.wait(l, [&]() {
    
    
	return !is_running_ || !qjobs_.empty();
});

if (!is_running_) break;

5. 生产者上限控制

问题提出:生产频率太高,commit 频率太高,而消费频率太低。导致内存占用太大,程序无法长时间运行。缺少队列上限限制的机制。使用独占分配器解决:

  • tensor 复用

  • 队列上限限制

  • 向tensor_alloctor_ 申请一个 tensor

  • 预先分配固定数量的tensor, 比如10个

  • 如果申请的时候,有空闲的tensor没有被分配出去,则把这个空闲给申请者

  • 如果申请的时候,没有空闲的tensor,此时,让申请者等待。

  • 如果使用者使用完毕了,通知tensor_allocator_, 告诉他这个tensor不用了,可以分配给别人了。

  • 这样实现了tensor复用,并且控制了队列的上限。

内存复用类,申请固定数量的内存池,完成内存复用。当内存池中的内存块消耗完后,则不能申请新的内存,需要等待,若等待超时,则返回空。

#ifndef __MONOPOLY_ALLOCATOR_H__
#define __MONOPOLY_ALLOCATOR_H__

#include <condition_variable>
#include <vector>
#include <mutex>
#include <memory>

template<class _ItemType>
class MonopolyAllocator {
    
    
public:
    class MonopolyData {
    
    
    public:
        std::shared_ptr<_ItemType>& data() {
    
     return data_; }
        void release() {
    
     manager_->release_one(this); }

    private:
        MonopolyData(MonopolyAllocator* pmanager) {
    
     manager_ = pmanager; }

    private:
        friend class MonopolyAllocator;
        MonopolyAllocator* manager_ = nullptr;
        std::shared_ptr<_ItemType> data_;
        bool available_ = true;
    };
    typedef std::shared_ptr<MonopolyData> MonopolyDataPointer;

    MonopolyAllocator(int size) {
    
    
        capacity_ = size;
        num_available_ = size;
        datas_.resize(size);

        for (int i = 0; i < size; ++i)
            datas_[i] = std::shared_ptr<MonopolyData>(new MonopolyData(this));
    }

    virtual ~MonopolyAllocator() {
    
    
        run_ = false;
        cv_.notify_all();

        std::unique_lock<std::mutex> l(lock_);
        cv_exit_.wait(l, [&]() {
    
    
            return num_wait_thread_ == 0;
            });
    }

    MonopolyDataPointer query(int timeout = 10000) {
    
    

        std::unique_lock<std::mutex> l(lock_);
        if (!run_) return nullptr;

        if (num_available_ == 0) {
    
    
            num_wait_thread_++;

            auto state = cv_.wait_for(l, std::chrono::milliseconds(timeout), [&]() {
    
    
                return num_available_ > 0 || !run_;
                });

            num_wait_thread_--;
            cv_exit_.notify_one();

            // timeout, no available, exit program
            if (!state || num_available_ == 0 || !run_)
                return nullptr;
        }

        auto item = std::find_if(datas_.begin(), datas_.end(), [](MonopolyDataPointer& item) {
    
    return item->available_; });
        if (item == datas_.end())
            return nullptr;

        (*item)->available_ = false;
        num_available_--;
        return *item;
    }

    int num_available() {
    
    
        return num_available_;
    }

    int capacity() {
    
    
        return capacity_;
    }

private:
    void release_one(MonopolyData* prq) {
    
    
        std::unique_lock<std::mutex> l(lock_);
        if (!prq->available_) {
    
    
            prq->available_ = true;
            num_available_++;
            cv_.notify_one();
        }
    }

private:
    std::mutex lock_;
    std::condition_variable cv_;
    std::condition_variable cv_exit_;
    std::vector<MonopolyDataPointer> datas_;
    int capacity_ = 0;
    volatile int num_available_ = 0;
    volatile int num_wait_thread_ = 0;
    volatile bool run_ = true;
};

#endif // __MONOPOLY_ALLOCATOR_H__

完整代码参考

https://github.com/JilinLi4/trt_infer/tree/master/test/ThreadTest

猜你喜欢

转载自blog.csdn.net/qq_30340349/article/details/130837142