nettype_tls.c 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. /*
  2. * @Author: jiejie
  3. * @Github: https://github.com/jiejieTop
  4. * @Date: 2020-01-11 19:45:35
  5. * @LastEditTime: 2020-09-20 14:29:06
  6. * @Description: the code belongs to jiejie, please keep the author
  7. * information and source code according to the license.
  8. */
  9. #include "nettype_tls.h"
  10. #include "platform_memory.h"
  11. #include "platform_net_socket.h"
  12. #include "random.h"
  13. #ifndef MQTT_NETWORK_TYPE_NO_TLS
  14. #include "mbedtls/ctr_drbg.h"
  15. #include "mbedtls/debug.h"
  16. #include "mbedtls/entropy.h"
  17. #include "mbedtls/error.h"
  18. #include "mbedtls/net_sockets.h"
  19. #include "mbedtls/pk.h"
  20. #include "mbedtls/platform.h"
  21. #include "mbedtls/ssl.h"
  22. #include "mbedtls/x509_crt.h"
  23. #if defined(MBEDTLS_X509_CRT_PARSE_C)
  24. static int server_certificate_verify(void* hostname,
  25. mbedtls_x509_crt* crt,
  26. int depth,
  27. uint32_t* flags) {
  28. if (0 != *flags)
  29. MQTT_LOG_E(
  30. "%s:%d %s()... server_certificate_verify failed returned 0x%04x\n",
  31. __FILE__, __LINE__, __FUNCTION__, *flags);
  32. return *flags;
  33. }
  34. #endif
  35. static int nettype_tls_entropy_source(void* data,
  36. uint8_t* output,
  37. size_t len,
  38. size_t* out_len) {
  39. uint32_t seed;
  40. (void)data;
  41. seed = random_number();
  42. if (len > sizeof(seed)) {
  43. len = sizeof(seed);
  44. }
  45. memcpy(output, &seed, len);
  46. *out_len = len;
  47. return 0;
  48. }
  49. static int nettype_tls_init(network_t* n,
  50. nettype_tls_params_t* nettype_tls_params) {
  51. int rc = MQTT_SUCCESS_ERROR;
  52. mbedtls_platform_set_calloc_free(platform_memory_calloc,
  53. platform_memory_free);
  54. mbedtls_net_init(&(nettype_tls_params->socket_fd));
  55. mbedtls_ssl_init(&(nettype_tls_params->ssl));
  56. mbedtls_ssl_config_init(&(nettype_tls_params->ssl_conf));
  57. mbedtls_ctr_drbg_init(&(nettype_tls_params->ctr_drbg));
  58. #if defined(MBEDTLS_X509_CRT_PARSE_C)
  59. mbedtls_x509_crt_init(&(nettype_tls_params->ca_cert));
  60. mbedtls_x509_crt_init(&(nettype_tls_params->client_cert));
  61. mbedtls_pk_init(&(nettype_tls_params->private_key));
  62. #endif
  63. mbedtls_entropy_init(&(nettype_tls_params->entropy));
  64. mbedtls_entropy_add_source(
  65. &(nettype_tls_params->entropy), nettype_tls_entropy_source, NULL,
  66. MBEDTLS_ENTROPY_MAX_GATHER, MBEDTLS_ENTROPY_SOURCE_STRONG);
  67. if ((rc = mbedtls_ctr_drbg_seed(
  68. &(nettype_tls_params->ctr_drbg), mbedtls_entropy_func,
  69. &(nettype_tls_params->entropy), NULL, 0)) != 0) {
  70. MQTT_LOG_E("mbedtls_ctr_drbg_seed failed returned 0x%04x",
  71. (rc < 0) ? -rc : rc);
  72. RETURN_ERROR(rc);
  73. }
  74. if ((rc = mbedtls_ssl_config_defaults(
  75. &(nettype_tls_params->ssl_conf), MBEDTLS_SSL_IS_CLIENT,
  76. MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)) != 0) {
  77. MQTT_LOG_E("mbedtls_ssl_config_defaults failed returned 0x%04x",
  78. (rc < 0) ? -rc : rc);
  79. RETURN_ERROR(rc);
  80. }
  81. mbedtls_ssl_conf_rng(&(nettype_tls_params->ssl_conf),
  82. mbedtls_ctr_drbg_random,
  83. &(nettype_tls_params->ctr_drbg));
  84. #if defined(MBEDTLS_X509_CRT_PARSE_C)
  85. if (NULL != n->ca_crt) {
  86. n->ca_crt_len = strlen(n->ca_crt);
  87. if (0 != (rc = (mbedtls_x509_crt_parse(&(nettype_tls_params->ca_cert),
  88. (unsigned char*)n->ca_crt,
  89. (n->ca_crt_len + 1))))) {
  90. MQTT_LOG_E("%s:%d %s()... parse ca crt failed returned 0x%04x",
  91. __FILE__, __LINE__, __FUNCTION__, (rc < 0) ? -rc : rc);
  92. RETURN_ERROR(rc);
  93. }
  94. }
  95. mbedtls_ssl_conf_ca_chain(&(nettype_tls_params->ssl_conf),
  96. &(nettype_tls_params->ca_cert), NULL);
  97. if ((rc = mbedtls_ssl_conf_own_cert(&(nettype_tls_params->ssl_conf),
  98. &(nettype_tls_params->client_cert),
  99. &(nettype_tls_params->private_key))) !=
  100. 0) {
  101. MQTT_LOG_E(
  102. "%s:%d %s()... mbedtls_ssl_conf_own_cert failed returned 0x%04x",
  103. __FILE__, __LINE__, __FUNCTION__, (rc < 0) ? -rc : rc);
  104. RETURN_ERROR(rc);
  105. }
  106. mbedtls_ssl_conf_verify(&(nettype_tls_params->ssl_conf),
  107. server_certificate_verify, (void*)n->host);
  108. mbedtls_ssl_conf_authmode(&(nettype_tls_params->ssl_conf),
  109. MBEDTLS_SSL_VERIFY_REQUIRED);
  110. #endif
  111. mbedtls_ssl_conf_read_timeout(&(nettype_tls_params->ssl_conf),
  112. n->timeout_ms);
  113. if ((rc = mbedtls_ssl_setup(&(nettype_tls_params->ssl),
  114. &(nettype_tls_params->ssl_conf))) != 0) {
  115. MQTT_LOG_E("mbedtls_ssl_setup failed returned 0x%04x",
  116. (rc < 0) ? -rc : rc);
  117. RETURN_ERROR(rc);
  118. }
  119. #if defined(MBEDTLS_X509_CRT_PARSE_C)
  120. if ((rc = mbedtls_ssl_set_hostname(&(nettype_tls_params->ssl), n->host)) !=
  121. 0) {
  122. MQTT_LOG_E(
  123. "%s:%d %s()... mbedtls_ssl_set_hostname failed returned 0x%04x",
  124. __FILE__, __LINE__, __FUNCTION__, (rc < 0) ? -rc : rc);
  125. RETURN_ERROR(rc);
  126. }
  127. #endif
  128. mbedtls_ssl_set_bio(&(nettype_tls_params->ssl),
  129. &(nettype_tls_params->socket_fd), mbedtls_net_send,
  130. mbedtls_net_recv, mbedtls_net_recv_timeout);
  131. RETURN_ERROR(MQTT_SUCCESS_ERROR);
  132. }
  133. int nettype_tls_connect(network_t* n) {
  134. int rc;
  135. if (NULL == n)
  136. RETURN_ERROR(MQTT_NULL_VALUE_ERROR);
  137. nettype_tls_params_t* nettype_tls_params =
  138. (nettype_tls_params_t*)platform_memory_alloc(
  139. sizeof(nettype_tls_params_t));
  140. if (NULL == nettype_tls_params)
  141. RETURN_ERROR(MQTT_MEM_NOT_ENOUGH_ERROR);
  142. rc = nettype_tls_init(n, nettype_tls_params);
  143. if (MQTT_SUCCESS_ERROR != rc)
  144. goto exit;
  145. if (0 !=
  146. (rc = mbedtls_net_connect(&(nettype_tls_params->socket_fd), n->host,
  147. n->port, MBEDTLS_NET_PROTO_TCP)))
  148. goto exit;
  149. while ((rc = mbedtls_ssl_handshake(&(nettype_tls_params->ssl))) != 0) {
  150. if (rc != MBEDTLS_ERR_SSL_WANT_READ &&
  151. rc != MBEDTLS_ERR_SSL_WANT_WRITE) {
  152. MQTT_LOG_E("%s:%d %s()...mbedtls handshake failed returned 0x%04x",
  153. __FILE__, __LINE__, __FUNCTION__, (rc < 0) ? -rc : rc);
  154. #if defined(MBEDTLS_X509_CRT_PARSE_C)
  155. if (rc == MBEDTLS_ERR_X509_CERT_VERIFY_FAILED) {
  156. MQTT_LOG_E(
  157. "%s:%d %s()...unable to verify the server's certificate",
  158. __FILE__, __LINE__, __FUNCTION__);
  159. }
  160. #endif
  161. goto exit;
  162. }
  163. }
  164. if ((rc = mbedtls_ssl_get_verify_result(&(nettype_tls_params->ssl))) != 0) {
  165. MQTT_LOG_E("%s:%d %s()...mbedtls_ssl_get_verify_result returned 0x%04x",
  166. __FILE__, __LINE__, __FUNCTION__, (rc < 0) ? -rc : rc);
  167. goto exit;
  168. }
  169. n->nettype_tls_params = nettype_tls_params;
  170. RETURN_ERROR(MQTT_SUCCESS_ERROR)
  171. exit:
  172. platform_memory_free(nettype_tls_params);
  173. RETURN_ERROR(rc);
  174. }
  175. void nettype_tls_disconnect(network_t* n) {
  176. int rc = 0;
  177. if (NULL == n)
  178. return;
  179. nettype_tls_params_t* nettype_tls_params =
  180. (nettype_tls_params_t*)n->nettype_tls_params;
  181. do {
  182. rc = mbedtls_ssl_close_notify(&(nettype_tls_params->ssl));
  183. } while (rc == MBEDTLS_ERR_SSL_WANT_READ ||
  184. rc == MBEDTLS_ERR_SSL_WANT_WRITE);
  185. mbedtls_net_free(&(nettype_tls_params->socket_fd));
  186. #if defined(MBEDTLS_X509_CRT_PARSE_C)
  187. mbedtls_x509_crt_free(&(nettype_tls_params->client_cert));
  188. mbedtls_x509_crt_free(&(nettype_tls_params->ca_cert));
  189. mbedtls_pk_free(&(nettype_tls_params->private_key));
  190. #endif
  191. mbedtls_ssl_free(&(nettype_tls_params->ssl));
  192. mbedtls_ssl_config_free(&(nettype_tls_params->ssl_conf));
  193. mbedtls_ctr_drbg_free(&(nettype_tls_params->ctr_drbg));
  194. mbedtls_entropy_free(&(nettype_tls_params->entropy));
  195. platform_memory_free(nettype_tls_params);
  196. }
  197. int nettype_tls_write(network_t* n, unsigned char* buf, int len, int timeout) {
  198. int rc = 0;
  199. int write_len = 0;
  200. pika_platform_timer_t timer;
  201. if (NULL == n)
  202. RETURN_ERROR(MQTT_NULL_VALUE_ERROR);
  203. nettype_tls_params_t* nettype_tls_params =
  204. (nettype_tls_params_t*)n->nettype_tls_params;
  205. pika_platform_thread_timer_cutdown(&timer, timeout);
  206. do {
  207. rc = mbedtls_ssl_write(&(nettype_tls_params->ssl),
  208. (unsigned char*)(buf + write_len),
  209. len - write_len);
  210. if (rc > 0) {
  211. write_len += rc;
  212. } else if ((rc == 0) || ((rc != MBEDTLS_ERR_SSL_WANT_WRITE) &&
  213. (rc != MBEDTLS_ERR_SSL_WANT_READ) &&
  214. (rc != MBEDTLS_ERR_SSL_TIMEOUT))) {
  215. MQTT_LOG_E("%s:%d %s()... mbedtls_ssl_write failed: 0x%04x",
  216. __FILE__, __LINE__, __FUNCTION__, (rc < 0) ? -rc : rc);
  217. break;
  218. }
  219. } while ((!pika_platform_thread_timer_is_expired(&timer)) &&
  220. (write_len < len));
  221. return write_len;
  222. }
  223. int nettype_tls_read(network_t* n, unsigned char* buf, int len, int timeout) {
  224. int rc = 0;
  225. int read_len = 0;
  226. pika_platform_timer_t timer;
  227. if (NULL == n)
  228. RETURN_ERROR(MQTT_NULL_VALUE_ERROR);
  229. nettype_tls_params_t* nettype_tls_params =
  230. (nettype_tls_params_t*)n->nettype_tls_params;
  231. pika_platform_thread_timer_cutdown(&timer, timeout);
  232. do {
  233. rc = mbedtls_ssl_read(&(nettype_tls_params->ssl),
  234. (unsigned char*)(buf + read_len), len - read_len);
  235. if (rc > 0) {
  236. read_len += rc;
  237. } else if ((rc == 0) || ((rc != MBEDTLS_ERR_SSL_WANT_WRITE) &&
  238. (rc != MBEDTLS_ERR_SSL_WANT_READ) &&
  239. (rc != MBEDTLS_ERR_SSL_TIMEOUT))) {
  240. // MQTT_LOG_E("%s:%d %s()... mbedtls_ssl_read failed: 0x%04x",
  241. // __FILE__, __LINE__, __FUNCTION__, (rc < 0 )? -rc : rc);
  242. break;
  243. }
  244. } while ((!pika_platform_thread_timer_is_expired(&timer)) &&
  245. (read_len < len));
  246. return read_len;
  247. }
  248. #endif /* MQTT_NETWORK_TYPE_NO_TLS */