sgx_socket.c 5.5 KB


  1. /*
  2. * Copyright (C) 2019 Intel Corporation. All rights reserved.
  3. * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  4. */
  5. #include "platform_api_vmcore.h"
  6. #ifndef SGX_DISABLE_WASI
  7. #define TRACE_OCALL_FAIL() os_printf("ocall %s failed!\n", __FUNCTION__)
  8. int ocall_socket(int *p_ret, int domain, int type, int protocol);
  9. int ocall_getsockopt(int *p_ret, int sockfd, int level, int optname,
  10. void *val_buf, unsigned int val_buf_size,
  11. void *len_buf);
  12. int ocall_sendmsg(ssize_t *p_ret, int sockfd, void *msg_buf,
  13. unsigned int msg_buf_size, int flags);
  14. int ocall_recvmsg(ssize_t *p_ret, int sockfd, void *msg_buf,
  15. unsigned int msg_buf_size, int flags);
  16. int ocall_shutdown(int *p_ret, int sockfd, int how);
  17. int socket(int domain, int type, int protocol)
  18. {
  19. int ret;
  20. if (ocall_socket(&ret, domain, type, protocol) != SGX_SUCCESS) {
  21. TRACE_OCALL_FAIL();
  22. return -1;
  23. }
  24. if (ret == -1)
  25. errno = get_errno();
  26. return ret;
  27. }
  28. int getsockopt(int sockfd, int level, int optname,
  29. void *optval, socklen_t *optlen)
  30. {
  31. int ret;
  32. unsigned int val_buf_size = *optlen;
  33. if (ocall_getsockopt(&ret, sockfd, level, optname, optval,
  34. val_buf_size, (void *)optlen) != SGX_SUCCESS) {
  35. TRACE_OCALL_FAIL();
  36. return -1;
  37. }
  38. if (ret == -1)
  39. errno = get_errno();
  40. return ret;
  41. }
  42. ssize_t sendmsg(int sockfd, const struct msghdr *msg, int flags)
  43. {
  44. ssize_t ret;
  45. int i;
  46. char *p;
  47. struct msghdr *msg1;
  48. uint64 total_size = sizeof(struct msghdr) + (uint64)msg->msg_namelen
  49. + (uint64)msg->msg_controllen;
  50. total_size += sizeof(struct iovec) * (msg->msg_iovlen);
  51. for (i = 0; i < msg->msg_iovlen; i++) {
  52. total_size += msg->msg_iov[i].iov_len;
  53. }
  54. if (total_size >= UINT32_MAX)
  55. return -1;
  56. msg1 = BH_MALLOC((uint32)total_size);
  57. if (msg1 == NULL)
  58. return -1;
  59. p = (char*)(uintptr_t)sizeof(struct msghdr);
  60. if (msg->msg_name != NULL) {
  61. msg1->msg_name = p;
  62. memcpy((uintptr_t)p + (char *)msg1, msg->msg_name,
  63. (size_t)msg->msg_namelen);
  64. p += msg->msg_namelen;
  65. }
  66. if (msg->msg_control != NULL) {
  67. msg1->msg_control = p;
  68. memcpy((uintptr_t)p + (char *)msg1, msg->msg_control,
  69. (size_t)msg->msg_control);
  70. p += msg->msg_controllen;
  71. }
  72. if (msg->msg_iov != NULL) {
  73. msg1->msg_iov = (struct iovec *)p;
  74. p += (uintptr_t)(sizeof(struct iovec) * (msg->msg_iovlen));
  75. for (i = 0; i < msg->msg_iovlen; i++) {
  76. msg1->msg_iov[i].iov_base = p;
  77. msg1->msg_iov[i].iov_len = msg->msg_iov[i].iov_len;
  78. memcpy((uintptr_t)p + (char *)msg1, msg->msg_iov[i].iov_base,
  79. (size_t)(msg->msg_iov[i].iov_len));
  80. p += msg->msg_iov[i].iov_len;
  81. }
  82. }
  83. if (ocall_sendmsg(&ret, sockfd, (void *)msg1, (uint32)total_size,
  84. flags) != SGX_SUCCESS) {
  85. TRACE_OCALL_FAIL();
  86. return -1;
  87. }
  88. if (ret == -1)
  89. errno = get_errno();
  90. return ret;
  91. }
  92. ssize_t recvmsg(int sockfd, struct msghdr *msg, int flags)
  93. {
  94. ssize_t ret;
  95. int i;
  96. char *p;
  97. struct msghdr *msg1;
  98. uint64 total_size = sizeof(struct msghdr) + (uint64)msg->msg_namelen
  99. + (uint64)msg->msg_controllen;
  100. total_size += sizeof(struct iovec) * (msg->msg_iovlen);
  101. for (i = 0; i < msg->msg_iovlen; i++) {
  102. total_size += msg->msg_iov[i].iov_len;
  103. }
  104. if (total_size >= UINT32_MAX)
  105. return -1;
  106. msg1 = BH_MALLOC((uint32)total_size);
  107. if (msg1 == NULL)
  108. return -1;
  109. memset(msg1, 0, total_size);
  110. p = (char*)(uintptr_t)sizeof(struct msghdr);
  111. if (msg->msg_name != NULL) {
  112. msg1->msg_name = p;
  113. p += msg->msg_namelen;
  114. }
  115. if (msg->msg_control != NULL) {
  116. msg1->msg_control = p;
  117. p += msg->msg_controllen;
  118. }
  119. if (msg->msg_iov != NULL) {
  120. msg1->msg_iov = (struct iovec *)p;
  121. p += (uintptr_t)(sizeof(struct iovec) * (msg->msg_iovlen));
  122. for (i = 0; i < msg->msg_iovlen; i++) {
  123. msg1->msg_iov[i].iov_base = p;
  124. msg1->msg_iov[i].iov_len = msg->msg_iov[i].iov_len;
  125. p += msg->msg_iov[i].iov_len;
  126. }
  127. }
  128. if (ocall_recvmsg(&ret, sockfd, (void *)msg1, (uint32)total_size,
  129. flags) != SGX_SUCCESS) {
  130. TRACE_OCALL_FAIL();
  131. return -1;
  132. }
  133. p = (char *)(uintptr_t)(sizeof(struct msghdr));
  134. if (msg1->msg_name != NULL) {
  135. memcpy(msg->msg_name, (uintptr_t)p + (char *)msg1,
  136. (size_t)msg1->msg_namelen);
  137. p += msg1->msg_namelen;
  138. }
  139. if (msg1->msg_control != NULL) {
  140. memcpy(msg->msg_control, (uintptr_t)p + (char *)msg1,
  141. (size_t)msg1->msg_control);
  142. p += msg->msg_controllen;
  143. }
  144. if (msg1->msg_iov != NULL) {
  145. p += (uintptr_t)(sizeof(struct iovec) * (msg1->msg_iovlen));
  146. for (i = 0; i < msg1->msg_iovlen; i++) {
  147. memcpy(msg->msg_iov[i].iov_base, (uintptr_t)p + (char *)msg1,
  148. (size_t)(msg1->msg_iov[i].iov_len));
  149. p += msg1->msg_iov[i].iov_len;
  150. }
  151. }
  152. if (ret == -1)
  153. errno = get_errno();
  154. return ret;
  155. }
  156. int shutdown(int sockfd, int how)
  157. {
  158. int ret;
  159. if (ocall_shutdown(&ret, sockfd, how) != SGX_SUCCESS) {
  160. TRACE_OCALL_FAIL();
  161. return -1;
  162. }
  163. if (ret == -1)
  164. errno = get_errno();
  165. return ret;
  166. }
  167. #endif