1 /*
2  * Kiss - A refined core library for D programming language.
3  *
4  * Copyright (C) 2015-2018  Shanghai Putao Technology Co., Ltd
5  *
6  * Developer: HuntLabs.cn
7  *
8  * Licensed under the Apache-2.0 License.
9  *
10  */
11 
12 module kiss.event.socket.iocp;
13 
14 // dfmt off
15 version (Windows) : 
16 
17 pragma(lib, "Ws2_32");
18 // dfmt on
19 
20 import kiss.container.ByteBuffer;
21 import kiss.core;
22 import kiss.event.socket.common;
23 import kiss.event.core;
24 import kiss.util.thread;
25 
26 import core.sys.windows.windows;
27 import core.sys.windows.winsock2;
28 import core.sys.windows.mswsock;
29 
30 import std.format;
31 import std.conv;
32 import std.socket;
33 import std.exception;
34 import kiss.logger;
35 
36 import std.process;
37 
38 // import core.thread;
39 
40 /**
41 TCP Server
42 */
43 abstract class AbstractListener : AbstractSocketChannel // , IAcceptor
44 {
45     this(Selector loop, AddressFamily family = AddressFamily.INET, size_t bufferSize = 4 * 1024)
46     {
47         super(loop, WatcherType.Accept);
48         setFlag(WatchFlag.Read, true);
49         _buffer = new ubyte[bufferSize];
50         this.socket = new TcpSocket(family);
51     }
52 
53     mixin CheckIocpError;
54 
55     protected void doAccept()
56     {
57         _iocp.watcher = this;
58         _iocp.operation = IocpOperation.accept;
59         _clientSocket = new Socket(_family, SocketType.STREAM, ProtocolType.TCP);
60         DWORD dwBytesReceived = 0;
61 
62         version (KissDebugMode)
63             tracef("client socket:accept=%s  inner socket=%s", this.handle,
64                     _clientSocket.handle());
65         version (KissDebugMode)
66             trace("AcceptEx is :  ", AcceptEx);
67         int nRet = AcceptEx(this.handle, cast(SOCKET) _clientSocket.handle,
68                 _buffer.ptr, 0, sockaddr_in.sizeof + 16, sockaddr_in.sizeof + 16,
69                 &dwBytesReceived, &_iocp.overlapped);
70 
71         version (KissDebugMode)
72             trace("do AcceptEx : the return is : ", nRet);
73         checkErro(nRet);
74     }
75 
76     protected bool onAccept(scope AcceptHandler handler)
77     {
78         version (KissDebugMode)
79             trace("new connection coming...");
80         this.clearError();
81         SOCKET slisten = cast(SOCKET) this.handle;
82         SOCKET slink = cast(SOCKET) this._clientSocket.handle;
83         // void[] value = (&slisten)[0..1];
84         // setsockopt(slink, SocketOptionLevel.SOCKET, 0x700B, value.ptr,
85         //                    cast(uint) value.length);
86         version (KissDebugMode)
87             tracef("slisten=%s, slink=%s", slisten, slink);
88         setsockopt(slink, SocketOptionLevel.SOCKET, 0x700B, cast(void*)&slisten, slisten.sizeof);
89         if (handler !is null)
90             handler(this._clientSocket);
91 
92         version (KissDebugMode)
93             trace("accept next connection...");
94         if (this.isRegistered)
95             this.doAccept();
96         return true;
97     }
98 
99     override void onClose()
100     {
101         // assert(false, "");
102         // TODO: created by Administrator @ 2018-3-27 15:51:52
103     }
104 
105     private IocpContext _iocp;
106     private WSABUF _dataWriteBuffer;
107     private ubyte[] _buffer;
108     private Socket _clientSocket;
109 }
110 
111 alias AcceptorBase = AbstractListener;
112 
113 /**
114 TCP Client
115 */
116 abstract class AbstractStream : AbstractSocketChannel, Stream
117 {
118     DataReceivedHandler dataReceivedHandler;
119     DataWrittenHandler sentHandler;
120 
121     this(Selector loop, AddressFamily family = AddressFamily.INET, size_t bufferSize = 4096 * 2)
122     {
123         super(loop, WatcherType.TCP);
124         setFlag(WatchFlag.Read, true);
125         setFlag(WatchFlag.Write, true);
126 
127         version (KissDebugMode)
128             trace("Buffer size for read: ", bufferSize);
129         _readBuffer = new ubyte[bufferSize];
130         this.socket = new TcpSocket(family);
131     }
132 
133     mixin CheckIocpError;
134 
135     override void onRead()
136     {
137         version (KissDebugMode)
138             trace("ready to read");
139         _inRead = false;
140         super.onRead();
141     }
142 
143     override void onWrite()
144     {
145         _inWrite = false;
146         super.onWrite();
147     }
148 
149     protected void beginRead()
150     {
151         _inRead = true;
152         _dataReadBuffer.len = cast(uint) _readBuffer.length;
153         _dataReadBuffer.buf = cast(char*) _readBuffer.ptr;
154         _iocpread.watcher = this;
155         _iocpread.operation = IocpOperation.read;
156         DWORD dwReceived = 0;
157         DWORD dwFlags = 0;
158 
159         version (KissDebugMode)
160             tracef("start receiving handle=%d ", this.socket.handle);
161 
162         int nRet = WSARecv(cast(SOCKET) this.socket.handle, &_dataReadBuffer, 1u, &dwReceived, &dwFlags,
163                 &_iocpread.overlapped, cast(LPWSAOVERLAPPED_COMPLETION_ROUTINE) null);
164 
165         checkErro(nRet, SOCKET_ERROR);
166     }
167 
168     protected void doConnect(Address addr)
169     {
170         _iocpwrite.watcher = this;
171         _iocpwrite.operation = IocpOperation.connect;
172         int nRet = ConnectEx(cast(SOCKET) this.socket.handle(),
173                 cast(SOCKADDR*) addr.name(), addr.nameLen(), null, 0, null,
174                 &_iocpwrite.overlapped);
175         checkErro(nRet, ERROR_IO_PENDING);
176     }
177 
178     private uint doWrite()
179     {
180         _inWrite = true;
181         DWORD dwFlags = 0;
182         DWORD dwSent = 0;
183         _iocpwrite.watcher = this;
184         _iocpwrite.operation = IocpOperation.write;
185         version (KissDebugMode)
186         {
187             size_t bufferLength = sendDataBuffer.length;
188             trace("writing...handle=", this.socket.handle());
189             trace("buffer content length: ", bufferLength);
190             // trace(cast(string) data);
191             if(bufferLength>64)
192                 tracef("%(%02X %) ...", sendDataBuffer[0..64]);
193             else
194                 tracef("%(%02X %)", sendDataBuffer[0..$]);
195         }
196 
197         int nRet = WSASend(cast(SOCKET) this.socket.handle(), &_dataWriteBuffer, 1, &dwSent,
198                 dwFlags, &_iocpwrite.overlapped, cast(LPWSAOVERLAPPED_COMPLETION_ROUTINE) null);
199 
200         version (KissDebugMode)
201         {
202             if (dwSent != _dataWriteBuffer.len)
203                 warningf("dwSent=%d, BufferLength=%d", dwSent, _dataWriteBuffer.len);
204         }
205         // FIXME: Needing refactor or cleanup -@Administrator at 2018-5-9 16:28:55
206         // The buffer may be full, so what can do here?
207         // checkErro(nRet, SOCKET_ERROR); // bug:
208 
209         if (this.isError)
210         {
211             errorf("Socket error on write: fd=%d, message=%s", this.handle, this.erroString);
212             this.close();
213         }
214 
215         return dwSent;
216     }
217 
218     protected void doRead()
219     {
220         this.clearError();
221         version (KissDebugMode)
222             tracef("data reading...%d nbytes", this.readLen);
223 
224         if (readLen > 0)
225         {
226             // import std.stdio;
227             // writefln("length=%d, data: %(%02X %)", readLen, _readBuffer[0 .. readLen]);
228 
229             if (dataReceivedHandler !is null)
230                 dataReceivedHandler(this._readBuffer[0 .. readLen]);
231             version (KissDebugMode)
232                 tracef("done with data reading...%d nbytes", this.readLen);
233 
234             // continue reading
235             this.beginRead();
236         }
237         else if (readLen == 0)
238         {
239             version (KissDebugMode) {
240                 if (_remoteAddress !is null)
241                     warningf("connection broken: %s", _remoteAddress.toString());
242             }
243             onDisconnected();
244             // if (_isClosed)
245             //     this.close();
246         }
247         else
248         {
249             version (KissDebugMode)
250             {
251                 warningf("undefined behavior on thread %d", getTid());
252             }
253             else
254             {
255                 this._error = true;
256                 this._erroString = "undefined behavior on thread";
257             }
258         }
259     }
260 
261     // private ThreadID lastThreadID;
262 
263     /// 
264     // TODO: created by Administrator @ 2018-4-18 10:15:20
265     // Send a big block of data
266     protected size_t tryWrite(in ubyte[] data)
267     {
268         if (_isWritting)
269         {
270             warning("Busy in writting on thread: ");
271             return 0;
272         }
273         version (KissDebugMode)
274             trace("start to write");
275         _isWritting = true;
276 
277         clearError();
278         setWriteBuffer(data);
279         size_t nBytes = doWrite();
280 
281         return nBytes;
282     }
283 
284     protected void tryWrite()
285     {
286         if (_isWritting)
287         {
288             version (KissDebugMode)
289                 warning("Busy in writting on thread: ");
290             return;
291         }
292 
293         if (_writeQueue.empty)
294             return;
295 
296         version (KissDebugMode) trace("start to write");
297         _isWritting = true;
298 
299         clearError();
300 
301         writeBuffer = _writeQueue.front();
302         const(ubyte)[] data = writeBuffer.sendData();
303         setWriteBuffer(data);
304         size_t nBytes = doWrite();
305 
306         if(nBytes < data.length) { // to fix the corrupted data 
307             version (KissDebugMode) warningf("remaining data: %d / %d ", data.length - nBytes, data.length);
308             sendDataBuffer = data.dup;
309         }
310     }
311 
312     private bool _isWritting = false;
313 
314     private void setWriteBuffer(in ubyte[] data)
315     {
316         version (KissDebugMode)
317         trace("buffer content length: ", data.length);
318         // trace(cast(string) data);
319         // tracef("%(%02X %)", data);
320 
321         sendDataBuffer = data; //data[writeLen .. $]; // TODO: need more tests
322         _dataWriteBuffer.buf = cast(char*) sendDataBuffer.ptr;
323         _dataWriteBuffer.len = cast(uint) sendDataBuffer.length;
324     }
325 
326     /**
327      * Called by selector after data sent
328      * Note: It's only for IOCP selector: 
329     */
330     void onWriteDone(size_t nBytes)
331     {
332         version (KissDebugMode)
333             tracef("finishing data writting %d nbytes) ", nBytes);
334         if (isWriteCancelling)
335         {
336             _isWritting = false;
337             isWriteCancelling = false;
338             _writeQueue.clear(); // clean the data buffer 
339             return;
340         }
341 
342         if (writeBuffer.popSize(nBytes))
343         {
344             if (_writeQueue.deQueue() is null)
345                 warning("_writeQueue is empty!");
346 
347             writeBuffer.doFinish();
348             _isWritting = false;
349 
350             version (KissDebugMode)
351                 tracef("done with data writting %d nbytes) ", nBytes);
352 
353             tryWrite();
354         }
355         else // if (sendDataBuffer.length > nBytes) 
356         {
357             // version (KissDebugMode)
358                 tracef("remaining nbytes: %d", sendDataBuffer.length - nBytes);
359             // FIXME: Needing refactor or cleanup -@Administrator at 2018-6-12 13:56:17
360             // sendDataBuffer corrupted
361             // const(ubyte)[] data = writeBuffer.sendData();
362             // tracef("%(%02X %)", data);
363             // tracef("%(%02X %)", sendDataBuffer);
364             setWriteBuffer(sendDataBuffer[nBytes .. $]); // send remaining
365             nBytes = doWrite();
366         }
367     }
368 
369     void cancelWrite()
370     {
371         isWriteCancelling = true;
372     }
373 
374     protected void onDisconnected()
375     {
376         _isConnected = false;
377         _isClosed = true;
378         if (disconnectionHandler !is null)
379             disconnectionHandler();
380     }
381 
382     bool _isConnected; //if server side always true.
383     SimpleEventHandler disconnectionHandler;
384 
385     protected WriteBufferQueue _writeQueue;
386     protected bool isWriteCancelling = false;
387     private const(ubyte)[] _readBuffer;
388     private const(ubyte)[] sendDataBuffer;
389     private StreamWriteBuffer writeBuffer;
390 
391     private IocpContext _iocpread;
392     private IocpContext _iocpwrite;
393 
394     private WSABUF _dataReadBuffer;
395     private WSABUF _dataWriteBuffer;
396 
397     private bool _inWrite;
398     private bool _inRead;
399 }
400 
401 /**
402 UDP Socket
403 */
404 abstract class AbstractDatagramSocket : AbstractSocketChannel, IDatagramSocket
405 {
406     /// Constructs a blocking IPv4 UDP Socket.
407     this(Selector loop, AddressFamily family = AddressFamily.INET)
408     {
409         super(loop, WatcherType.UDP);
410         setFlag(WatchFlag.Read, true);
411         setFlag(WatchFlag.ETMode, false);
412 
413         this.socket = new UdpSocket(family);
414         _readBuffer = new UdpDataObject();
415         _readBuffer.data = new ubyte[4096 * 2];
416 
417         if (family == AddressFamily.INET)
418             _bindAddress = new InternetAddress(InternetAddress.PORT_ANY);
419         else if (family == AddressFamily.INET6)
420             _bindAddress = new Internet6Address(Internet6Address.PORT_ANY);
421         else
422             _bindAddress = new UnknownAddress();
423     }
424 
425     final void bind(Address addr)
426     {
427         if (_binded)
428             return;
429         _bindAddress = addr;
430         socket.bind(_bindAddress);
431         _binded = true;
432     }
433 
434     final bool isBind()
435     {
436         return _binded;
437     }
438 
439     Address bindAddr()
440     {
441         return _bindAddress;
442     }
443 
444     override void start()
445     {
446         if (!_binded)
447         {
448             socket.bind(_bindAddress);
449             _binded = true;
450         }
451     }
452 
453     // abstract void doRead();
454 
455     private UdpDataObject _readBuffer;
456     protected bool _binded = false;
457     protected Address _bindAddress;
458 
459     version (Windows)
460     {
461         mixin CheckIocpError;
462 
463         void doRead()
464         {
465             version (KissDebugMode)
466                 trace("Receiving......");
467 
468             _dataReadBuffer.len = cast(uint) _readBuffer.data.length;
469             _dataReadBuffer.buf = cast(char*) _readBuffer.data.ptr;
470             _iocpread.watcher = this;
471             _iocpread.operation = IocpOperation.read;
472             remoteAddrLen = cast(int) bindAddr().nameLen();
473 
474             DWORD dwReceived = 0;
475             DWORD dwFlags = 0;
476 
477             int nRet = WSARecvFrom(cast(SOCKET) this.handle, &_dataReadBuffer,
478                     cast(uint) 1, &dwReceived, &dwFlags, cast(SOCKADDR*)&remoteAddr, &remoteAddrLen,
479                     &_iocpread.overlapped, cast(LPWSAOVERLAPPED_COMPLETION_ROUTINE) null);
480             checkErro(nRet, SOCKET_ERROR);
481         }
482 
483         Address buildAddress()
484         {
485             Address tmpaddr;
486             if (remoteAddrLen == 32)
487             {
488                 sockaddr_in* addr = cast(sockaddr_in*)(&remoteAddr);
489                 tmpaddr = new InternetAddress(*addr);
490             }
491             else
492             {
493                 sockaddr_in6* addr = cast(sockaddr_in6*)(&remoteAddr);
494                 tmpaddr = new Internet6Address(*addr);
495             }
496             return tmpaddr;
497         }
498 
499         bool tryRead(scope ReadCallBack read)
500         {
501             this.clearError();
502             if (this.readLen == 0)
503             {
504                 read(null);
505             }
506             else
507             {
508                 ubyte[] data = this._readBuffer.data;
509                 this._readBuffer.data = data[0 .. this.readLen];
510                 this._readBuffer.addr = this.buildAddress();
511                 scope (exit)
512                     this._readBuffer.data = data;
513                 read(this._readBuffer);
514                 this._readBuffer.data = data;
515                 if (this.isRegistered)
516                     this.doRead();
517             }
518             return false;
519         }
520 
521         IocpContext _iocpread;
522         WSABUF _dataReadBuffer;
523 
524         sockaddr remoteAddr;
525         int remoteAddrLen;
526     }
527 
528 }
529 
530 /**
531 */
532 mixin template CheckIocpError()
533 {
534     void checkErro(int ret, int erro = 0)
535     {
536         DWORD dwLastError = GetLastError();
537         if (ret != 0 || dwLastError == 0)
538             return;
539 
540         version (KissDebugMode)
541             tracef("erro=%d, dwLastError=%d", erro, dwLastError);
542 
543         if (ERROR_IO_PENDING != dwLastError)
544         {
545             this._error = true;
546             this._erroString = format("AcceptEx failed with error: code=%s", dwLastError);
547         }
548     }
549 }
550 
551 enum IocpOperation
552 {
553     accept,
554     connect,
555     read,
556     write,
557     event,
558     close
559 }
560 
561 struct IocpContext
562 {
563     OVERLAPPED overlapped;
564     IocpOperation operation;
565     AbstractChannel watcher = null;
566 }
567 
568 alias WSAOVERLAPPED = OVERLAPPED;
569 alias LPWSAOVERLAPPED = OVERLAPPED*;
570 
571 __gshared static LPFN_ACCEPTEX AcceptEx;
572 __gshared static LPFN_CONNECTEX ConnectEx;
573 /*__gshared LPFN_DISCONNECTEX DisconnectEx;
574 __gshared LPFN_GETACCEPTEXSOCKADDRS GetAcceptexSockAddrs;
575 __gshared LPFN_TRANSMITFILE TransmitFile;
576 __gshared LPFN_TRANSMITPACKETS TransmitPackets;
577 __gshared LPFN_WSARECVMSG WSARecvMsg;
578 __gshared LPFN_WSASENDMSG WSASendMsg;*/
579 
580 shared static this()
581 {
582     WSADATA wsaData;
583     int iResult = WSAStartup(MAKEWORD(2, 2), &wsaData);
584 
585     SOCKET ListenSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
586     scope (exit)
587         closesocket(ListenSocket);
588     GUID guid;
589     mixin(GET_FUNC_POINTER("WSAID_ACCEPTEX", "AcceptEx"));
590     mixin(GET_FUNC_POINTER("WSAID_CONNECTEX", "ConnectEx"));
591     /* mixin(GET_FUNC_POINTER("WSAID_DISCONNECTEX", "DisconnectEx"));
592      mixin(GET_FUNC_POINTER("WSAID_GETACCEPTEXSOCKADDRS", "GetAcceptexSockAddrs"));
593      mixin(GET_FUNC_POINTER("WSAID_TRANSMITFILE", "TransmitFile"));
594      mixin(GET_FUNC_POINTER("WSAID_TRANSMITPACKETS", "TransmitPackets"));
595      mixin(GET_FUNC_POINTER("WSAID_WSARECVMSG", "WSARecvMsg"));*/
596 }
597 
598 shared static ~this()
599 {
600     WSACleanup();
601 }
602 
603 private
604 {
605     bool GetFunctionPointer(FuncPointer)(SOCKET sock, ref FuncPointer pfn, ref GUID guid)
606     {
607         DWORD dwBytesReturned = 0;
608         if (WSAIoctl(sock, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, guid.sizeof,
609                 &pfn, pfn.sizeof, &dwBytesReturned, null, null) == SOCKET_ERROR)
610         {
611             error("Get function failed with error:", GetLastError());
612             return false;
613         }
614 
615         return true;
616     }
617 
618     string GET_FUNC_POINTER(string GuidValue, string pft)
619     {
620         string str = " guid = " ~ GuidValue ~ ";";
621         str ~= "if( !GetFunctionPointer( ListenSocket, " ~ pft
622             ~ ", guid ) ) { errnoEnforce(false,\"get function error!\"); } ";
623         return str;
624     }
625 }
626 
627 enum : DWORD
628 {
629     IOCPARAM_MASK = 0x7f,
630     IOC_VOID = 0x20000000,
631     IOC_OUT = 0x40000000,
632     IOC_IN = 0x80000000,
633     IOC_INOUT = IOC_IN | IOC_OUT
634 }
635 
636 enum IOC_UNIX = 0x00000000;
637 enum IOC_WS2 = 0x08000000;
638 enum IOC_PROTOCOL = 0x10000000;
639 enum IOC_VENDOR = 0x18000000;
640 
641 template _WSAIO(int x, int y)
642 {
643     enum _WSAIO = IOC_VOID | x | y;
644 }
645 
646 template _WSAIOR(int x, int y)
647 {
648     enum _WSAIOR = IOC_OUT | x | y;
649 }
650 
651 template _WSAIOW(int x, int y)
652 {
653     enum _WSAIOW = IOC_IN | x | y;
654 }
655 
656 template _WSAIORW(int x, int y)
657 {
658     enum _WSAIORW = IOC_INOUT | x | y;
659 }
660 
661 enum SIO_ASSOCIATE_HANDLE = _WSAIOW!(IOC_WS2, 1);
662 enum SIO_ENABLE_CIRCULAR_QUEUEING = _WSAIO!(IOC_WS2, 2);
663 enum SIO_FIND_ROUTE = _WSAIOR!(IOC_WS2, 3);
664 enum SIO_FLUSH = _WSAIO!(IOC_WS2, 4);
665 enum SIO_GET_BROADCAST_ADDRESS = _WSAIOR!(IOC_WS2, 5);
666 enum SIO_GET_EXTENSION_FUNCTION_POINTER = _WSAIORW!(IOC_WS2, 6);
667 enum SIO_GET_QOS = _WSAIORW!(IOC_WS2, 7);
668 enum SIO_GET_GROUP_QOS = _WSAIORW!(IOC_WS2, 8);
669 enum SIO_MULTIPOINT_LOOPBACK = _WSAIOW!(IOC_WS2, 9);
670 enum SIO_MULTICAST_SCOPE = _WSAIOW!(IOC_WS2, 10);
671 enum SIO_SET_QOS = _WSAIOW!(IOC_WS2, 11);
672 enum SIO_SET_GROUP_QOS = _WSAIOW!(IOC_WS2, 12);
673 enum SIO_TRANSLATE_HANDLE = _WSAIORW!(IOC_WS2, 13);
674 enum SIO_ROUTING_INTERFACE_QUERY = _WSAIORW!(IOC_WS2, 20);
675 enum SIO_ROUTING_INTERFACE_CHANGE = _WSAIOW!(IOC_WS2, 21);
676 enum SIO_ADDRESS_LIST_QUERY = _WSAIOR!(IOC_WS2, 22);
677 enum SIO_ADDRESS_LIST_CHANGE = _WSAIO!(IOC_WS2, 23);
678 enum SIO_QUERY_TARGET_PNP_HANDLE = _WSAIOR!(IOC_WS2, 24);
679 enum SIO_NSP_NOTIFY_CHANGE = _WSAIOW!(IOC_WS2, 25);
680 
681 extern (Windows):
682 nothrow:
683 int WSARecv(SOCKET, LPWSABUF, DWORD, LPDWORD, LPDWORD, LPWSAOVERLAPPED,
684         LPWSAOVERLAPPED_COMPLETION_ROUTINE);
685 int WSARecvDisconnect(SOCKET, LPWSABUF);
686 int WSARecvFrom(SOCKET, LPWSABUF, DWORD, LPDWORD, LPDWORD, SOCKADDR*, LPINT,
687         LPWSAOVERLAPPED, LPWSAOVERLAPPED_COMPLETION_ROUTINE);
688 
689 int WSASend(SOCKET, LPWSABUF, DWORD, LPDWORD, DWORD, LPWSAOVERLAPPED,
690         LPWSAOVERLAPPED_COMPLETION_ROUTINE);
691 int WSASendDisconnect(SOCKET, LPWSABUF);
692 int WSASendTo(SOCKET, LPWSABUF, DWORD, LPDWORD, DWORD, const(SOCKADDR)*, int,
693         LPWSAOVERLAPPED, LPWSAOVERLAPPED_COMPLETION_ROUTINE);