Sfoglia il codice sorgente

add socket.setblocking method

dreamcmi 3 anni fa
parent
commit
8a21ba98d1

+ 19 - 4
package/socket/_socket.c

@@ -1,5 +1,5 @@
-#include "_socket_socket.h"
 #include "PikaPlatform_socket.h"
+#include "_socket_socket.h"
 
 #if !PIKASCRIPT_VERSION_REQUIRE_MINIMUN(1, 12, 0)
 #error "This library requires PikaScript version 1.12.0 or higher"
@@ -17,6 +17,7 @@ void _socket_socket__init(PikaObj* self) {
         return;
     }
     obj_setInt(self, "sockfd", sockfd);
+    obj_setInt(self, "blocking", 1);
 }
 
 void _socket_socket__close(PikaObj* self) {
@@ -71,9 +72,11 @@ Arg* _socket_socket__recv(PikaObj* self, int num) {
     data_recv = arg_getBytes(res);
     ret = __platform_recv(sockfd, data_recv, num, 0);
     if (ret < 0) {
-        obj_setErrorCode(self, PIKA_RES_ERR_RUNTIME_ERROR);
-        __platform_printf("recv error\n");
-        return NULL;
+        if (obj_getInt(self, "blocking")) {
+            obj_setErrorCode(self, PIKA_RES_ERR_RUNTIME_ERROR);
+            __platform_printf("recv error\n");
+            return NULL;
+        }
     }
     return res;
 }
@@ -91,6 +94,14 @@ void _socket_socket__connect(PikaObj* self, char* host, int port) {
     server_addr.sin_addr.s_addr = inet_addr(host);
     __platform_connect(sockfd, (struct sockaddr*)&server_addr,
                        sizeof(server_addr));
+    if (obj_getInt(self, "blocking") == 0) {
+        int flags = fcntl(sockfd, F_GETFL);
+        if (fcntl(sockfd, F_SETFL, flags | O_NONBLOCK) == -1) {
+            obj_setErrorCode(self, PIKA_RES_ERR_RUNTIME_ERROR);
+            __platform_printf("Unable to set socket non blocking\n");
+            return;
+        }
+    }
 }
 
 void _socket_socket__bind(PikaObj* self, char* host, int port) {
@@ -114,3 +125,7 @@ char* _socket__gethostname(PikaObj* self) {
     __platform_gethostname(hostname_buff, 128);
     return obj_cacheStr(self, hostname);
 }
+
+void _socket_socket__setblocking(PikaObj* self, int sta) {
+    obj_setInt(self, "blocking", sta);
+}

+ 1 - 0
package/socket/_socket.pyi

@@ -10,6 +10,7 @@ class socket:
     def _connect(host: str, port: int): ...
     def _recv(num: int) -> bytes: ...
     def _init(): ...
+    def _setblocking(sta: int): ...
 
 
 def _gethostname() -> str: ...

+ 3 - 0
package/socket/socket.py

@@ -50,5 +50,8 @@ class socket(_socket.socket):
     def recv(self, num):
         return self._recv(num)
 
+    def setblocking(self, sta): 
+        return self._setblocking(sta)
+
 def gethostname():
     return _socket._gethostname()

+ 1 - 0
port/linux/package/pikascript/_socket.pyi

@@ -10,6 +10,7 @@ class socket:
     def _connect(host: str, port: int): ...
     def _recv(num: int) -> bytes: ...
     def _init(): ...
+    def _setblocking(sta: int): ...
 
 
 def _gethostname() -> str: ...

+ 19 - 4
port/linux/package/pikascript/pikascript-lib/socket/_socket.c

@@ -1,5 +1,5 @@
-#include "_socket_socket.h"
 #include "PikaPlatform_socket.h"
+#include "_socket_socket.h"
 
 #if !PIKASCRIPT_VERSION_REQUIRE_MINIMUN(1, 12, 0)
 #error "This library requires PikaScript version 1.12.0 or higher"
@@ -17,6 +17,7 @@ void _socket_socket__init(PikaObj* self) {
         return;
     }
     obj_setInt(self, "sockfd", sockfd);
+    obj_setInt(self, "blocking", 1);
 }
 
 void _socket_socket__close(PikaObj* self) {
@@ -71,9 +72,11 @@ Arg* _socket_socket__recv(PikaObj* self, int num) {
     data_recv = arg_getBytes(res);
     ret = __platform_recv(sockfd, data_recv, num, 0);
     if (ret < 0) {
-        obj_setErrorCode(self, PIKA_RES_ERR_RUNTIME_ERROR);
-        __platform_printf("recv error\n");
-        return NULL;
+        if (obj_getInt(self, "blocking")) {
+            obj_setErrorCode(self, PIKA_RES_ERR_RUNTIME_ERROR);
+            __platform_printf("recv error\n");
+            return NULL;
+        }
     }
     return res;
 }
@@ -91,6 +94,14 @@ void _socket_socket__connect(PikaObj* self, char* host, int port) {
     server_addr.sin_addr.s_addr = inet_addr(host);
     __platform_connect(sockfd, (struct sockaddr*)&server_addr,
                        sizeof(server_addr));
+    if (obj_getInt(self, "blocking") == 0) {
+        int flags = fcntl(sockfd, F_GETFL);
+        if (fcntl(sockfd, F_SETFL, flags | O_NONBLOCK) == -1) {
+            obj_setErrorCode(self, PIKA_RES_ERR_RUNTIME_ERROR);
+            __platform_printf("Unable to set socket non blocking\n");
+            return;
+        }
+    }
 }
 
 void _socket_socket__bind(PikaObj* self, char* host, int port) {
@@ -114,3 +125,7 @@ char* _socket__gethostname(PikaObj* self) {
     __platform_gethostname(hostname_buff, 128);
     return obj_cacheStr(self, hostname);
 }
+
+void _socket_socket__setblocking(PikaObj* self, int sta) {
+    obj_setInt(self, "blocking", sta);
+}

+ 3 - 0
port/linux/package/pikascript/socket.py

@@ -50,5 +50,8 @@ class socket(_socket.socket):
     def recv(self, num):
         return self._recv(num)
 
+    def setblocking(self, sta): 
+        return self._setblocking(sta)
+
 def gethostname():
     return _socket._gethostname()