转载: muduo库chat server的shared_ptr和TLS实现分析

muduo书也买过。代码也在看,不过实践确实少。

库内部还是挺不错的。值得借鉴的地方较多。

转载一篇不错的解析

muduo 库的 chat 最基本的是单线程模型,然后有多线程模型,但是多线程同步需要加锁,锁争用会降低服务器性能,明显的代码就是 chat server 的 onStringMessage()函数:

void onStringMessage(const TcpConnectionPtr&,  
                       const string& message,  
                       Timestamp)  
  {  
    MutexLockGuard lock(mutex_);  
    for (ConnectionList::iterator it = connections_.begin();  
        it != connections_.end();  
        ++it)  
    {  
      codec_.send(get_pointer(*it), message);  
    }
}

这个函数遍历连接列表,然后对每个连接挨个转发消息。临界区为整个函数,有相当大的优化空间。

优化方法有两种:

  • 使用 shared_ptr 做 copy-on-write
  • 使用TLS(线程局部存储)

shared_ptr 实现 copy-on-write

先分析如何借助 shared_ptr 实现 copy-on-write

  • shared_ptr 是引用计数智能指针,如果当前只有一个观察者(原始创建者持有),那么引用计数为 1,可以用 shared_ptr::unique() 来判断。
  • 对于 write 端,如果发现引用计数为 1,这是可以安全地修改对象,不必担心有人在读它。
  • 对于 read 端,在读之前把引用计数加 1,读完之后减 1,这样可以保证在读的期间其引用计数大于 1,可以阻止并发写
  • 比较难的是,对于 write 端,如果发现引用计数大于 1,该如何处理? 既然更新数据,肯定要加锁,如果这时候其他线程在读,那么不能在原来的数据上修改,得创建一个副本,在副本上修改,修改完了再替换。如果没有用户在读,那么可以直接修改

使用 shared_ptr 实现 copy-on-write 的目的在于在适用于 readers-writer 场合可以降低锁竞争,提高服务器性能。

使用shared_ptr改进代码如下:

class ChatServer : boost::noncopyable  
{  
 public:  
  ChatServer(EventLoop* loop,  
             const InetAddress& listenAddr)  
  : server_(loop, listenAddr, "ChatServer"),  
    codec_(boost::bind(&ChatServer::onStringMessage, this, _1, _2, _3)),  
    connections_(new ConnectionList)  //new出来,shared_ptr引用计数为1  
  {  
    server_.setConnectionCallback(  
        boost::bind(&ChatServer::onConnection, this, _1));//写操作  
    server_.setMessageCallback(  
        boost::bind(&LengthHeaderCodec::onMessage, &codec_, _1, _2, _3));  //读操作  
  }  
  
  void setThreadNum(int numThreads)  
  {  
    server_.setThreadNum(numThreads);  
  }  
  
  void start()  
  {  
    server_.start();  
  }  
  
 private:  
  void onConnection(const TcpConnectionPtr& conn)  
  {  
    LOG_INFO << conn->localAddress().toIpPort() << " -> "  
        << conn->peerAddress().toIpPort() << " is "  
        << (conn->connected() ? "UP" : "DOWN");  
  
    MutexLockGuard lock(mutex_);  
    if (!connections_.unique())  //说明引用计数大于1,存在其他线程正在读  
    {  
      //new ConnectionList(*connections_) 这段代码拷贝了一份 ConnectionList  
      //新connections_的引用计数为1,原来的connections_的引用计数减1,因为reset了  
      connections_.reset(new ConnectionList(*connections_));  
    }  
    assert(connections_.unique());  
  
    //在副本上修改,不会影响读者,所以读者在遍历列表的时候,不需要mutex保护  
    if (conn->connected())  
    {  
      connections_->insert(conn);  
    }  
    else  
    {  
      connections_->erase(conn);  
    }  
  }  
  
  typedef std::set<TcpConnectionPtr> ConnectionList;  
  typedef boost::shared_ptr<ConnectionList> ConnectionListPtr;  
  
  void onStringMessage(const TcpConnectionPtr&,  
                       const string& message,  
                       Timestamp)  
  {  
    //引用计数加1,mutex保护的临界区大大缩短  
    ConnectionListPtr connections = getConnectionList();  //临界区仅为getConnectionList内部  
    //可能大家会有疑问,不受mutex保护,写者更改了连接列表怎么办?  
    //实际上,写者是在另一个副本上修改,所以无需担心。  
    for (ConnectionList::iterator it = connections->begin();  
        it != connections->end();  
        ++it)  
    {  
      codec_.send(get_pointer(*it), message);  
    }  
    //当connections这个栈上变量销毁的时候,引用计数减1  
    //如果connections在本函数前面获得智能指针后引用计数为2(一个connections和一个connextions_),写者会采取reset使引用计数减1,  
    //再加上在本函数结束时引用计数减1,所以旧的connections_会销毁,写者reset的新的connections_成为"正宗"。  
    //也就是说 assert(!connections.unique()),这个断言在此处不一定成立。  
  }  
  
  //mutex只保护这一段临界区  
  ConnectionListPtr getConnectionList()  
  {  
    MutexLockGuard lock(mutex_);  
    return connections_;  
  }  
  
  TcpServer server_;  
  LengthHeaderCodec codec_;  
  MutexLock mutex_;  
  ConnectionListPtr connections_;  //连接的集合  
};  

TLS(线程局部存储)

shared_ptr 通过对读和写的改进来减小临界区的长度,但是在核心函数 onStringMessage() 中采用单线程遍历连接列表,单个线程挨个向客户端转发消息,这样我们觉得效率还不够。更好的办法是采用TLS。我们把每个 ConnectionsList 做成TLS,每个线程都有自己的 ConnectionList,这样客户端 1 可能位于线程 1 的 ConnectionsList,客户端 2 可能位于线程 2 的 ConnectionsList,等等。这样每个线程拥有的资源都不同,那么每个线程就可以放心的把自己要做的 send() 工作扔给 EventLoop 的 doPendingFunctors(),所以每个线程的临界区就会缩小为“转移”这一个动作,临界区小了,效率自然提升。

代码如下:

class ChatServer : boost::noncopyable  
{  
 public:  
  ChatServer(EventLoop* loop,  
             const InetAddress& listenAddr)  
  : server_(loop, listenAddr, "ChatServer"),  
    codec_(boost::bind(&ChatServer::onStringMessage, this, _1, _2, _3))  
  {  
    server_.setConnectionCallback(  
        boost::bind(&ChatServer::onConnection, this, _1));  
    server_.setMessageCallback(  
        boost::bind(&LengthHeaderCodec::onMessage, &codec_, _1, _2, _3));  
  }  
  
  void setThreadNum(int numThreads)  
  {  
    server_.setThreadNum(numThreads);  
  }  
  
  void start()  
  {  
    //设置每个线程启动前的回调函数,该函数中会为每个线程生成线程局部ConnectionsList实例  
    server_.setThreadInitCallback(boost::bind(&ChatServer::threadInit, this, _1));  
    server_.start();  
  }  
  
 private:  
  void onConnection(const TcpConnectionPtr& conn)  
  {  
    LOG_INFO << conn->localAddress().toIpPort() << " -> "  
             << conn->peerAddress().toIpPort() << " is "  
             << (conn->connected() ? "UP" : "DOWN");  
    //不需要保护,用的是每个线程局部实例ConnectionsList  
    if (conn->connected())  
    {  
      LocalConnections::instance().insert(conn);  
    }  
    else  
    {  
      LocalConnections::instance().erase(conn);  
    }  
  }  
  
  void onStringMessage(const TcpConnectionPtr&,  
                       const string& message,  
                       Timestamp)  
  {  
     //distribuMessage函数负责转发消息,下面会把它分给I/O线程,本函数mutex锁定区域只负责将f分配给I/O线程,不转发消息,减小了临界区长度  
    EventLoop::Functor f = boost::bind(&ChatServer::distributeMessage, this, message);  
    LOG_DEBUG;  
  
    MutexLockGuard lock(mutex_);   
    //被锁定的临界区并没有转发消息,  
    //转发消息给所有客户端,高效转发(多线程转发),通过各个客户端所在的I/O线程  
    for (std::set<EventLoop*>::iterator it = loops_.begin();  
        it != loops_.end();  
        ++it)  
    {  
      // 1.让对应的I/O线程来执行distributeMessage  
      // 2.distributeMessage放到I/O队列中执行,因此,这里的mutex_锁竞争大大减小  
      // 3. distributeMessage不受mutex保护,因为它是TLS  
      (*it)->queueInLoop(f);  
    }  
    LOG_DEBUG;  
  }  
  
  typedef std::set<TcpConnectionPtr> ConnectionList;  
  
  void distributeMessage(const string& message)  
  {  
    LOG_DEBUG << "begin";  
    //connections_是TLS变量,所以不需要保护  
    for (ConnectionList::iterator it = LocalConnections::instance().begin();  
        it != LocalConnections::instance().end();  
        ++it)  
    {  
      codec_.send(get_pointer(*it), message);  //发送消息  
    }  
    LOG_DEBUG << "end";  
  }  
  
  //线程调用之前调用的回调函数  
  void threadInit(EventLoop* loop)  
  {  
    assert(LocalConnections::pointer() == NULL);  
  
    //在此生成线程局部单例对象  
    LocalConnections::instance();  
    assert(LocalConnections::pointer() != NULL);  
    MutexLockGuard lock(mutex_);  
    loops_.insert(loop);  //保存loop到loops_列表中  
  }  
  
  TcpServer server_;  
  LengthHeaderCodec codec_;  
  typedef ThreadLocalSingleton<ConnectionList> LocalConnections;//线程局部单例变量  
  
  MutexLock mutex_;  
  std::set<EventLoop*> loops_;  
};  

猜你喜欢

转载自blog.csdn.net/andylau00j/article/details/79596930