C++语言中LRU算法的一种实现

1.什么是LRU

以下内容来自百度百科:

      LRU是Least Recently Used的缩写,即最近最少使用,是一种常用的页面置换算法,选择最近最久未使用的页面予以淘汰。该算法赋予每个页面一个访问字段,用来记录一个页面自上次被访问以来所经历的时间 t,当须淘汰一个页面时,选择现有页面中其 t 值最大的,即最近最少使用的页面予以淘汰。

     在实际项目开发中,为了提高程序运行速度,以空间换时间,常用的数据经常需要缓存,而LRU常用做一种的缓存更新策略被使用。本文就介绍C++语言下LRU的一种实现,希望对大家有帮助和启发。

2.模板 + 无序字典 + 双向链表实现LRU

数据缓存可能会存储各种各样的数据,如果为每种数据类型都写一个缓存算法,那代码量优点大,而且对后期维护也是一种很大的成本,在C++语言中模板编程可以很好的解决这个问题。缓存数据可多可少,支持动态变化,可以使用std::list结构存储缓存数据,缓存数据要支持快速查找,因此可以使用无须字典来快速访问std::list,字典key为缓存数据key,value为std::list的迭代器。

3.LRU类定义

lrucache.h

#ifndef UTIL_LRU_CACHE_H
#define UTIL_LRU_CACHE_H

#include <list>
#include <memory>
#include <unordered_map>
#include <mutex>
#include <functional>

/*
//std::list方法splice使用举例
list<string> list1, list2;
list1.push_back("a");
list1.push_back("b");
list1.push_back("c"); // list1: "a", "b", "c"
list<string>::iterator iterator1 = list1.begin();
iterator1++: // points to "b"
list2.push_back("x");
list2.push_back("y");
list2.push_back("z"); // list2: "x", "y", "z"
list<string>::iterator iterator2 = list2.begin();
iterator2++; // points to "y"
list1.splice(iterator1, list2, iterator2, list2.end());
// list1: "a", "y", "z", "b", "c"
// list2: "x"
*/

namespace common {
namespace util {

/// LRU cache base on elements count and memory size
template <typename Key, typename Value>
class LRUCache
{
public:
    using key_type          = Key;
    using value_type        = Value;
    using value_ptr_type    = std::shared_ptr<Value>;
    using list_value        = std::pair<Key, value_ptr_type>;
    using iterator          = typename std::list<list_value>::iterator;
    using size_type         = uint64_t;
    using rate_type         = double;

    /// cache stat structure
    struct Stats {
        size_type m_get_cnt; // 总的请求次数
        size_type m_hit_cnt; // 命中次数
        size_type m_set_cnt; // 总设置次数
    };

public:
    // if DropHandler is setup, lru droped data will be passed to DropHandler
    typedef std::function<
        void (const key_type&, const value_ptr_type&)> DropHandler;

private:
    mutable std::mutex                   m_mutex;
    std::list<list_value>           m_list;
    std::unordered_map<key_type, iterator>  m_hash_table;

    size_type                       m_max_size; // 可保有的最大元素数量

    DropHandler                     m_drop_handler; // lru替换时执行的回调

    Stats                           m_stats;

public:
    LRUCache() = delete;

    ///
    /// construct
    /// \param [max_size] 最大元素个数
    ///
    explicit LRUCache(size_type max_size);

public:
    ///
    /// push
    /// \brief 将k-v对压入容器
    /// \param [in]: key, value
    /// \warning 压入后放在队头,若压入后满足了进行淘汰的条件,将淘汰1个元素
    ///
    void push(const key_type& key, const value_ptr_type& value);

    ///
    /// get
    /// \brief 根据key,从容器中取出value
    /// \param [in]: key, [out]: value
    /// \return bool [ture]: 容器中有此k-v对 [false]: 容器中无此k-v对
    /// \warning 若容器中有此k-v对,get操作会根据最近使用原则,将此k-v对移动至队头
    ///
    bool get(const key_type& key, value_ptr_type& value);

    ///
    /// exists
    /// \brief 判断容器中是否有key对应的k-v对
    /// \param [in]: key
    /// \return bool [ture]: 容器中有此k-v对 [false]: 容器中无此k-v对
    /// \warning  此方法不会影响元素位置,不更该容器状态,只返回k-v对是否存在
    ///
    bool exists(const key_type& key) const
    {
        std::lock_guard<std::mutex> lock (m_mutex);
        auto ite = m_hash_table.find(key);
        return ite == m_hash_table.end() ? false : true;
    }

    /// check drop handler
    /// \brief 判断lru替换数据回调是否存在
    bool check_drop_handler() const
    {
        std::lock_guard<std::mutex> lock (m_mutex);
        return m_drop_handler != nullptr;
    }

    /// get drop handler
    /// \brief 获取lru替换数据回调
    DropHandler get_drop_handler() const
    {
        std::lock_guard<std::mutex> lock (m_mutex);
        return m_drop_handler;
    }

    /// set drop handler
    /// \brief 设置lru替换数据回调
    void set_drop_handler(const DropHandler& handler)
    {
        std::lock_guard<std::mutex> lock (m_mutex);
        m_drop_handler = handler;
    }
    
    /// get_max_capacity
    /// \brief 获取当前最大内存限制量
    size_type get_max_capacity() const
    {
        std::lock_guard<std::mutex> lock (m_mutex);
        return m_max_size * sizeof(value_type);
    }

    /// get_current_capacity
    /// \brief 获取当前已经使用的内存量
    size_type get_current_capacity() const
    {
        std::lock_guard<std::mutex> lock (m_mutex);
        return _get_cache_size_in_lock() * sizeof(value_type);
    }

    /// get_current_count
    /// \brief 获取当前元素个数
    size_type get_current_count() const
    {
        std::lock_guard<std::mutex> lock (m_mutex);
        return _get_cache_size_in_lock();
    }

    ///
    /// get_hit_rate
    /// \brief 获取截至当前get的命中率
    /// \return rate_type(double)
    ///
    rate_type get_hit_rate() const
    {
        std::lock_guard<std::mutex> lock (m_mutex);
        return static_cast<rate_type>(m_stats.m_hit_cnt) / m_stats.m_get_cnt;
    }

    /// get_stats
    /// \brief 获取当前统计信息
    void get_stats(Stats& stats) const
    {
        std::lock_guard<std::mutex> lock (m_mutex);
        stats.m_hit_cnt = m_stats.m_hit_cnt;
        stats.m_get_cnt = m_stats.m_get_cnt;
        stats.m_set_cnt = m_stats.m_set_cnt;
    }

    ///
    /// reset_stats
    /// \brief 重置容器状态,重新统计命中率
    ///
    void reset_stats()
    {
        std::lock_guard<std::mutex> lock (m_mutex);
        m_stats.m_hit_cnt = 0;
        m_stats.m_get_cnt = 0;
        m_stats.m_set_cnt = 0;
    }

private:
    ///
    /// [内部方法] 获取当前保有的元素个数
    /// \return size_type
    /// \warning 此方法必须在m_mutex上锁之后调用
    ///
    size_type _get_cache_size_in_lock() const
    {
        // 此处使用hashmap的size,时间复杂度为O(1)。
        // list求size的时间复杂度为O(n)
        return m_hash_table.size();
    }

    ///
    /// [内部方法] 淘汰一个元素,时间复杂度O(1)
    /// \warning 如果传入了淘汰回调,则需要同步等待回调执行完毕
    ///
    void _discard_one_element_in_lock()
    {
        // std::list为双向链表,end()的时间复杂度为O(1)
        auto ite_list_last = m_list.end();
        ite_list_last--;
        m_hash_table.erase(ite_list_last->first);
        // 如果设置了drop回调则需要执行drop回调
        if (m_drop_handler != nullptr) {
            m_drop_handler(ite_list_last->first, ite_list_last->second);
        }
        m_list.erase(ite_list_last);
    }
};

template <typename Key, typename Value>
LRUCache<Key, Value>::LRUCache(size_type max_size) : m_max_size(max_size)
{
    m_stats.m_get_cnt = 0;
    m_stats.m_set_cnt = 0;
    m_stats.m_hit_cnt = 0;
}

template <typename Key, typename Value>
void LRUCache<Key, Value>::push(const key_type& key, const value_ptr_type& value) {
    std::lock_guard<std::mutex> lock (m_mutex);

    m_stats.m_set_cnt++;

    auto ite = m_hash_table.find(key);
    if (ite == m_hash_table.end()) {
        m_list.push_front({key, value});
        m_hash_table[key] = m_list.begin();
    } else {
        ite->second->second = value;
        m_list.splice(m_list.begin(), m_list, ite->second);
    }

    // 满足size条件才不会进行淘汰
    if (_get_cache_size_in_lock() > m_max_size) {
        _discard_one_element_in_lock();
    }
}

template <typename Key, typename Value>
bool LRUCache<Key, Value>::get(const key_type& key, value_ptr_type& value)
{
    std::lock_guard<std::mutex> lock (m_mutex);

    m_stats.m_get_cnt++;

    auto ite = m_hash_table.find(key);
    if (ite == m_hash_table.end()) {
        return false;
    }

    m_stats.m_hit_cnt++;
    m_list.splice(m_list.begin(), m_list, ite->second);

    value = ite->second->second;
    return true;
}

} 
} 

#endif

4.LRU类使用测试

#include <iostream>
#include <time.h>
#include <sys/time.h>
#include "lrucache.h"

using namespace std;
using namespace common::util;

static inline uint64_t get_microsecond()
{
    struct timeval tv;
    gettimeofday(&tv, NULL);
    uint64_t usec = tv.tv_sec * 1000000LLU + tv.tv_usec;
    return usec;
}

// 查找判断
void test1() {
     int n = 4;
    LRUCache<int, int> lru0(30);

    for (int i = 0; i < n; ++i) {
        lru0.push(i, std::make_shared<int>(i + n));
    }

    cout << "===============================" << endl;

    for (int i = 0; i < n; ++i) {
        auto tmp = std::make_shared<int>(-1);
        lru0.get(i, tmp);
        std::cout << ((i + n) == *tmp) << std::endl;
    }

    cout << "===============================" << endl;

    for (int i = 0; i < n; ++i) {
        auto tmp = std::make_shared<int>(-1);
        lru0.get(i, tmp);
        std::cout << ((i + n) == *tmp) << std::endl;
    }

    cout << "===============================" << endl;
    cout << (0 == lru0.exists(n)) << endl;
    auto ret = std::make_shared<int>(-1);
    cout << (0 == lru0.get(n, ret)) << endl;
}

//根据元素个数进行淘汰
void test2() {
    int n = 4;
    LRUCache<int, int> lru1(static_cast<uint64_t>(n));

     for (int i = 0; i < n; ++i) {
        lru1.push(i, std::make_shared<int>(i + n));
    }

    for (int i = 0; i < n; ++i) {
        auto tmp = std::make_shared<int>(-1);
        cout << (1 == lru1.get(i, tmp)) << endl;
        cout << (i + n == *tmp) << endl;
    }

    for (int i = n; i < n * 2; ++i) {
        lru1.push(i, std::make_shared<int>(i + n));
        cout << (1 == lru1.exists(i)) << endl;
        int j = 0;
        for (; j <= i - n; ++j) {
            cout << (1 == lru1.exists(j)) << endl;
        }
        for (; j <= i; ++j) {
            cout << (1 == lru1.exists(j)) << endl;
        }
    }
}

void test3() {
     const int cap = 3000000;

    LRUCache<int, int> lru4(static_cast<uint64_t>(cap));

    uint64_t totalTime = 0;
    for (int i = 0; i < cap; ++i) {
        uint64_t begin = get_microsecond();
        lru4.push(i, std::make_shared<int>(i));
        uint64_t end = get_microsecond();

        totalTime += end - begin;
    }

    std::cout << "average latency with no-discard push:"
              << 1.0 * totalTime / cap << " us" << std::endl;

    totalTime = 0;
    for (int i = cap; i < cap * 2; ++i) {
        uint64_t begin = get_microsecond();
        lru4.push(i, std::make_shared<int>(i));
        uint64_t end = get_microsecond();

        totalTime += end - begin;
    }

    std::cout << "average latency with discard push:"
              << 1.0 * totalTime / cap << " us" << std::endl;

    totalTime = 0;
    for (int i = cap; i < cap * 2; ++i) {
        auto tmp = std::make_shared<int>(-1);
        uint64_t begin = get_microsecond();
        lru4.get(i, tmp);
        uint64_t end = get_microsecond();

        totalTime += end - begin;
    }

    std::cout << "average latency with get exists elements:"
              << 1.0 * totalTime / cap << " us" << std::endl;

    totalTime = 0;
    for (int i = 0; i < cap; ++i) {
        auto tmp = std::make_shared<int>(-1);
        uint64_t begin = get_microsecond();
        lru4.get(i, tmp);
        uint64_t end = get_microsecond();

        totalTime += end - begin;
    }

    std::cout << "average latency with get not exists elements:"
              << 1.0 * totalTime / cap << " us" << std::endl;
}

int main() {

    test3();

    return 0;
}

编译运行:

g++ main.cpp 

./a.out

运行结果如下:

猜你喜欢

转载自blog.csdn.net/hsy12342611/article/details/129917427
今日推荐