C++ Socket 简单封装

以下代码一部分来自于《网络多人游戏架构与编程》,

其它的都是我瞎写的。

备忘。

一个简单的Socket封装,没有做什么高级的操作(比如IO完成端口等等)。

  1 #pragma once
  2 
  3 #include <iostream>
  4 #include <memory>
  5 #include <string>
  6 #include <unordered_map>
  7 #include <mutex>
  8 #include <atomic>
  9 
 10 #ifdef _WIN32
 11     #include <WinSock2.h>
 12     #include <Ws2tcpip.h>
 13     #pragma comment(lib,"ws2_32.lib")
 14 
 15 #else
 16     //TODO:
 17     //
 18 #endif
 19 
 20 
 21 #ifdef _WIN32
 22     #define SocketLastError WSAGetLastError()
 23 #else
 24     #define SocketLastError errno
 25 #endif // _WIN32
 26 
 27 
 28 namespace Lunacia
 29 {
 30     enum SocketFamily
 31     {
 32         INET = AF_INET,
 33         INET6 = AF_INET6
 34     };
 35 
 36     enum SocketProtocol
 37     {
 38         SP_TCP = IPPROTO_TCP,
 39         SP_TCP_NODELAY = TCP_NODELAY,
 40         SP_UDP = IPPROTO_UDP
 41     };
 42 
 43     template<class SockType, bool isDelay = true>
 44     struct getproto
 45     {
 46         enum {
 47             proto = 
 48                 std::is_same<SockType, TCPSocket>::value && isDelay ? SocketProtocol::SP_TCP :
 49                 std::is_same<SockType, UDPSocket>::value ? SocketProtocol::SP_UDP :
 50                 std::is_same<SockType, TCPSocket>::value && !isDelay ? SocketProtocol::SP_TCP_NODELAY :
 51                 -1,
 52 
 53             dataform = 
 54                 std::is_same<SockType, TCPSocket>::value ? SOCK_STREAM:
 55                 std::is_same<SockType, UDPSocket>::value ? SOCK_DGRAM :
 56                 -1
 57         };
 58     };
 59 
 60     //
 61     class SocketAddress final
 62     {
 63     public:
 64         SocketAddress()
 65         {
 66         }
 67 
 68     public:
 69         SocketAddress(unsigned int inAddress, unsigned int inPort, SocketFamily inFamily = INET)
 70         {
 71             if (inFamily != INET
 72                 && inFamily != INET6)
 73             {
 74                 //TODO: inFamily Error.
 75                 //
 76                 return;
 77             }
 78             GetAsSockAddrIn()->sin_family = inFamily;
 79             GetAsSockAddrIn()->sin_addr.S_un.S_addr = htonl(inAddress);
 80             GetAsSockAddrIn()->sin_port = htons(inPort);
 81         }
 82 
 83         SocketAddress(const sockaddr& inSockAddr)
 84         {
 85             memcpy(&_sockaddr, &inSockAddr, sizeof(sockaddr));
 86         }
 87 
 88         size_t Size() const
 89         {
 90             return sizeof(sockaddr);
 91         }
 92 
 93         sockaddr& Get()
 94         {
 95             return _sockaddr;
 96         }
 97 
 98         const sockaddr& Get() const
 99         {
100             return _sockaddr;
101         }
102 
103     private:
104         sockaddr _sockaddr;
105         inline sockaddr_in* GetAsSockAddrIn()
106         {
107             return reinterpret_cast<sockaddr_in*>(&_sockaddr);
108         }
109     };
110     typedef std::shared_ptr<SocketAddress> PtrSocketAddress;
111 
112     class SocketAddressFactory final
113     {
114     public:
115         static PtrSocketAddress CreateIPFromString(const std::string& inString, SocketFamily inFamily = INET)
116         {
117             auto pos = inString.find_last_of(':');
118 
119             std::string host, service;
120             if (pos != std::string::npos)
121             {
122                 host = inString.substr(0, pos);
123                 service = inString.substr(pos + 1);
124             }
125             else
126             {
127                 host = inString;
128                 service = "0";
129             }
130 
131             addrinfo hint;
132             memset(&hint, 0, sizeof(hint));
133             hint.ai_family = inFamily;
134 
135             addrinfo* result = nullptr;
136             int err = getaddrinfo(host.c_str(), service.c_str(), &hint, &result);
137             if (err != 0)
138             {
139                 if (result != nullptr) freeaddrinfo(result);
140                 return nullptr;
141             }
142 
143             while (!result->ai_addr && result->ai_next)
144             {
145                 result = result->ai_next;
146             }
147 
148             if (!result->ai_addr)
149             {
150                 freeaddrinfo(result);
151                 return nullptr;
152             }
153 
154             auto toRet = std::make_shared<SocketAddress>(*result->ai_addr);
155 
156             freeaddrinfo(result);
157 
158             return toRet;
159         }
160     };
161 
162     //Socket Base Class.
163     class LCSocket
164     {
165     public:
166         LCSocket():
167             _sock(INVALID_SOCKET),
168             _lastError(0)
169         {    
170         }
171 
172         LCSocket(SOCKET sock)
173         {
174             Close();
175             _sock = sock;
176         }
177 
178         LCSocket(const SocketAddress & inSockaddr)
179         {
180             SetAddr(inSockaddr);
181         }
182 
183         virtual ~LCSocket()
184         {
185             Close();
186         }
187 
188     public:
189         virtual int Send(const void* inData, int inLength) = 0;
190         virtual int Receive(void* inBuffer, int inLength) = 0;
191 
192         int Bind()
193         {
194             int err = bind(_sock, &_sockaddr.Get(), _sockaddr.Size());
195             if (err != 0)
196             {
197                 SetLastError(SocketLastError);
198                 LCSocket::ReportMessage("[Failed] LCSocket::Bind Failed! " + std::to_string(GetLastError()));
199             }
200             return NO_ERROR;
201         }
202 
203         void SetAddr(const SocketAddress& inSockaddr)
204         {
205             memcpy(&_sockaddr, &inSockaddr, inSockaddr.Size());
206         }
207 
208         void Close() noexcept
209         {
210             if (_sock == INVALID_SOCKET) return;
211                     
212             shutdown(_sock, 2/*Shutdown both send and receive operations. Liunx: SHUT_RDWR; Windows: SD_BOTH*/);
213             closesocket(_sock);
214             _sock = INVALID_SOCKET;
215         }
216         
217         int SetBlockMode(bool isBlock)
218         {
219 #if _WIN32
220             unsigned long blockState = isBlock ? 0 : 1;
221             int result = ioctlsocket(_sock, FIONBIO, &blockState);
222 
223 #else
224             int flags = fcntl(_sock, F_GETFL, 0);
225             flags = isBlock ? (flags & ~O_NONBLOCK) : (flags | O_NONBLOCK);
226 
227             int result = fcntl(_sock, F_SETFL, flags);
228 
229 #endif
230 
231             if (result == SOCKET_ERROR)
232             {
233                 SetLastError(SocketLastError);
234                 LCSocket::ReportMessage("[Failed] LCSocket::SetBlockMode Failed! " + std::to_string(GetLastError()));
235             }
236 
237             return NO_ERROR;
238         }
239 
240         template<class SockType, 
241             typename std::enable_if<std::is_base_of<LCSocket, SockType>::value, SockType>::type* = nullptr>
242         static std::shared_ptr<SockType> CreateSocket(SocketFamily inFamily = INET)
243         {
244             SOCKET sock = socket(inFamily, getproto<SockType>::dataform, getproto<SockType>::proto);
245             if (sock == INVALID_SOCKET)
246             {
247                 LCSocket::ReportMessage("[Failed] LCSocket::CreateSocket Failed! ");
248                 return nullptr;
249             }
250             return std::shared_ptr<SockType>(new SockType(sock));
251         }
252 
253         static void ReportMessage(const std::string& msg)
254         {
255             std::cout << msg << std::endl;
256         }
257 
258         inline void SetLastError(unsigned long err)
259         {
260             _lastError = err;
261         }
262 
263         inline unsigned long GetLastError() const
264         {
265             return _lastError;
266         }
267 
268         SOCKET Get()
269         {
270             return _sock;
271         }
272 
273     protected:
274         SOCKET _sock;
275         unsigned long _lastError;
276         SocketAddress _sockaddr;
277     };
278 
279     //UDP Socket.
280     class UDPSocket
281         :public LCSocket
282     {
283     public:
284         UDPSocket() 
285             :LCSocket() 
286         {
287         }
288 
289         UDPSocket(const SocketAddress& inSockaddr)
290             :LCSocket(inSockaddr)
291         {
292             
293         }
294 
295         UDPSocket(SOCKET sock)
296             :LCSocket(sock)
297         {
298         }
299 
300         ~UDPSocket()
301         {
302             Close();
303         }
304 
305     public:
306         virtual int Send(const void* inData, int inLength)
307         {    
308             int sendBytes = sendto(
309                 _sock,
310                 static_cast<const char*>(inData),
311                 inLength,
312                 0,
313                 &_sockaddr.Get(),
314                 _sockaddr.Size()
315             );
316             
317             if (0 < sendBytes)    return sendBytes;
318             
319             SetLastError(SocketLastError);
320             LCSocket::ReportMessage("[Failed] UDPSocket::Send Failed!");
321             return 0;
322         }
323 
324         virtual int Receive(void* inBuffer, int inLength)
325         {
326             SocketAddress outFrom;
327             return Receive(inBuffer, inLength, outFrom);
328         }
329 
330         int Receive(void* inBuffer, int inLength, SocketAddress& outFrom)
331         {        
332             int addrSize = outFrom.Size();
333             int readBytes = recvfrom(
334                 _sock,
335                 static_cast<char*>(inBuffer),
336                 inLength,
337                 0,
338                 &outFrom.Get(),
339                 &addrSize
340             );
341                 
342             if (readBytes > 0)    return readBytes;
343                         
344             SetLastError(SocketLastError);
345             LCSocket::ReportMessage("[Failed] UDPSocket::Receive Failed! " + std::to_string(GetLastError()));
346             return 0;
347         }
348 
349     };
350     typedef std::shared_ptr<UDPSocket> PtrUDPSocket;
351 
352 
353     //TCP Socket
354     class TCPSocket
355         : public LCSocket
356     {
357     public:
358         TCPSocket()
359             :LCSocket()
360         {
361         }
362 
363         TCPSocket(const SocketAddress& inSockaddr)
364             :LCSocket(inSockaddr)
365         {
366         }
367 
368         TCPSocket(SOCKET sock):
369             LCSocket(sock)
370         {
371         }
372 
373         ~TCPSocket()
374         {
375             Close();
376         }
377 
378     public:
379         //Client use it.
380         int Connect()
381         {
382             int err = connect(_sock, &_sockaddr.Get(), _sockaddr.Size());
383             if (err < 0)
384             {
385                 SetLastError(SocketLastError);
386                 LCSocket::ReportMessage("[Failed] TCPSocket::Connect Failed! " + std::to_string(GetLastError()));
387             }
388             return NO_ERROR;
389         }
390 
391         int Listen(int inBackLog = 64)
392         {
393             int err = listen(_sock, inBackLog);
394             if (err < 0)
395             {
396                 SetLastError(SocketLastError);
397                 LCSocket::ReportMessage("[Failed] TCPSocket::Listen Failed! " + std::to_string(GetLastError()));
398             }
399             return NO_ERROR;
400         }
401 
402         std::shared_ptr<TCPSocket> Accept(SocketAddress& outFromAddress)
403         {
404             int size = outFromAddress.Size();
405             SOCKET newSock = accept(_sock, &outFromAddress.Get(), &size);
406 
407             if (newSock == INVALID_SOCKET)
408             {
409                 SetLastError(SocketLastError);
410                 LCSocket::ReportMessage("[Failed] TCPSocket::Accept Failed! " + std::to_string(GetLastError()));
411             }
412             return std::shared_ptr<TCPSocket>(new TCPSocket(_sock));
413         }
414 
415         virtual int Send(const void* inData, int inLength)
416         {
417             int sendBytes = send(_sock, static_cast<const char*>(inData), inLength, 0);
418             if (sendBytes < 0)
419             {
420                 SetLastError(SocketLastError);
421                 LCSocket::ReportMessage("[Failed] TCPSocket::Send Failed! " + std::to_string(GetLastError()));
422             }
423             return sendBytes;
424         }
425 
426         virtual int Receive(void* inBuffer, int inLength)
427         {
428             int recvBytes = recv(_sock, static_cast<char*>(inBuffer), inLength, 0);
429             if (0 > recvBytes)
430             {
431                 SetLastError(SocketLastError);
432                 LCSocket::ReportMessage("[Failed] TCPSocket::Receive Failed! " + std::to_string(GetLastError()));
433             }
434             return recvBytes;
435         }
436     };
437     typedef std::shared_ptr<TCPSocket> PtrTCPSocket;
438 
439 };
440 
441 template<class SockType>
442 class SocketPool final
443 {
444 private:
445     SocketPool()
446     {
447         _socks.reverse(100);
448     }
449 
450     ~SocketPool()
451     {
452         Clear();
453     }
454 
455 public:
456     static SocketPool<SockType>* Instance()
457     {
458         
459         if (nullptr == _pInstance)
460         {
461             _mutex.lock();
462             if (nullptr == _pInstance)
463             {
464                 _pInstance = new SocketPool<SockType>();
465             }
466             _mutex.unlock();
467         }
468 
469         return _pInstance;
470     }
471 
472     void Clear(bool isCloseSocket = true)
473     {
474         if (!isCloseSocket)
475         {
476             //TODO: Others Process.
477             //
478             return;
479         }
480 
481         for (auto each : _socks)
482         {
483             each->Close();
484         }
485     }
486 
487     int Add(std::shared_ptr<SockType> pSock)
488     {
489         if (nullptr == pSock)
490         {
491             return -1;
492         }
493         _socks.push_back(pSock);
494         return _socks.size() - 1;
495     }
496 
497     bool Delete(int index)
498     {
499         if (index >= 0 && index < _socks.size())
500         {
501             _socks.erase(index);
502         }
503         return true;
504     }
505 
506     int Delete(std::shared_ptr<SockType> pSock)
507     {
508         for (int i = 0; i < _socks.size(); ++i)
509         {
510             if (_socks[i] == pSock)
511             {
512                 _socks.erase(_socks.begin() + i);
513                 return i;
514             }
515         }
516         return -1;
517     }
518 
519     std::shared_ptr<SockType> Get(int index)
520     {
521         if (index >= 0 && index < _socks.size())
522         {
523             return _socks[index];
524         }
525         return nullptr;
526     }
527 
528     std::shared_ptr<SockType> operator[](int index)
529     {
530         return Get(index);
531     }
532 
533     const std::vector<std::shared_ptr<SockType>>* const Get() const
534     {
535         return &_socks;
536     }
537 
538     static fd_set* FillSetFromVec(
539         fd_set& outSet, 
540         const std::vector<std::shared_ptr<SockType>>* inSocks
541     )
542     {
543         if (nullptr == inSocks)
544         {
545             return nullptr;
546         }
547 
548         FD_ZERO(&outSet);
549         for (const std::shared_ptr<SockType>& sockEach : *inSocks)
550         {
551             FD_SET(sockEach->Get(), &outSet);
552         }
553         return &outSet;
554     }
555 
556     static void FillVecFromSet(
557         std::vector<std::shared_ptr<SockType>>* outSocks,
558         const std::vector<std::shared_ptr<SockType>>* inSocks,
559         const fd_set& inSet
560     )
561     {
562         if (inSocks == nullptr || outSocks == nullptr)
563         {
564             return;
565         }
566 
567         outSocks->clear();
568         for (const std::shared_ptr<SockType>& each : *inSocks)
569         {
570             if (FD_ISSET(each->Get(), &inSet))
571             {
572                 outSocks->push_back(each);
573             }
574         }
575     }
576 
577     static int Select(
578         const std::vector<std::shared_ptr<SockType>>* inReadSet,
579         std::vector<std::shared_ptr<SockType>>* outReadSet,
580 
581         const std::vector<std::shared_ptr<SockType>>* inWriteSet,
582         std::vector<std::shared_ptr<SockType>>* outWriteSet,
583 
584         const std::vector<std::shared_ptr<SockType>>* inExceptSet,
585         std::vector<std::shared_ptr<SockType>>* outExceptSet
586     )
587     {
588         fd_set read, write, except;
589 
590         fd_set* pRead        = FillSetFromVec(read, inReadSet);
591         fd_set* pWrite        = FillSetFromVec(write, inWriteSet);
592         fd_set* pExcept    = FillSetFromVec(except, inExceptSet);
593 
594         int ret = select(0, pRead, pWrite, pExcept, nullptr);
595 
596         if (ret > 0)
597         {
598             FillVecFromSet(outReadSet, inReadSet, read);
599             FillVecFromSet(outWriteSet, inWriteSet, write);
600             FillVecFromSet(outExceptSet, inExceptSet, except);
601         }
602         
603         return ret;
604     }
605 
606     int Count() const 
607     {
608         return _socks.size();
609     }
610 
611 private:
612     std::vector<std::shared_ptr<SockType>> _socks;
613 
614     static SocketPool<SockType>* _pInstance = nullptr;
615     static std::mutex _mutex;
616 };

猜你喜欢

转载自www.cnblogs.com/rkexy/p/10861285.html