1. thread_pool.h
#pragma once
#include <iostream>
#include <vector>
#include <queue>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <stdexcept>
class ThreadPool
{
public:
ThreadPool(size_t);
template <class F, class... Args>
auto enqueue(F &&f, Args &&...args)
-> std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool();
private:
std::vector<std::thread> workers;
std::queue<std::function<void()>> tasks;
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
};
ThreadPool::ThreadPool(size_t threads)
: stop(false)
{
for (size_t i = 0; i < threads; ++i)
workers.emplace_back(
[this]
{
for (;;)
{
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock,
[this]
{
return this->stop || !this->tasks.empty(); });
if (this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
});
}
template <class F, class... Args>
auto ThreadPool::enqueue(F &&f, Args &&...args)
-> std::future<typename std::result_of<F(Args...)>::type>
{
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared<std::packaged_task<return_type()>>(
std::bind(std::forward<F>(f), std::forward<Args>(args)...));
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
if (stop)
throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task]()
{
(*task)(); });
}
condition.notify_one();
return res;
}
ThreadPool::~ThreadPool()
{
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for (std::thread &worker : workers)
worker.join();
}
- 定义了一个名为ThreadPool的类,它包含以下成员:
workers
:一个std::vector
,用于存储线程池中的工作线程
tasks
:一个std::queue
,用于存储待执行的任务
queue_mutex
:一个互斥锁,用于同步任务队列的访问
condition
:一个条件变量,用于在添加新任务时唤醒工作线程
stop
:一个布尔值,表示线程池是否应停止接受新任务并等待所有线程完成后终止。
ThreadPool
类的构造函数接受一个size_t
类型的参数,表示线程池中工作线程的数量。在构造函数中,创建指定数量的工作线程,并在这些线程中执行一个匿名函数,该匿名函数用于从任务队列中获取任务并执行。
enqueue
成员函数模板用于向线程池添加新任务。它接受一个可调用对象f和其参数args,并将任务添加到任务队列中。enqueue
函数返回一个std::future
对象,表示任务的异步结果。
ThreadPool
类的析构函数设置stop
标志为true
,通知所有工作线程停止接受新任务,并在所有任务完成后终止。然后,调用condition.notify_all()
唤醒所有等待的工作线程,最后使用join()
等待所有工作线程完成。
2. main.cpp
#include <iostream>
#include <vector>
#include <random>
#include <chrono>
#include "thread_pool.h"
void multiply(const std::vector<std::vector<int>> &A, const std::vector<std::vector<int>> &B, std::vector<std::vector<int>> &C, size_t row)
{
size_t num_columns = B[0].size();
size_t num_inner = A[0].size();
for (size_t col = 0; col < num_columns; ++col)
{
C[row][col] = 0;
for (size_t inner = 0; inner < num_inner; ++inner)
{
C[row][col] += A[row][inner] * B[inner][col];
}
}
}
int main()
{
const size_t matrix_size = 1000;
std::vector<std::vector<int>> A(matrix_size, std::vector<int>(matrix_size));
std::vector<std::vector<int>> B(matrix_size, std::vector<int>(matrix_size));
std::vector<std::vector<int>> C(matrix_size, std::vector<int>(matrix_size));
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dis(1, 10);
for (size_t i = 0; i < matrix_size; ++i)
{
for (size_t j = 0; j < matrix_size; ++j)
{
A[i][j] = dis(gen);
B[i][j] = dis(gen);
}
}
ThreadPool pool(20);
auto start = std::chrono::high_resolution_clock::now();
std::vector<std::future<void>> results;
for (size_t i = 0; i < matrix_size; ++i)
{
results.emplace_back(pool.enqueue(multiply, std::cref(A), std::cref(B), std::ref(C), i));
}
for (auto &&result : results)
{
result.get();
}
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> elapsed = end - start;
std::cout << "Time elapsed: " << elapsed.count() << " seconds" << std::endl;
return 0;
}
- 定义了一个名为
multiply
的函数,用于计算矩阵相乘。这个函数接受两个输入矩阵A
和B
,一个输出矩阵C
以及一个行索引。在这个示例中,每个线程将负责计算矩阵C
中的一行。
main
函数中,首先定义了矩阵的大小(matrix_size),并创建了大小为matrix_size
的二维矩阵A、B和C。
- 使用随机数生成器填充矩阵A和B的元素。
- 创建一个包含4个工作线程的线程池
pool
。
- 记录开始时间,然后将每行矩阵相乘的任务添加到线程池中。这里使用了
std::cref
和std::ref
来传递矩阵的引用,以避免不必要的拷贝。
- 使用
results
向量存储每个任务返回的std::future
对象。
- 遍历
results
向量,并调用每个std::future
对象的get()
方法,以确保所有任务都已完成。
- 记录结束时间,计算并输出所用时间。