sgx_socket.c 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. /*
  2. * Copyright (C) 2019 Intel Corporation. All rights reserved.
  3. * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  4. */
  5. #include "platform_api_vmcore.h"
  6. #define TRACE_OCALL_FAIL() os_printf("ocall %s failed!\n", __FUNCTION__)
  7. int ocall_socket(int *p_ret, int domain, int type, int protocol);
  8. int ocall_getsockopt(int *p_ret, int sockfd, int level, int optname,
  9. void *val_buf, unsigned int val_buf_size,
  10. void *len_buf);
  11. int ocall_sendmsg(ssize_t *p_ret, int sockfd, void *msg_buf,
  12. unsigned int msg_buf_size, int flags);
  13. int ocall_recvmsg(ssize_t *p_ret, int sockfd, void *msg_buf,
  14. unsigned int msg_buf_size, int flags);
  15. int ocall_shutdown(int *p_ret, int sockfd, int how);
  16. int socket(int domain, int type, int protocol)
  17. {
  18. int ret;
  19. if (ocall_socket(&ret, domain, type, protocol) != SGX_SUCCESS) {
  20. TRACE_OCALL_FAIL();
  21. return -1;
  22. }
  23. if (ret == -1)
  24. errno = get_errno();
  25. return ret;
  26. }
  27. int getsockopt(int sockfd, int level, int optname,
  28. void *optval, socklen_t *optlen)
  29. {
  30. int ret;
  31. unsigned int val_buf_size = *optlen;
  32. if (ocall_getsockopt(&ret, sockfd, level, optname, optval,
  33. val_buf_size, (void *)optlen) != SGX_SUCCESS) {
  34. TRACE_OCALL_FAIL();
  35. return -1;
  36. }
  37. if (ret == -1)
  38. errno = get_errno();
  39. return ret;
  40. }
  41. ssize_t sendmsg(int sockfd, const struct msghdr *msg, int flags)
  42. {
  43. ssize_t ret;
  44. int i;
  45. char *p;
  46. struct msghdr *msg1;
  47. uint64 total_size = sizeof(struct msghdr) + (uint64)msg->msg_namelen
  48. + (uint64)msg->msg_controllen;
  49. total_size += sizeof(struct iovec) * (msg->msg_iovlen);
  50. for (i = 0; i < msg->msg_iovlen; i++) {
  51. total_size += msg->msg_iov[i].iov_len;
  52. }
  53. if (total_size >= UINT32_MAX)
  54. return -1;
  55. msg1 = BH_MALLOC((uint32)total_size);
  56. if (msg1 == NULL)
  57. return -1;
  58. p = (char*)(uintptr_t)sizeof(struct msghdr);
  59. if (msg->msg_name != NULL) {
  60. msg1->msg_name = p;
  61. memcpy((uintptr_t)p + (char *)msg1, msg->msg_name,
  62. (size_t)msg->msg_namelen);
  63. p += msg->msg_namelen;
  64. }
  65. if (msg->msg_control != NULL) {
  66. msg1->msg_control = p;
  67. memcpy((uintptr_t)p + (char *)msg1, msg->msg_control,
  68. (size_t)msg->msg_control);
  69. p += msg->msg_controllen;
  70. }
  71. if (msg->msg_iov != NULL) {
  72. msg1->msg_iov = (struct iovec *)p;
  73. p += (uintptr_t)(sizeof(struct iovec) * (msg->msg_iovlen));
  74. for (i = 0; i < msg->msg_iovlen; i++) {
  75. msg1->msg_iov[i].iov_base = p;
  76. msg1->msg_iov[i].iov_len = msg->msg_iov[i].iov_len;
  77. memcpy((uintptr_t)p + (char *)msg1, msg->msg_iov[i].iov_base,
  78. (size_t)(msg->msg_iov[i].iov_len));
  79. p += msg->msg_iov[i].iov_len;
  80. }
  81. }
  82. if (ocall_sendmsg(&ret, sockfd, (void *)msg1, (uint32)total_size,
  83. flags) != SGX_SUCCESS) {
  84. TRACE_OCALL_FAIL();
  85. return -1;
  86. }
  87. if (ret == -1)
  88. errno = get_errno();
  89. return ret;
  90. }
  91. ssize_t recvmsg(int sockfd, struct msghdr *msg, int flags)
  92. {
  93. ssize_t ret;
  94. int i;
  95. char *p;
  96. struct msghdr *msg1;
  97. uint64 total_size = sizeof(struct msghdr) + (uint64)msg->msg_namelen
  98. + (uint64)msg->msg_controllen;
  99. total_size += sizeof(struct iovec) * (msg->msg_iovlen);
  100. for (i = 0; i < msg->msg_iovlen; i++) {
  101. total_size += msg->msg_iov[i].iov_len;
  102. }
  103. if (total_size >= UINT32_MAX)
  104. return -1;
  105. msg1 = BH_MALLOC((uint32)total_size);
  106. if (msg1 == NULL)
  107. return -1;
  108. memset(msg1, 0, total_size);
  109. p = (char*)(uintptr_t)sizeof(struct msghdr);
  110. if (msg->msg_name != NULL) {
  111. msg1->msg_name = p;
  112. p += msg->msg_namelen;
  113. }
  114. if (msg->msg_control != NULL) {
  115. msg1->msg_control = p;
  116. p += msg->msg_controllen;
  117. }
  118. if (msg->msg_iov != NULL) {
  119. msg1->msg_iov = (struct iovec *)p;
  120. p += (uintptr_t)(sizeof(struct iovec) * (msg->msg_iovlen));
  121. for (i = 0; i < msg->msg_iovlen; i++) {
  122. msg1->msg_iov[i].iov_base = p;
  123. msg1->msg_iov[i].iov_len = msg->msg_iov[i].iov_len;
  124. p += msg->msg_iov[i].iov_len;
  125. }
  126. }
  127. if (ocall_recvmsg(&ret, sockfd, (void *)msg1, (uint32)total_size,
  128. flags) != SGX_SUCCESS) {
  129. TRACE_OCALL_FAIL();
  130. return -1;
  131. }
  132. p = (char *)(uintptr_t)(sizeof(struct msghdr));
  133. if (msg1->msg_name != NULL) {
  134. memcpy(msg->msg_name, (uintptr_t)p + (char *)msg1,
  135. (size_t)msg1->msg_namelen);
  136. p += msg1->msg_namelen;
  137. }
  138. if (msg1->msg_control != NULL) {
  139. memcpy(msg->msg_control, (uintptr_t)p + (char *)msg1,
  140. (size_t)msg1->msg_control);
  141. p += msg->msg_controllen;
  142. }
  143. if (msg1->msg_iov != NULL) {
  144. p += (uintptr_t)(sizeof(struct iovec) * (msg1->msg_iovlen));
  145. for (i = 0; i < msg1->msg_iovlen; i++) {
  146. memcpy(msg->msg_iov[i].iov_base, (uintptr_t)p + (char *)msg1,
  147. (size_t)(msg1->msg_iov[i].iov_len));
  148. p += msg1->msg_iov[i].iov_len;
  149. }
  150. }
  151. if (ret == -1)
  152. errno = get_errno();
  153. return ret;
  154. }
  155. int shutdown(int sockfd, int how)
  156. {
  157. int ret;
  158. if (ocall_shutdown(&ret, sockfd, how) != SGX_SUCCESS) {
  159. TRACE_OCALL_FAIL();
  160. return -1;
  161. }
  162. if (ret == -1)
  163. errno = get_errno();
  164. return ret;
  165. }