Parcourir la source

linux-sgx: Implement socket API getpeername, recvfrom and sendto (#1556)

Implement some of the popular socket APIs left unimplemented for SGX,
following the merge of dev/socket.
Jämes Ménétrey il y a 3 ans
Parent
commit
e2a3f0f387

+ 14 - 5
core/shared/platform/common/posix/posix_socket.c

@@ -261,9 +261,12 @@ os_socket_recv_from(bh_socket_t socket, void *buf, unsigned int len, int flags,
         return ret;
     }
 
-    if (src_addr) {
-        sockaddr_to_bh_sockaddr((struct sockaddr *)&sock_addr, socklen,
-                                src_addr);
+    if (src_addr && socklen > 0) {
+        if (sockaddr_to_bh_sockaddr((struct sockaddr *)&sock_addr, socklen,
+                                    src_addr)
+            == BHT_ERROR) {
+            return -1;
+        }
     }
 
     return ret;
@@ -391,8 +394,14 @@ os_socket_addr_resolve(const char *host, const char *service,
                 continue;
             }
 
-            sockaddr_to_bh_sockaddr(res->ai_addr, sizeof(struct sockaddr_in),
-                                    &addr_info[pos].sockaddr);
+            ret = sockaddr_to_bh_sockaddr(res->ai_addr,
+                                          sizeof(struct sockaddr_in),
+                                          &addr_info[pos].sockaddr);
+
+            if (ret == BHT_ERROR) {
+                freeaddrinfo(result);
+                return BHT_ERROR;
+            }
 
             addr_info[pos].is_tcp = res->ai_socktype == SOCK_STREAM;
         }

+ 116 - 6
core/shared/platform/linux-sgx/sgx_socket.c

@@ -31,6 +31,10 @@ int
 ocall_getsockname(int *p_ret, int sockfd, void *addr, uint32_t *addrlen,
                   uint32_t addr_size);
 
+int
+ocall_getpeername(int *p_ret, int sockfd, void *addr, uint32_t *addrlen,
+                  uint32_t addr_size);
+
 int
 ocall_getsockopt(int *p_ret, int sockfd, int level, int optname, void *val_buf,
                  unsigned int val_buf_size, void *len_buf);
@@ -41,6 +45,10 @@ ocall_listen(int *p_ret, int sockfd, int backlog);
 int
 ocall_recv(int *p_ret, int sockfd, void *buf, size_t len, int flags);
 
+int
+ocall_recvfrom(ssize_t *p_ret, int sockfd, void *buf, size_t len, int flags,
+               void *src_addr, uint32_t *addrlen, uint32_t addr_size);
+
 int
 ocall_recvmsg(ssize_t *p_ret, int sockfd, void *msg_buf,
               unsigned int msg_buf_size, int flags);
@@ -48,6 +56,10 @@ ocall_recvmsg(ssize_t *p_ret, int sockfd, void *msg_buf,
 int
 ocall_send(int *p_ret, int sockfd, const void *buf, size_t len, int flags);
 
+int
+ocall_sendto(ssize_t *p_ret, int sockfd, const void *buf, size_t len, int flags,
+             void *dest_addr, uint32_t addrlen);
+
 int
 ocall_sendmsg(ssize_t *p_ret, int sockfd, void *msg_buf,
               unsigned int msg_buf_size, int flags);
@@ -237,6 +249,46 @@ textual_addr_to_sockaddr(const char *textual, int port, struct sockaddr_in *out)
     return BHT_OK;
 }
 
+static int
+sockaddr_to_bh_sockaddr(const struct sockaddr *sockaddr, socklen_t socklen,
+                        bh_sockaddr_t *bh_sockaddr)
+{
+    switch (sockaddr->sa_family) {
+        case AF_INET:
+        {
+            struct sockaddr_in *addr = (struct sockaddr_in *)sockaddr;
+
+            assert(socklen >= sizeof(struct sockaddr_in));
+
+            bh_sockaddr->port = ntohs(addr->sin_port);
+            bh_sockaddr->addr_bufer.ipv4 = ntohl(addr->sin_addr.s_addr);
+            bh_sockaddr->is_ipv4 = true;
+            return BHT_OK;
+        }
+        default:
+            errno = EAFNOSUPPORT;
+            return BHT_ERROR;
+    }
+}
+
+static int
+bh_sockaddr_to_sockaddr(const bh_sockaddr_t *bh_sockaddr,
+                        struct sockaddr *sockaddr, socklen_t *socklen)
+{
+    if (bh_sockaddr->is_ipv4) {
+        struct sockaddr_in *addr = (struct sockaddr_in *)sockaddr;
+        addr->sin_port = htons(bh_sockaddr->port);
+        addr->sin_family = AF_INET;
+        addr->sin_addr.s_addr = htonl(bh_sockaddr->addr_bufer.ipv4);
+        *socklen = sizeof(*addr);
+        return BHT_OK;
+    }
+    else {
+        errno = EAFNOSUPPORT;
+        return BHT_ERROR;
+    }
+}
+
 int
 socket(int domain, int type, int protocol)
 {
@@ -651,6 +703,7 @@ os_socket_recv(bh_socket_t socket, void *buf, unsigned int len)
     int ret;
 
     if (ocall_recv(&ret, socket, buf, len, 0) != SGX_SUCCESS) {
+        TRACE_OCALL_FAIL();
         errno = ENOSYS;
         return -1;
     }
@@ -665,9 +718,32 @@ int
 os_socket_recv_from(bh_socket_t socket, void *buf, unsigned int len, int flags,
                     bh_sockaddr_t *src_addr)
 {
-    errno = ENOSYS;
+    struct sockaddr_in addr;
+    socklen_t addr_len = sizeof(addr);
+    ssize_t ret;
 
-    return BHT_ERROR;
+    if (ocall_recvfrom(&ret, socket, buf, len, flags, &addr, &addr_len,
+                       addr_len)
+        != SGX_SUCCESS) {
+        TRACE_OCALL_FAIL();
+        errno = ENOSYS;
+        return -1;
+    }
+
+    if (ret < 0) {
+        errno = get_errno();
+        return ret;
+    }
+
+    if (src_addr && addr_len > 0) {
+        if (sockaddr_to_bh_sockaddr((struct sockaddr *)&addr, addr_len,
+                                    src_addr)
+            == BHT_ERROR) {
+            return -1;
+        }
+    }
+
+    return ret;
 }
 
 int
@@ -676,6 +752,7 @@ os_socket_send(bh_socket_t socket, const void *buf, unsigned int len)
     int ret;
 
     if (ocall_send(&ret, socket, buf, len, 0) != SGX_SUCCESS) {
+        TRACE_OCALL_FAIL();
         errno = ENOSYS;
         return -1;
     }
@@ -690,9 +767,28 @@ int
 os_socket_send_to(bh_socket_t socket, const void *buf, unsigned int len,
                   int flags, const bh_sockaddr_t *dest_addr)
 {
-    errno = ENOSYS;
+    struct sockaddr_in addr;
+    socklen_t addr_len;
+    ssize_t ret;
 
-    return BHT_ERROR;
+    if (bh_sockaddr_to_sockaddr(dest_addr, (struct sockaddr *)&addr, &addr_len)
+        == BHT_ERROR) {
+        return -1;
+    }
+
+    if (ocall_sendto(&ret, socket, buf, len, flags, (struct sockaddr *)&addr,
+                     addr_len)
+        != SGX_SUCCESS) {
+        TRACE_OCALL_FAIL();
+        errno = ENOSYS;
+        return -1;
+    }
+
+    if (ret == -1) {
+        errno = get_errno();
+    }
+
+    return ret;
 }
 
 int
@@ -723,9 +819,23 @@ os_socket_addr_local(bh_socket_t socket, bh_sockaddr_t *sockaddr)
 int
 os_socket_addr_remote(bh_socket_t socket, bh_sockaddr_t *sockaddr)
 {
-    errno = ENOSYS;
+    struct sockaddr_in addr;
+    socklen_t addr_len = sizeof(addr);
+    int ret;
 
-    return BHT_ERROR;
+    if (ocall_getpeername(&ret, socket, (void *)&addr, &addr_len, addr_len)
+        != SGX_SUCCESS) {
+        TRACE_OCALL_FAIL();
+        return -1;
+    }
+
+    if (ret != BHT_OK) {
+        errno = get_errno();
+        return BHT_ERROR;
+    }
+
+    return sockaddr_to_bh_sockaddr((struct sockaddr *)&addr, addr_len,
+                                   sockaddr);
 }
 
 int

+ 7 - 0
core/shared/platform/linux-sgx/sgx_wamr.edl

@@ -124,17 +124,24 @@ enclave {
         int ocall_connect(int sockfd, [in, size=addrlen]void *addr, uint32_t addrlen);
         int ocall_getsockname(int sockfd, [out, size=addr_size]void *addr,
                               [in, out, size=4]uint32_t *addrlen, uint32_t addr_size);
+        int ocall_getpeername(int sockfd, [out, size=addr_size]void *addr,
+                              [in, out, size=4]uint32_t *addrlen, uint32_t addr_size);
         int ocall_getsockopt(int sockfd, int level, int optname,
                              [out, size=val_buf_size]void *val_buf,
                              unsigned int val_buf_size,
                              [in, out, size=4]void *len_buf);
         int ocall_listen(int sockfd, int backlog);
         int ocall_recv(int sockfd, [out, size=len]void *buf, size_t len, int flags);
+        ssize_t ocall_recvfrom(int sockfd, [out, size=len]void *buf, size_t len, int flags,
+                               [out, size=addr_size]void *src_addr,
+                               [in, out, size=4]uint32_t *addrlen, uint32_t addr_size);
         ssize_t ocall_recvmsg(int sockfd,
                               [in, out, size=msg_buf_size]void *msg_buf,
                               unsigned int msg_buf_size,
                               int flags);
         int ocall_send(int sockfd, [in, size=len]const void *buf, size_t len, int flags);
+        ssize_t ocall_sendto(int sockfd, [in, size=len]const void *buf, size_t len, int flags,
+                             [in, size=addrlen]void *dest_addr, uint32_t addrlen);
         ssize_t ocall_sendmsg(int sockfd,
                               [in, size=msg_buf_size]void *msg_buf,
                               unsigned int msg_buf_size,

+ 22 - 0
core/shared/platform/linux-sgx/untrusted/socket.c

@@ -95,6 +95,12 @@ ocall_getsockname(int sockfd, void *addr, uint32_t *addrlen, uint32_t addr_size)
     return getsockname(sockfd, (struct sockaddr *)addr, addrlen);
 }
 
+int
+ocall_getpeername(int sockfd, void *addr, uint32_t *addrlen, uint32_t addr_size)
+{
+    return getpeername(sockfd, (struct sockaddr *)addr, addrlen);
+}
+
 int
 ocall_listen(int sockfd, int backlog)
 {
@@ -113,12 +119,28 @@ ocall_recv(int sockfd, void *buf, size_t len, int flags)
     return recv(sockfd, buf, len, flags);
 }
 
+ssize_t
+ocall_recvfrom(int sockfd, void *buf, size_t len, int flags, void *src_addr,
+               uint32_t *addrlen, uint32_t addr_size)
+{
+    return recvfrom(sockfd, buf, len, flags, (struct sockaddr *)src_addr,
+                    addrlen);
+}
+
 int
 ocall_send(int sockfd, const void *buf, size_t len, int flags)
 {
     return send(sockfd, buf, len, flags);
 }
 
+ssize_t
+ocall_sendto(int sockfd, const void *buf, size_t len, int flags,
+             void *dest_addr, uint32_t addrlen)
+{
+    return sendto(sockfd, buf, len, flags, (struct sockaddr *)dest_addr,
+                  addrlen);
+}
+
 int
 ocall_connect(int sockfd, void *addr, uint32_t addrlen)
 {