1 /*
2  * Kiss - A refined core library for D programming language.
3  *
4  * Copyright (C) 2015-2018  Shanghai Putao Technology Co., Ltd
5  *
6  * Developer: HuntLabs.cn
7  *
8  * Licensed under the Apache-2.0 License.
9  *
10  */
11 
12 module kiss.event.socket.iocp;
13 
14 // dfmt off
15 version (Windows) : 
16 
17 pragma(lib, "Ws2_32");
18 // dfmt on
19 
20 import kiss.container.ByteBuffer;
21 import kiss.core;
22 import kiss.event.socket.common;
23 import kiss.event.core;
24 import kiss.util.thread;
25 
26 import core.sys.windows.windows;
27 import core.sys.windows.winsock2;
28 import core.sys.windows.mswsock;
29 
30 import std.format;
31 import std.conv;
32 import std.socket;
33 import std.exception;
34 import kiss.logger;
35 
36 import std.process;
37 
38 // import core.thread;
39 
40 /**
41 TCP Server
42 */
43 abstract class AbstractListener : AbstractSocketChannel // , IAcceptor
44 {
45     this(Selector loop, AddressFamily family = AddressFamily.INET, size_t bufferSize = 4 * 1024)
46     {
47         super(loop, WatcherType.Accept);
48         setFlag(WatchFlag.Read, true);
49         _buffer = new ubyte[bufferSize];
50         this.socket = new TcpSocket(family);
51     }
52 
53     mixin CheckIocpError;
54 
55     protected void doAccept()
56     {
57         _iocp.watcher = this;
58         _iocp.operation = IocpOperation.accept;
59         _clientSocket = new Socket(_family, SocketType.STREAM, ProtocolType.TCP);
60         DWORD dwBytesReceived = 0;
61 
62         version (KissDebugMode)
63             tracef("client socket:accept=%s  inner socket=%s", this.handle,
64                     _clientSocket.handle());
65         version (KissDebugMode)
66             trace("AcceptEx is :  ", AcceptEx);
67         int nRet = AcceptEx(this.handle, cast(SOCKET) _clientSocket.handle,
68                 _buffer.ptr, 0, sockaddr_in.sizeof + 16, sockaddr_in.sizeof + 16,
69                 &dwBytesReceived, &_iocp.overlapped);
70 
71         version (KissDebugMode)
72             trace("do AcceptEx : the return is : ", nRet);
73         checkErro(nRet);
74     }
75 
76     protected bool onAccept(scope AcceptHandler handler)
77     {
78         version (KissDebugMode)
79             trace("new connection coming...");
80         this.clearError();
81         SOCKET slisten = cast(SOCKET) this.handle;
82         SOCKET slink = cast(SOCKET) this._clientSocket.handle;
83         // void[] value = (&slisten)[0..1];
84         // setsockopt(slink, SocketOptionLevel.SOCKET, 0x700B, value.ptr,
85         //                    cast(uint) value.length);
86         version (KissDebugMode)
87             tracef("slisten=%s, slink=%s", slisten, slink);
88         setsockopt(slink, SocketOptionLevel.SOCKET, 0x700B, cast(void*)&slisten, slisten.sizeof);
89         if (handler !is null)
90             handler(this._clientSocket);
91 
92         version (KissDebugMode)
93             trace("accept next connection...");
94         if (this.isRegistered)
95             this.doAccept();
96         return true;
97     }
98 
99     override void onClose()
100     {
101         // assert(false, "");
102         // TODO: created by Administrator @ 2018-3-27 15:51:52
103     }
104 
105     private IocpContext _iocp;
106     private WSABUF _dataWriteBuffer;
107     private ubyte[] _buffer;
108     private Socket _clientSocket;
109 }
110 
111 alias AcceptorBase = AbstractListener;
112 
113 /**
114 TCP Client
115 */
116 abstract class AbstractStream : AbstractSocketChannel, Stream
117 {
118     DataReceivedHandler dataReceivedHandler;
119     DataWrittenHandler sentHandler;
120 
121     this(Selector loop, AddressFamily family = AddressFamily.INET, size_t bufferSize = 4096 * 2)
122     {
123         super(loop, WatcherType.TCP);
124         setFlag(WatchFlag.Read, true);
125         setFlag(WatchFlag.Write, true);
126 
127         version (KissDebugMode)
128             trace("Buffer size for read: ", bufferSize);
129         _readBuffer = new ubyte[bufferSize];
130         this.socket = new TcpSocket(family);
131     }
132 
133     mixin CheckIocpError;
134 
135     override void onRead()
136     {
137         version (KissDebugMode)
138             trace("ready to read");
139         _inRead = false;
140         super.onRead();
141     }
142 
143     override void onWrite()
144     {
145         _inWrite = false;
146         super.onWrite();
147     }
148 
149     protected void beginRead()
150     {
151         _inRead = true;
152         _dataReadBuffer.len = cast(uint) _readBuffer.length;
153         _dataReadBuffer.buf = cast(char*) _readBuffer.ptr;
154         _iocpread.watcher = this;
155         _iocpread.operation = IocpOperation.read;
156         DWORD dwReceived = 0;
157         DWORD dwFlags = 0;
158 
159         version (KissDebugMode)
160             tracef("receiving on thread(%d), handle=%d ", getTid(), this.socket.handle);
161 
162         int nRet = WSARecv(cast(SOCKET) this.socket.handle, &_dataReadBuffer, 1u, &dwReceived, &dwFlags,
163                 &_iocpread.overlapped, cast(LPWSAOVERLAPPED_COMPLETION_ROUTINE) null);
164 
165         checkErro(nRet, SOCKET_ERROR);
166     }
167 
168     protected void doConnect(Address addr)
169     {
170         _iocpwrite.watcher = this;
171         _iocpwrite.operation = IocpOperation.connect;
172         int nRet = ConnectEx(cast(SOCKET) this.socket.handle(),
173                 cast(SOCKADDR*) addr.name(), addr.nameLen(), null, 0, null,
174                 &_iocpwrite.overlapped);
175         checkErro(nRet, ERROR_IO_PENDING);
176     }
177 
178     private uint doWrite()
179     {
180         _inWrite = true;
181         DWORD dwFlags = 0;
182         DWORD dwSent = 0;
183         _iocpwrite.watcher = this;
184         _iocpwrite.operation = IocpOperation.write;
185         version (KissDebugMode)
186             trace("writing...handle=", this.socket.handle());
187         int nRet = WSASend(cast(SOCKET) this.socket.handle(), &_dataWriteBuffer, 1, &dwSent,
188                 dwFlags, &_iocpwrite.overlapped, cast(LPWSAOVERLAPPED_COMPLETION_ROUTINE) null);
189 
190         version (KissDebugMode)
191         {
192             if (dwSent != _dataWriteBuffer.len)
193                 warningf("dwSent=%d, BufferLength=%d", dwSent, _dataWriteBuffer.len);
194         }
195         // FIXME: Needing refactor or cleanup -@Administrator at 2018-5-9 16:28:55
196         // The buffer may be full, so what can do here?
197         // checkErro(nRet, SOCKET_ERROR); // bug:
198 
199         if (this.isError)
200         {
201             errorf("Socket error on write: fd=%d, message=%s", this.handle, this.erroString);
202             this.close();
203         }
204 
205         return dwSent;
206     }
207 
208     protected void doRead()
209     {
210         this.clearError();
211         version (KissDebugMode)
212             tracef("data reading...%d nbytes", this.readLen);
213 
214         if (readLen > 0)
215         {
216             // import std.stdio;
217             // writefln("length=%d, data: %(%02X %)", readLen, _readBuffer[0 .. readLen]);
218 
219             if (dataReceivedHandler !is null)
220                 dataReceivedHandler(this._readBuffer[0 .. readLen]);
221             version (KissDebugMode)
222                 tracef("done with data reading...%d nbytes", this.readLen);
223 
224             // continue reading
225             this.beginRead();
226         }
227         else if (readLen == 0)
228         {
229             version (KissDebugMode)
230             {
231                 if (_remoteAddress !is null)
232                     warningf("connection broken: %s", _remoteAddress.toString());
233             }
234             onDisconnected();
235             if (_isClosed)
236                 this.socket.close(); // release the sources
237             else
238                 this.close();
239         }
240         else
241         {
242             version (KissDebugMode)
243             {
244                 warningf("undefined behavior on thread %d", getTid());
245             }
246             else
247             {
248                 this._error = true;
249                 this._erroString = "undefined behavior on thread";
250             }
251         }
252     }
253 
254     // private ThreadID lastThreadID;
255 
256     /// 
257     // TODO: created by Administrator @ 2018-4-18 10:15:20
258     // Send a big block of data
259     protected size_t tryWrite(in ubyte[] data)
260     {
261         if (_isWritting)
262         {
263             warning("Busy in writting on thread: ", thisThreadID());
264             return 0;
265         }
266         version (KissDebugMode)
267             trace("start to write");
268         _isWritting = true;
269 
270         clearError();
271         setWriteBuffer(data);
272         size_t nBytes = doWrite();
273 
274         return nBytes;
275     }
276 
277     protected void tryWrite()
278     {
279         if (_isWritting)
280         {
281             version (KissDebugMode)
282                 warning("Busy in writting on thread: ", thisThreadID());
283             return;
284         }
285 
286         if (_writeQueue.empty)
287             return;
288 
289         version (KissDebugMode)
290             trace("start to write");
291         _isWritting = true;
292 
293         clearError();
294 
295         writeBuffer = _writeQueue.front();
296         setWriteBuffer(writeBuffer.sendData());
297         size_t nBytes = doWrite();
298     }
299 
300     private bool _isWritting = false;
301 
302     private void setWriteBuffer(in ubyte[] data)
303     {
304         version (KissDebugMode)
305         trace("buffer content length: ", data.length);
306             // trace(cast(string) data);
307 
308         sendDataBuffer = data; //data[writeLen .. $]; // TODO: need more tests
309         _dataWriteBuffer.buf = cast(char*) sendDataBuffer.ptr;
310         _dataWriteBuffer.len = cast(uint) sendDataBuffer.length;
311     }
312 
313     /**
314      * Called by selector after data sent
315      * Note: It's only for IOCP selector: 
316     */
317     void onWriteDone(size_t nBytes)
318     {
319         version (KissDebugMode)
320             tracef("finishing data writing, thread: %d,  nbytes: %d) ", thisThreadID(), nBytes);
321         if (isWriteCancelling)
322         {
323             _isWritting = false;
324             isWriteCancelling = false;
325             _writeQueue.clear(); // clean the data buffer 
326             return;
327         }
328 
329         if (writeBuffer.popSize(nBytes))
330         {
331             if (_writeQueue.deQueue() is null)
332                 warning("_writeQueue is empty!");
333 
334             writeBuffer.doFinish();
335             _isWritting = false;
336 
337             version (KissDebugMode)
338                 tracef("done with data writing, thread: %d,  nbytes: %d) ", thisThreadID(), nBytes);
339 
340             tryWrite();
341         }
342         else // if (sendDataBuffer.length > nBytes) 
343         {
344             version (KissDebugMode)
345                 tracef("remaining nbytes: ", sendDataBuffer.length - nBytes);
346             setWriteBuffer(sendDataBuffer[nBytes .. $]); // send remaining
347             nBytes = doWrite();
348         }
349     }
350 
351     void cancelWrite()
352     {
353         isWriteCancelling = true;
354     }
355 
356     protected void onDisconnected()
357     {
358         _isConnected = false;
359         _isClosed = true;
360         if (disconnectionHandler !is null)
361             disconnectionHandler();
362     }
363 
364     bool _isConnected; //if server side always true.
365     SimpleEventHandler disconnectionHandler;
366 
367     protected WriteBufferQueue _writeQueue;
368     protected bool isWriteCancelling = false;
369     private const(ubyte)[] _readBuffer;
370     private const(ubyte)[] sendDataBuffer;
371     private StreamWriteBuffer writeBuffer;
372 
373     private IocpContext _iocpread;
374     private IocpContext _iocpwrite;
375 
376     private WSABUF _dataReadBuffer;
377     private WSABUF _dataWriteBuffer;
378 
379     private bool _inWrite;
380     private bool _inRead;
381 }
382 
383 /**
384 UDP Socket
385 */
386 abstract class AbstractDatagramSocket : AbstractSocketChannel, IDatagramSocket
387 {
388     /// Constructs a blocking IPv4 UDP Socket.
389     this(Selector loop, AddressFamily family = AddressFamily.INET)
390     {
391         super(loop, WatcherType.UDP);
392         setFlag(WatchFlag.Read, true);
393         setFlag(WatchFlag.ETMode, false);
394 
395         this.socket = new UdpSocket(family);
396         _readBuffer = new UdpDataObject();
397         _readBuffer.data = new ubyte[4096 * 2];
398 
399         if (family == AddressFamily.INET)
400             _bindAddress = new InternetAddress(InternetAddress.PORT_ANY);
401         else if (family == AddressFamily.INET6)
402             _bindAddress = new Internet6Address(Internet6Address.PORT_ANY);
403         else
404             _bindAddress = new UnknownAddress();
405     }
406 
407     final void bind(Address addr)
408     {
409         if (_binded)
410             return;
411         _bindAddress = addr;
412         socket.bind(_bindAddress);
413         _binded = true;
414     }
415 
416     final bool isBind()
417     {
418         return _binded;
419     }
420 
421     Address bindAddr()
422     {
423         return _bindAddress;
424     }
425 
426     override void start()
427     {
428         if (!_binded)
429         {
430             socket.bind(_bindAddress);
431             _binded = true;
432         }
433     }
434 
435     // abstract void doRead();
436 
437     private UdpDataObject _readBuffer;
438     protected bool _binded = false;
439     protected Address _bindAddress;
440 
441     version (Windows)
442     {
443         mixin CheckIocpError;
444 
445         void doRead()
446         {
447             version (KissDebugMode)
448                 trace("Receiving......");
449 
450             _dataReadBuffer.len = cast(uint) _readBuffer.data.length;
451             _dataReadBuffer.buf = cast(char*) _readBuffer.data.ptr;
452             _iocpread.watcher = this;
453             _iocpread.operation = IocpOperation.read;
454             remoteAddrLen = cast(int) bindAddr().nameLen();
455 
456             DWORD dwReceived = 0;
457             DWORD dwFlags = 0;
458 
459             int nRet = WSARecvFrom(cast(SOCKET) this.handle, &_dataReadBuffer,
460                     cast(uint) 1, &dwReceived, &dwFlags, cast(SOCKADDR*)&remoteAddr, &remoteAddrLen,
461                     &_iocpread.overlapped, cast(LPWSAOVERLAPPED_COMPLETION_ROUTINE) null);
462             checkErro(nRet, SOCKET_ERROR);
463         }
464 
465         Address buildAddress()
466         {
467             Address tmpaddr;
468             if (remoteAddrLen == 32)
469             {
470                 sockaddr_in* addr = cast(sockaddr_in*)(&remoteAddr);
471                 tmpaddr = new InternetAddress(*addr);
472             }
473             else
474             {
475                 sockaddr_in6* addr = cast(sockaddr_in6*)(&remoteAddr);
476                 tmpaddr = new Internet6Address(*addr);
477             }
478             return tmpaddr;
479         }
480 
481         bool tryRead(scope ReadCallBack read)
482         {
483             this.clearError();
484             if (this.readLen == 0)
485             {
486                 read(null);
487             }
488             else
489             {
490                 ubyte[] data = this._readBuffer.data;
491                 this._readBuffer.data = data[0 .. this.readLen];
492                 this._readBuffer.addr = this.buildAddress();
493                 scope (exit)
494                     this._readBuffer.data = data;
495                 read(this._readBuffer);
496                 this._readBuffer.data = data;
497                 if (this.isRegistered)
498                     this.doRead();
499             }
500             return false;
501         }
502 
503         IocpContext _iocpread;
504         WSABUF _dataReadBuffer;
505 
506         sockaddr remoteAddr;
507         int remoteAddrLen;
508     }
509 
510 }
511 
512 /**
513 */
514 mixin template CheckIocpError()
515 {
516     void checkErro(int ret, int erro = 0)
517     {
518         DWORD dwLastError = GetLastError();
519         if (ret != 0 || dwLastError == 0)
520             return;
521 
522         version (KissDebugMode)
523             tracef("erro=%d, dwLastError=%d", erro, dwLastError);
524 
525         if (ERROR_IO_PENDING != dwLastError)
526         {
527             this._error = true;
528             this._erroString = format("AcceptEx failed with error: code=%s", dwLastError);
529         }
530     }
531 }
532 
533 enum IocpOperation
534 {
535     accept,
536     connect,
537     read,
538     write,
539     event,
540     close
541 }
542 
543 struct IocpContext
544 {
545     OVERLAPPED overlapped;
546     IocpOperation operation;
547     AbstractChannel watcher = null;
548 }
549 
550 alias WSAOVERLAPPED = OVERLAPPED;
551 alias LPWSAOVERLAPPED = OVERLAPPED*;
552 
553 __gshared static LPFN_ACCEPTEX AcceptEx;
554 __gshared static LPFN_CONNECTEX ConnectEx;
555 /*__gshared LPFN_DISCONNECTEX DisconnectEx;
556 __gshared LPFN_GETACCEPTEXSOCKADDRS GetAcceptexSockAddrs;
557 __gshared LPFN_TRANSMITFILE TransmitFile;
558 __gshared LPFN_TRANSMITPACKETS TransmitPackets;
559 __gshared LPFN_WSARECVMSG WSARecvMsg;
560 __gshared LPFN_WSASENDMSG WSASendMsg;*/
561 
562 shared static this()
563 {
564     WSADATA wsaData;
565     int iResult = WSAStartup(MAKEWORD(2, 2), &wsaData);
566 
567     SOCKET ListenSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
568     scope (exit)
569         closesocket(ListenSocket);
570     GUID guid;
571     mixin(GET_FUNC_POINTER("WSAID_ACCEPTEX", "AcceptEx"));
572     mixin(GET_FUNC_POINTER("WSAID_CONNECTEX", "ConnectEx"));
573     /* mixin(GET_FUNC_POINTER("WSAID_DISCONNECTEX", "DisconnectEx"));
574      mixin(GET_FUNC_POINTER("WSAID_GETACCEPTEXSOCKADDRS", "GetAcceptexSockAddrs"));
575      mixin(GET_FUNC_POINTER("WSAID_TRANSMITFILE", "TransmitFile"));
576      mixin(GET_FUNC_POINTER("WSAID_TRANSMITPACKETS", "TransmitPackets"));
577      mixin(GET_FUNC_POINTER("WSAID_WSARECVMSG", "WSARecvMsg"));*/
578 }
579 
580 shared static ~this()
581 {
582     WSACleanup();
583 }
584 
585 private
586 {
587     bool GetFunctionPointer(FuncPointer)(SOCKET sock, ref FuncPointer pfn, ref GUID guid)
588     {
589         DWORD dwBytesReturned = 0;
590         if (WSAIoctl(sock, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, guid.sizeof,
591                 &pfn, pfn.sizeof, &dwBytesReturned, null, null) == SOCKET_ERROR)
592         {
593             error("Get function failed with error:", GetLastError());
594             return false;
595         }
596 
597         return true;
598     }
599 
600     string GET_FUNC_POINTER(string GuidValue, string pft)
601     {
602         string str = " guid = " ~ GuidValue ~ ";";
603         str ~= "if( !GetFunctionPointer( ListenSocket, " ~ pft
604             ~ ", guid ) ) { errnoEnforce(false,\"get function error!\"); } ";
605         return str;
606     }
607 }
608 
609 enum : DWORD
610 {
611     IOCPARAM_MASK = 0x7f,
612     IOC_VOID = 0x20000000,
613     IOC_OUT = 0x40000000,
614     IOC_IN = 0x80000000,
615     IOC_INOUT = IOC_IN | IOC_OUT
616 }
617 
618 enum IOC_UNIX = 0x00000000;
619 enum IOC_WS2 = 0x08000000;
620 enum IOC_PROTOCOL = 0x10000000;
621 enum IOC_VENDOR = 0x18000000;
622 
623 template _WSAIO(int x, int y)
624 {
625     enum _WSAIO = IOC_VOID | x | y;
626 }
627 
628 template _WSAIOR(int x, int y)
629 {
630     enum _WSAIOR = IOC_OUT | x | y;
631 }
632 
633 template _WSAIOW(int x, int y)
634 {
635     enum _WSAIOW = IOC_IN | x | y;
636 }
637 
638 template _WSAIORW(int x, int y)
639 {
640     enum _WSAIORW = IOC_INOUT | x | y;
641 }
642 
643 enum SIO_ASSOCIATE_HANDLE = _WSAIOW!(IOC_WS2, 1);
644 enum SIO_ENABLE_CIRCULAR_QUEUEING = _WSAIO!(IOC_WS2, 2);
645 enum SIO_FIND_ROUTE = _WSAIOR!(IOC_WS2, 3);
646 enum SIO_FLUSH = _WSAIO!(IOC_WS2, 4);
647 enum SIO_GET_BROADCAST_ADDRESS = _WSAIOR!(IOC_WS2, 5);
648 enum SIO_GET_EXTENSION_FUNCTION_POINTER = _WSAIORW!(IOC_WS2, 6);
649 enum SIO_GET_QOS = _WSAIORW!(IOC_WS2, 7);
650 enum SIO_GET_GROUP_QOS = _WSAIORW!(IOC_WS2, 8);
651 enum SIO_MULTIPOINT_LOOPBACK = _WSAIOW!(IOC_WS2, 9);
652 enum SIO_MULTICAST_SCOPE = _WSAIOW!(IOC_WS2, 10);
653 enum SIO_SET_QOS = _WSAIOW!(IOC_WS2, 11);
654 enum SIO_SET_GROUP_QOS = _WSAIOW!(IOC_WS2, 12);
655 enum SIO_TRANSLATE_HANDLE = _WSAIORW!(IOC_WS2, 13);
656 enum SIO_ROUTING_INTERFACE_QUERY = _WSAIORW!(IOC_WS2, 20);
657 enum SIO_ROUTING_INTERFACE_CHANGE = _WSAIOW!(IOC_WS2, 21);
658 enum SIO_ADDRESS_LIST_QUERY = _WSAIOR!(IOC_WS2, 22);
659 enum SIO_ADDRESS_LIST_CHANGE = _WSAIO!(IOC_WS2, 23);
660 enum SIO_QUERY_TARGET_PNP_HANDLE = _WSAIOR!(IOC_WS2, 24);
661 enum SIO_NSP_NOTIFY_CHANGE = _WSAIOW!(IOC_WS2, 25);
662 
663 extern (Windows):
664 nothrow:
665 int WSARecv(SOCKET, LPWSABUF, DWORD, LPDWORD, LPDWORD, LPWSAOVERLAPPED,
666         LPWSAOVERLAPPED_COMPLETION_ROUTINE);
667 int WSARecvDisconnect(SOCKET, LPWSABUF);
668 int WSARecvFrom(SOCKET, LPWSABUF, DWORD, LPDWORD, LPDWORD, SOCKADDR*, LPINT,
669         LPWSAOVERLAPPED, LPWSAOVERLAPPED_COMPLETION_ROUTINE);
670 
671 int WSASend(SOCKET, LPWSABUF, DWORD, LPDWORD, DWORD, LPWSAOVERLAPPED,
672         LPWSAOVERLAPPED_COMPLETION_ROUTINE);
673 int WSASendDisconnect(SOCKET, LPWSABUF);
674 int WSASendTo(SOCKET, LPWSABUF, DWORD, LPDWORD, DWORD, const(SOCKADDR)*, int,
675         LPWSAOVERLAPPED, LPWSAOVERLAPPED_COMPLETION_ROUTINE);