HAL_TLS_mbedtls.c 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. /*
  2. * Copyright (C) 2012-2019 UCloud. All Rights Reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License").
  5. * You may not use this file except in compliance with the License.
  6. * A copy of the License is located at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * or in the "license" file accompanying this file. This file is distributed
  11. * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
  12. * express or implied. See the License for the specific language governing
  13. * permissions and limitations under the License.
  14. */
  15. #include <stdint.h>
  16. #include <string.h>
  17. #include <errno.h>
  18. #ifdef __cplusplus
  19. extern "C" {
  20. #endif
  21. #include "uiot_import.h"
  22. #include "uiot_defs.h"
  23. #include "HAL_Timer_Platform.h"
  24. #include "mbedtls/ssl.h"
  25. #include "mbedtls/entropy.h"
  26. #include "mbedtls/net_sockets.h"
  27. #include "mbedtls/ctr_drbg.h"
  28. #include "mbedtls/error.h"
  29. /**
  30. * @brief 用于保存SSL连接相关数据结构
  31. */
  32. typedef struct {
  33. mbedtls_net_context socket_fd; // socket文件描述符
  34. mbedtls_entropy_context entropy; // 保存熵配置
  35. mbedtls_ctr_drbg_context ctr_drbg; // 随机数生成器
  36. mbedtls_ssl_context ssl; // 保存SSL基本数据
  37. mbedtls_ssl_config ssl_conf; // SSL/TLS配置信息
  38. mbedtls_x509_crt ca_cert; // ca证书信息
  39. mbedtls_x509_crt client_cert; // 客户端证书信息
  40. mbedtls_pk_context private_key; // 客户端私钥信息
  41. } TLSDataParams;
  42. /**
  43. * @brief 释放mbedtls开辟的内存
  44. */
  45. static void _free_mbedtls(TLSDataParams *pParams) {
  46. mbedtls_net_free(&(pParams->socket_fd));
  47. mbedtls_x509_crt_free(&(pParams->client_cert));
  48. mbedtls_x509_crt_free(&(pParams->ca_cert));
  49. mbedtls_pk_free(&(pParams->private_key));
  50. mbedtls_ssl_free(&(pParams->ssl));
  51. mbedtls_ssl_config_free(&(pParams->ssl_conf));
  52. mbedtls_ctr_drbg_free(&(pParams->ctr_drbg));
  53. mbedtls_entropy_free(&(pParams->entropy));
  54. HAL_Free(pParams);
  55. }
  56. /**
  57. * @brief mbedtls库初始化
  58. *
  59. * 1. 执行mbedtls库相关初始化函数
  60. * 2. 随机数生成器
  61. * 3. 加载CA证书
  62. *
  63. * @param pDataParams TLS连接相关数据结构
  64. * @param pConnectParams TLS证书密钥相关
  65. * @return 返回SUCCESS, 表示成功
  66. */
  67. static int _mbedtls_client_init(TLSDataParams *pDataParams, const char *ca_crt, size_t ca_crt_len) {
  68. int ret;
  69. mbedtls_net_init(&(pDataParams->socket_fd));
  70. mbedtls_ssl_init(&(pDataParams->ssl));
  71. mbedtls_ssl_config_init(&(pDataParams->ssl_conf));
  72. mbedtls_ctr_drbg_init(&(pDataParams->ctr_drbg));
  73. mbedtls_x509_crt_init(&(pDataParams->ca_cert));
  74. mbedtls_x509_crt_init(&(pDataParams->client_cert));
  75. mbedtls_pk_init(&(pDataParams->private_key));
  76. LOG_DEBUG("Seeding the random number generator...");
  77. mbedtls_entropy_init(&(pDataParams->entropy));
  78. // 随机数, 增加custom参数, 目前为NULL
  79. if ((ret = mbedtls_ctr_drbg_seed(&(pDataParams->ctr_drbg), mbedtls_entropy_func,
  80. &(pDataParams->entropy), NULL, 0)) != 0) {
  81. LOG_ERROR("failed! mbedtls_ctr_drbg_seed returned -0x%x\n", -ret);
  82. return ERR_SSL_INIT_FAILED;
  83. }
  84. LOG_DEBUG("Loading the CA root certificate ...");
  85. if (ca_crt != NULL) {
  86. if ((ret = mbedtls_x509_crt_parse(&(pDataParams->ca_cert), (const unsigned char *) ca_crt,
  87. (ca_crt_len + 1 )))) {
  88. LOG_ERROR("failed! mbedtls_x509_crt_parse returned -0x%x while parsing root cert\n", -ret);
  89. return ERR_SSL_CERT_FAILED;
  90. }
  91. }
  92. return SUCCESS_RET;
  93. }
  94. /**
  95. * @brief 建立TCP连接
  96. *
  97. * @param socket_fd Socket描述符
  98. * @param host 服务器主机名
  99. * @param port 服务器端口地址
  100. * @return 返回SUCCESS, 表示成功
  101. */
  102. int _mbedtls_tcp_connect(mbedtls_net_context *socket_fd, const char *host, uint16_t port) {
  103. int ret = 0;
  104. char port_str[6];
  105. HAL_Snprintf(port_str, 6, "%d", port);
  106. if ((ret = mbedtls_net_connect(socket_fd, host, port_str, MBEDTLS_NET_PROTO_TCP)) != 0) {
  107. LOG_ERROR("failed! mbedtls_net_connect returned -0x%x\n", -ret);
  108. switch (ret) {
  109. case MBEDTLS_ERR_NET_SOCKET_FAILED:
  110. return ERR_TCP_SOCKET_FAILED;
  111. case MBEDTLS_ERR_NET_UNKNOWN_HOST:
  112. return ERR_TCP_UNKNOWN_HOST;
  113. default:
  114. return ERR_TCP_CONNECT_FAILED;
  115. }
  116. }
  117. #if 0
  118. if ((ret = mbedtls_net_set_block(socket_fd)) != 0) {
  119. LOG_ERROR("failed! net_set_(non)block() returned -0x%x\n", -ret);
  120. return ERR_TCP_CONNECT_FAILED;
  121. }
  122. #endif
  123. return SUCCESS_RET;
  124. }
  125. /**
  126. * @brief 在该函数中可对服务端证书进行自定义的校验
  127. *
  128. * 这种行为发生在握手过程中, 一般是校验连接服务器的主机名与服务器证书中的CN或SAN的域名信息是否一致
  129. * 不过, mbedtls库已经实现该功能, 可以参考函数 `mbedtls_x509_crt_verify_with_profile`
  130. *
  131. * @param hostname 连接服务器的主机名
  132. * @param crt x509格式的证书
  133. * @param depth
  134. * @param flags
  135. * @return
  136. */
  137. int _server_certificate_verify(void *hostname, mbedtls_x509_crt *crt, int depth, uint32_t *flags) {
  138. return *flags;
  139. }
  140. uintptr_t HAL_TLS_Connect(_IN_ const char *host, _IN_ uint16_t port, _IN_ uint16_t authmode, _IN_ const char *ca_crt,
  141. _IN_ size_t ca_crt_len) {
  142. int ret = 0;
  143. TLSDataParams *pDataParams = (TLSDataParams *) HAL_Malloc(sizeof(TLSDataParams));
  144. if ((ret = _mbedtls_client_init(pDataParams, ca_crt, ca_crt_len)) != SUCCESS_RET) {
  145. goto error;
  146. }
  147. LOG_INFO("Connecting to /%s/%d...", host, port);
  148. if ((ret = _mbedtls_tcp_connect(&(pDataParams->socket_fd), host, port)) != SUCCESS_RET) {
  149. goto error;
  150. }
  151. LOG_DEBUG("Setting up the SSL/TLS structure...");
  152. if ((ret = mbedtls_ssl_config_defaults(&(pDataParams->ssl_conf), MBEDTLS_SSL_IS_CLIENT,
  153. MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)) != 0) {
  154. LOG_ERROR("failed! mbedtls_ssl_config_defaults returned -0x%x\n", -ret);
  155. goto error;
  156. }
  157. mbedtls_ssl_conf_verify(&(pDataParams->ssl_conf), _server_certificate_verify, (void *) host);
  158. mbedtls_ssl_conf_authmode(&(pDataParams->ssl_conf), authmode);
  159. mbedtls_ssl_conf_rng(&(pDataParams->ssl_conf), mbedtls_ctr_drbg_random, &(pDataParams->ctr_drbg));
  160. mbedtls_ssl_conf_ca_chain(&(pDataParams->ssl_conf), &(pDataParams->ca_cert), NULL);
  161. mbedtls_ssl_conf_read_timeout(&(pDataParams->ssl_conf), 10000);
  162. if ((ret = mbedtls_ssl_setup(&(pDataParams->ssl), &(pDataParams->ssl_conf))) != 0) {
  163. LOG_ERROR("failed! mbedtls_ssl_setup returned -0x%x\n", -ret);
  164. goto error;
  165. }
  166. // Set the hostname to check against the received server certificate and sni
  167. if ((ret = mbedtls_ssl_set_hostname(&(pDataParams->ssl), host)) != 0) {
  168. LOG_ERROR("failed! mbedtls_ssl_set_hostname returned %d\n", ret);
  169. goto error;
  170. }
  171. LOG_DEBUG("SSL state connect : %d ", pDataParams->ssl.state);
  172. mbedtls_ssl_set_bio(&(pDataParams->ssl), &(pDataParams->socket_fd), mbedtls_net_send, mbedtls_net_recv,
  173. mbedtls_net_recv_timeout);
  174. LOG_DEBUG("SSL state connect : %d ", pDataParams->ssl.state);
  175. LOG_DEBUG("Performing the SSL/TLS handshake...");
  176. while ((ret = mbedtls_ssl_handshake(&(pDataParams->ssl))) != 0) {
  177. if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
  178. LOG_ERROR("failed! mbedtls_ssl_handshake returned -0x%x\n", -ret);
  179. if (ret == MBEDTLS_ERR_X509_CERT_VERIFY_FAILED) {
  180. LOG_ERROR("Unable to verify the server's certificate");
  181. }
  182. goto error;
  183. }
  184. }
  185. if ((ret = mbedtls_ssl_get_verify_result(&(pDataParams->ssl))) != 0) {
  186. LOG_ERROR("mbedtls_ssl_get_verify_result failed returned 0x%04x\n", -ret);
  187. goto error;
  188. }
  189. //mbedtls_ssl_conf_read_timeout(&(pDataParams->ssl_conf), 100);
  190. LOG_INFO("connected with /%s/%d...", host, port);
  191. return (uintptr_t) pDataParams;
  192. error:
  193. _free_mbedtls(pDataParams);
  194. return 0;
  195. }
  196. int32_t HAL_TLS_Disconnect(_IN_ uintptr_t handle) {
  197. if ((uintptr_t) NULL == handle) {
  198. LOG_DEBUG("handle is NULL");
  199. return FAILURE_RET;
  200. }
  201. TLSDataParams *pParams = (TLSDataParams *) handle;
  202. int ret = 0;
  203. do {
  204. ret = mbedtls_ssl_close_notify(&(pParams->ssl));
  205. } while (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE);
  206. _free_mbedtls(pParams);
  207. return SUCCESS_RET;
  208. }
  209. int32_t HAL_TLS_Write(_IN_ uintptr_t handle, _IN_ unsigned char *buf, _IN_ size_t len, _IN_ uint32_t timeout_ms) {
  210. Timer timer;
  211. HAL_Timer_Init(&timer);
  212. HAL_Timer_Countdown_ms(&timer, (unsigned int) timeout_ms);
  213. size_t written_so_far;
  214. bool errorFlag = false;
  215. int write_rc = 0;
  216. TLSDataParams *pParams = (TLSDataParams *) handle;
  217. for (written_so_far = 0; written_so_far < len && !HAL_Timer_Expired(&timer); written_so_far += write_rc) {
  218. while (!HAL_Timer_Expired(&timer) &&
  219. (write_rc = mbedtls_ssl_write(&(pParams->ssl), (unsigned char *)(buf + written_so_far), len - written_so_far)) <= 0) {
  220. if (write_rc != MBEDTLS_ERR_SSL_WANT_READ && write_rc != MBEDTLS_ERR_SSL_WANT_WRITE) {
  221. LOG_ERROR("failed! mbedtls_ssl_write returned -0x%x\n", -write_rc);
  222. errorFlag = true;
  223. break;
  224. }
  225. }
  226. if (errorFlag) {
  227. break;
  228. }
  229. }
  230. if (errorFlag) {
  231. return ERR_SSL_WRITE_FAILED;
  232. }
  233. return written_so_far;
  234. }
  235. int32_t HAL_TLS_Read(_IN_ uintptr_t handle, _OU_ unsigned char *buf, _IN_ size_t len, _IN_ uint32_t timeout_ms) {
  236. Timer timer;
  237. HAL_Timer_Init(&timer);
  238. HAL_Timer_Countdown_ms(&timer, timeout_ms);
  239. size_t read_len = 0;
  240. TLSDataParams *pParams = (TLSDataParams *) handle;
  241. while (read_len < len) {
  242. int read_rc = 0;
  243. read_rc = mbedtls_ssl_read(&(pParams->ssl), (unsigned char *)(buf + read_len), len - read_len);
  244. if (read_rc > 0) {
  245. read_len += read_rc;
  246. } else if (read_rc == 0 || (read_rc != MBEDTLS_ERR_SSL_WANT_WRITE
  247. && read_rc != MBEDTLS_ERR_SSL_WANT_READ && read_rc != MBEDTLS_ERR_SSL_TIMEOUT)) {
  248. LOG_ERROR("failed! mbedtls_ssl_read returned -0x%x\n", -read_rc);
  249. return ERR_SSL_READ_FAILED;
  250. }
  251. if (HAL_Timer_Expired(&timer)) {
  252. break;
  253. }
  254. }
  255. return read_len;
  256. }
  257. #ifdef __cplusplus
  258. }
  259. #endif