以下代码由奇妙之二进制和chatgpt共同创作。
/**
* @file thread_safe_list.h
* @brief A thread-safe implementation of std::list.
*/
#ifndef THREAD_SAFE_LIST_H_
#define THREAD_SAFE_LIST_H_
#include <list>
#include <mutex>
#include <condition_variable>
#include <chrono>
#include <functional>
/**
* @brief A thread-safe implementation of std::list.
* @tparam T The type of elements in the list.
*/
template<typename T>
class ThreadSafeList final {
public:
/**
* @brief Default constructor.
*/
ThreadSafeList() = default;
/**
* @brief Copy constructor.
* @param other The other ThreadSafeList to copy from.
*/
ThreadSafeList(const ThreadSafeList& other) {
std::scoped_lock lock(mutex_, other.mutex_);
list_ = other.list_;
}
/**
* @brief Move constructor.
* @param other The other ThreadSafeList to move from.
*/
ThreadSafeList(ThreadSafeList&& other) {
std::scoped_lock lock(mutex_, other.mutex_);
list_ = std::move(other.list_);
}
/**
* @brief Copy assignment operator.
* @param other The other ThreadSafeList to copy from.
* @return Reference to the current ThreadSafeList.
*/
ThreadSafeList& operator=(const ThreadSafeList& other) {
if (this != &other) {
std::scoped_lock lock(mutex_, other.mutex_);
list_ = other.list_;
}
return *this;
}
/**
* @brief Move assignment operator.
* @param other The other ThreadSafeList to move from.
* @return Reference to the current ThreadSafeList.
*/
ThreadSafeList<T>& operator=(ThreadSafeList<T>&& other) {
if (this != &other) {
std::scoped_lock lock(mutex_, other.mutex_);
list_ = std::move(other.list_);
}
return *this;
}
/**
* @brief Determine if two ThreadSafeLists are equal
* @tparam T Type of the elements
* @param other The ThreadSafeList to compare against
* @return Whether or not the two ThreadSafeLists are equal
*/
bool operator==(const ThreadSafeList<T>& other) const {
std::scoped_lock lock(mutex_, other.mutex_);
return list_ == other.list_;
}
/**
* @brief Emplaces an element at the front of the list.
* @tparam Args The types of arguments to construct the element.
* @param args The arguments to construct the element.
*/
template<typename... Args>
void emplace_front(Args&&... args) {
{
std::lock_guard<std::mutex> lock(mutex_);
list_.emplace_front(std::forward<Args>(args)...);
if (list_.size() == 1) {
cv_.notify_one();
}
}
}
/**
* @brief Emplaces an element at the back of the list.
* @tparam Args The types of arguments to construct the element.
* @param args The arguments to construct the element.
*/
template<typename... Args>
void emplace_back(Args&&... args) {
{
std::lock_guard<std::mutex> lock(mutex_);
list_.emplace_back(std::forward<Args>(args)...);
if (list_.size() == 1) {
cv_.notify_one();
}
}
}
/**
* @brief Checks if the list contains an element.
* @param value The value to search for.
* @return True if the list contains the element, false otherwise.
*/
bool contains(const T& value) const noexcept {
std::lock_guard<std::mutex> lock(mutex_);
return std::find(list_.begin(), list_.end(), value) != list_.end();
}
/**
* @brief Finds an element in the list.
* @tparam Compare The type of the comparison function.
* @param value The value to search for.
* @param compare The comparison function.
* @return An optional reference to the found element, or nullopt if not found.
*/
template<typename Compare = std::equal_to<T>>
auto find(const T& value, Compare compare = Compare{
}) const -> std::optional<std::reference_wrapper<const T>> {
std::lock_guard<std::mutex> lock(mutex_);
auto it = std::find_if(list_.begin(), list_.end(), [&value, &compare](const T& element) -> bool {
return compare(element, value); });
if(it != list_.end()) {
return std::cref(*it);
} else {
return std::optional<std::reference_wrapper<const T>>();
}
}
/**
* @brief Clears the list.
*/
void clear() noexcept {
std::lock_guard<std::mutex> lock(mutex_);
list_.clear();
}
/**
* @brief Gets the size of the list.
* @return The size of the list.
*/
size_t size() const noexcept {
std::lock_guard<std::mutex> lock(mutex_);
return list_.size();
}
/**
* @brief Checks if the list is empty.
* @return True if the list is empty, false otherwise.
*/
bool empty() const noexcept {
std::lock_guard<std::mutex> lock(mutex_);
return list_.empty();
}
/**
* @brief Pushes an element to the back of the list.
* @param value The value to push.
*/
void push_back(const T& value) {
{
std::lock_guard<std::mutex> lock(mutex_);
list_.push_back(value);
if (list_.size() == 1) {
cv_.notify_one();
}
}
}
/**
* @brief Pushes an element to the back of the list.
* @param value The value to push.
*/
void push_back(T&& value) {
{
std::lock_guard<std::mutex> lock(mutex_);
list_.push_back(std::move(value));
if (list_.size() == 1) {
cv_.notify_one();
}
}
}
/**
* @brief Pops an element from the front of the list.
* @param timeout The maximum time to wait for an element.
* @return An optional value of the popped element, or nullopt if the list is empty or the timeout is reached.
*/
std::optional<T> pop_front(const std::chrono::milliseconds timeout = std::chrono::milliseconds(0)) {
std::unique_lock<std::mutex> lock(mutex_);
if (timeout > std::chrono::milliseconds(0) && !cv_.wait_for(lock, timeout, [this]{
return !list_.empty(); })) {
return std::nullopt;
}
if (list_.empty()) {
return std::nullopt;
}
T value = std::move(list_.front());
list_.pop_front();
return std::move(value);
}
/**
* @brief Pops an element from the back of the list.
* @param timeout The maximum time to wait for an element.
* @return An optional value of the popped element, or nullopt if the list is empty or the timeout is reached.
*/
std::optional<T> pop_back(const std::chrono::milliseconds timeout = std::chrono::milliseconds(0)) {
std::unique_lock<std::mutex> lock(mutex_);
if (timeout > std::chrono::milliseconds(0) && !cv_.wait_for(lock, timeout, [this]{
return !list_.empty(); })) {
return std::nullopt;
}
if (list_.empty()) {
return std::nullopt;
}
T value = std::move(list_.back());
list_.pop_back();
return std::move(value);
}
/**
* @brief Traverses the list and applies a callback function to each element.
* @param callback The callback function to apply to each element.
*/
void traverse(const std::function<void(const T&)>& callback) const {
std::lock_guard<std::mutex> lock(mutex_);
for (auto it = list_.begin(); it != list_.end(); ++it) {
callback(*it);
}
}
/**
* @brief Removes duplicate elements from the list.
*/
void unique() {
std::lock_guard<std::mutex> lock(mutex_);
list_.sort();
list_.unique();
}
/**
* @brief Removes duplicate elements from the list using a custom predicate.
* @tparam UnaryPredicate The type of the predicate function.
* @param pred The predicate function.
*/
template<typename UnaryPredicate>
void unique(UnaryPredicate pred) {
std::lock_guard<std::mutex> lock(mutex_);
list_.sort();
list_.unique(pred);
}
template<typename Compare = std::less<T>>
void sort(Compare compare = Compare{
}) {
std::lock_guard<std::mutex> lock(mutex_);
list_.sort(compare);
}
private:
std::list<T> list_; ///< The underlying std::list.
mutable std::mutex mutex_; ///< The mutex to synchronize access to the list.
std::condition_variable cv_; ///< The condition variable to wait for new elements.
};
#endif // THREAD_SAFE_LIST_H_
接口预览:
以下是 `ThreadSafeList` 类的接口列表:
- `ThreadSafeList()`: 默认构造函数。
- `ThreadSafeList(const ThreadSafeList& other)`: 拷贝构造函数。
- `ThreadSafeList(ThreadSafeList&& other)`: 移动构造函数。
- `ThreadSafeList& operator=(const ThreadSafeList& other)`: 拷贝赋值运算符。
- `ThreadSafeList<T>& operator=(ThreadSafeList<T>&& other)`: 移动赋值运算符。
- `template<typename... Args> void emplace_front(Args&&... args)`: 在列表头部插入一个元素。
- `template<typename... Args> void emplace_back(Args&&... args)`: 在列表尾部插入一个元素。
- `bool contains(const T& value) const noexcept`: 判断列表中是否包含某个元素。
- `template<typename Compare = std::equal_to<T>> auto find(const T& value, Compare compare = Compare{
}) const -> std::optional<std::reference_wrapper<const T>>`: 查找列表中与给定元素相等的元素。
- `void clear() noexcept`: 清空列表。
- `size_t size() const noexcept`: 返回列表的元素数量。
- `bool empty() const noexcept`: 判断列表是否为空。
- `void push_back(const T& value)`: 在列表尾部插入一个元素。
- `void push_back(T&& value)`: 在列表尾部插入一个元素。
- `std::optional<T> pop_front(const std::chrono::milliseconds timeout = std::chrono::milliseconds(0))`: 从列表头部弹出一个元素。
- `std::optional<T> pop_back(const std::chrono::milliseconds timeout = std::chrono::milliseconds(0))`: 从列表尾部弹出一个元素。
- `void traverse(const std::function<void(const T&)>& callback) const`: 遍历列表并对每个元素执行给定的回调函数。
- `void unique()`: 移除列表中的重复元素。
- `template<typename UnaryPredicate> void unique(UnaryPredicate pred)`: 移除列表中满足给定条件的重复元素。
- `template<typename Compare = std::less<T>> void sort(Compare compare = Compare{
})`: 对列表排序
gtest单元测试:
由于我们的list底层是std::list,所以基本的接口测试我就不贴了,太多了,基本上没有太大问题,我们主要针对多线程进行测试。
#include <thread>
#include "gtest/gtest.h"
#include "thread_safe_list.h"
TEST(ThreadSafeList, MultiThreadPushPopFront) {
ThreadSafeList<int> list;
std::thread producer([&list]() {
for (int i = 0; i < 10000; ++i) {
list.push_back(i);
}
});
std::thread consumer([&list]() {
int count = 0;
ThreadSafeList<int> list_pop;
for (;;) {
auto value = list.pop_front(std::chrono::milliseconds(500));
if (value.has_value()) {
EXPECT_EQ(value.value(), count);
++count;
} else {
break;
}
}
EXPECT_EQ(count, 10000);
});
producer.join();
consumer.join();
}
TEST(ThreadSafeList, MultiThreadPushPopBack) {
ThreadSafeList<int> list, list_push;
std::thread producer([&list]() {
for (int i = 0; i < 10000; ++i) {
list.push_back(i);
list_push.push_back(i);
}
});
std::thread consumer([&list, &list_push]() {
ThreadSafeList<int> list_pop;
for (;;) {
auto value = list.pop_back(std::chrono::milliseconds(500));
if (value.has_value()) {
list_pop.push_back(value.value());
} else {
break;
}
}
list_pop.sort();
EXPECT_EQ(list_pop, list_push);
});
producer.join();
consumer.join();
}
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
测试报告:
$ ./a.out
[==========] Running 2 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 2 tests from ThreadSafeList
[ RUN ] ThreadSafeList.MultiThreadPushPopFront
[ OK ] ThreadSafeList.MultiThreadPushPopFront (38 ms)
[ RUN ] ThreadSafeList.MultiThreadPushPopBack
[ OK ] ThreadSafeList.MultiThreadPushPopBack (19 ms)
[----------] 2 tests from ThreadSafeList (58 ms total)
[----------] Global test environment tear-down
[==========] 2 tests from 1 test suite ran. (59 ms total)
[ PASSED ] 2 tests.
优化的细节:
- 某些场景使用std::scoped_lock避免死锁。
- 考虑线程安全,不对外提供任何迭代器以及与迭代器操作相关的函数,诸如begin、end、insert、find、erase,和std::list一样,对外不提供[]的重载函数。
- 插入提供右值版本
- 只在第一个元素被插入或删除时通知条件变量
- 在可能返回空对象的场景返回std::optional
待优化项:
- 使用 std::shared_mutex 代替 std::mutex,支持多线程读取操作,提升并发性能,但是这个需要考虑使用场景,如果读取频繁,写入较少,则可以考虑优化,不然没有必要,可能还会适得其反。