Просмотр исходного кода

Merge branch 'bugfix/fix_mbedtls_send_alert_crash_v4.2' into 'release/v4.2'

mbedtls: fix mbedtls dynamic resource memory leaks and mbedtls_ssl_send_alert_message crash due to ssl->out_iv is NULL[backport v4.2]

See merge request espressif/esp-idf!13301
Mahavir Jain 4 лет назад
Родитель
Сommit
e8f5b76112

+ 111 - 73
components/mbedtls/port/dynamic/esp_mbedtls_dynamic_impl.c

@@ -23,6 +23,33 @@
 
 static const char *TAG = "Dynamic Impl";
 
+static void esp_mbedtls_set_buf_state(unsigned char *buf, esp_mbedtls_ssl_buf_states state)
+{
+    struct esp_mbedtls_ssl_buf *temp = __containerof(buf, struct esp_mbedtls_ssl_buf, buf[0]);
+    temp->state = state;
+}
+
+static esp_mbedtls_ssl_buf_states esp_mbedtls_get_buf_state(unsigned char *buf)
+{
+    struct esp_mbedtls_ssl_buf *temp = __containerof(buf, struct esp_mbedtls_ssl_buf, buf[0]);
+    return temp->state;
+}
+
+void esp_mbedtls_free_buf(unsigned char *buf)
+{
+    struct esp_mbedtls_ssl_buf *temp = __containerof(buf, struct esp_mbedtls_ssl_buf, buf[0]);
+    ESP_LOGV(TAG, "free buffer @ %p", temp);
+    mbedtls_free(temp);
+}
+
+static void esp_mbedtls_init_ssl_buf(struct esp_mbedtls_ssl_buf *buf, unsigned int len)
+{
+    if (buf) {
+        buf->state = ESP_MBEDTLS_SSL_BUF_CACHED;
+        buf->len = len;
+    }
+}
+
 static void esp_mbedtls_parse_record_header(mbedtls_ssl_context *ssl)
 {
     ssl->in_msgtype =  ssl->in_hdr[0];
@@ -118,21 +145,22 @@ static void init_rx_buffer(mbedtls_ssl_context *ssl, unsigned char *buf)
 
 static int esp_mbedtls_alloc_tx_buf(mbedtls_ssl_context *ssl, int len)
 {
-    unsigned char *buf;
+    struct esp_mbedtls_ssl_buf *esp_buf;
 
     if (ssl->out_buf) {
-        mbedtls_free(ssl->out_buf);
+        esp_mbedtls_free_buf(ssl->out_buf);
         ssl->out_buf = NULL;
     }
 
-    buf = mbedtls_calloc(1, len);
-    if (!buf) {
-        ESP_LOGE(TAG, "alloc(%d bytes) failed", len);
+    esp_buf = mbedtls_calloc(1, SSL_BUF_HEAD_OFFSET_SIZE + len);
+    if (!esp_buf) {
+        ESP_LOGE(TAG, "alloc(%d bytes) failed", SSL_BUF_HEAD_OFFSET_SIZE + len);
         return MBEDTLS_ERR_SSL_ALLOC_FAILED;
     }
 
-    ESP_LOGV(TAG, "add out buffer %d bytes @ %p", len, buf);
+    ESP_LOGV(TAG, "add out buffer %d bytes @ %p", len, esp_buf->buf);
 
+    esp_mbedtls_init_ssl_buf(esp_buf, len);
     /**
      * Mark the out_msg offset from ssl->out_buf.
      * 
@@ -140,7 +168,7 @@ static int esp_mbedtls_alloc_tx_buf(mbedtls_ssl_context *ssl, int len)
      */
     ssl->out_msg = (unsigned char *)MBEDTLS_SSL_HEADER_LEN;
 
-    init_tx_buffer(ssl, buf);
+    init_tx_buffer(ssl, esp_buf->buf);
 
     return 0;
 }
@@ -150,7 +178,7 @@ int esp_mbedtls_setup_tx_buffer(mbedtls_ssl_context *ssl)
     CHECK_OK(esp_mbedtls_alloc_tx_buf(ssl, TX_IDLE_BUFFER_SIZE));
 
     /* mark the out buffer has no data cached */
-    ssl->out_iv = NULL;
+    esp_mbedtls_set_buf_state(ssl->out_buf, ESP_MBEDTLS_SSL_BUF_NO_CACHED);
 
     return 0;
 }
@@ -168,10 +196,7 @@ int esp_mbedtls_reset_add_tx_buffer(mbedtls_ssl_context *ssl)
 
 int esp_mbedtls_reset_free_tx_buffer(mbedtls_ssl_context *ssl)
 {
-    ESP_LOGV(TAG, "free out buffer @ %p", ssl->out_buf);
-
-    mbedtls_free(ssl->out_buf);
-
+    esp_mbedtls_free_buf(ssl->out_buf);
     init_tx_buffer(ssl, NULL);
 
     CHECK_OK(esp_mbedtls_setup_tx_buffer(ssl));
@@ -181,21 +206,22 @@ int esp_mbedtls_reset_free_tx_buffer(mbedtls_ssl_context *ssl)
 
 int esp_mbedtls_reset_add_rx_buffer(mbedtls_ssl_context *ssl)
 {
-    unsigned char *buf;
+    struct esp_mbedtls_ssl_buf *esp_buf;
 
     if (ssl->in_buf) {
-        mbedtls_free(ssl->in_buf);
+        esp_mbedtls_free_buf(ssl->in_buf);
         ssl->in_buf = NULL;
     }
 
-    buf = mbedtls_calloc(1, MBEDTLS_SSL_IN_BUFFER_LEN);
-    if (!buf) {
-        ESP_LOGE(TAG, "alloc(%d bytes) failed", MBEDTLS_SSL_IN_BUFFER_LEN);
+    esp_buf = mbedtls_calloc(1, SSL_BUF_HEAD_OFFSET_SIZE + MBEDTLS_SSL_IN_BUFFER_LEN);
+    if (!esp_buf) {
+        ESP_LOGE(TAG, "alloc(%d bytes) failed", SSL_BUF_HEAD_OFFSET_SIZE + MBEDTLS_SSL_IN_BUFFER_LEN);
         return MBEDTLS_ERR_SSL_ALLOC_FAILED;
     }
 
-    ESP_LOGV(TAG, "add in buffer %d bytes @ %p", MBEDTLS_SSL_IN_BUFFER_LEN, buf);
+    ESP_LOGV(TAG, "add in buffer %d bytes @ %p", MBEDTLS_SSL_IN_BUFFER_LEN, esp_buf->buf);
 
+    esp_mbedtls_init_ssl_buf(esp_buf, MBEDTLS_SSL_IN_BUFFER_LEN);
     /**
      * Mark the in_msg offset from ssl->in_buf.
      * 
@@ -203,38 +229,34 @@ int esp_mbedtls_reset_add_rx_buffer(mbedtls_ssl_context *ssl)
      */
     ssl->in_msg = (unsigned char *)MBEDTLS_SSL_HEADER_LEN;
 
-    init_rx_buffer(ssl, buf);
+    init_rx_buffer(ssl, esp_buf->buf);
 
     return 0;  
 }
 
 void esp_mbedtls_reset_free_rx_buffer(mbedtls_ssl_context *ssl)
 {
-    ESP_LOGV(TAG, "free in buffer @ %p", ssl->in_buf);
-
-    mbedtls_free(ssl->in_buf);
-
-    init_rx_buffer(ssl, NULL);    
+    esp_mbedtls_free_buf(ssl->in_buf);
+    init_rx_buffer(ssl, NULL);
 }
 
 int esp_mbedtls_add_tx_buffer(mbedtls_ssl_context *ssl, size_t buffer_len)
 {
     int ret = 0;
     int cached = 0;
-    unsigned char *buf;
+    struct esp_mbedtls_ssl_buf *esp_buf;
     unsigned char cache_buf[CACHE_BUFFER_SIZE];
 
     ESP_LOGV(TAG, "--> add out");
 
     if (ssl->out_buf) {
-        if (ssl->out_iv) {
+        if (esp_mbedtls_get_buf_state(ssl->out_buf) == ESP_MBEDTLS_SSL_BUF_CACHED) {
             ESP_LOGV(TAG, "out buffer is not empty");
             ret = 0;
             goto exit;
         } else {
             memcpy(cache_buf, ssl->out_buf, CACHE_BUFFER_SIZE);
-
-            mbedtls_free(ssl->out_buf);
+            esp_mbedtls_free_buf(ssl->out_buf);
             init_tx_buffer(ssl, NULL);
             cached = 1;
         }
@@ -242,15 +264,17 @@ int esp_mbedtls_add_tx_buffer(mbedtls_ssl_context *ssl, size_t buffer_len)
 
     buffer_len = tx_buffer_len(ssl, buffer_len);
 
-    buf = mbedtls_calloc(1, buffer_len);
-    if (!buf) {
-        ESP_LOGE(TAG, "alloc(%d bytes) failed", buffer_len);
+    esp_buf = mbedtls_calloc(1, SSL_BUF_HEAD_OFFSET_SIZE + buffer_len);
+    if (!esp_buf) {
+        ESP_LOGE(TAG, "alloc(%d bytes) failed", SSL_BUF_HEAD_OFFSET_SIZE + buffer_len);
         ret = MBEDTLS_ERR_SSL_ALLOC_FAILED;
         goto exit;
     }
 
-    ESP_LOGV(TAG, "add out buffer %d bytes @ %p", buffer_len, buf);
-    init_tx_buffer(ssl, buf);
+    ESP_LOGV(TAG, "add out buffer %d bytes @ %p", buffer_len, esp_buf->buf);
+
+    esp_mbedtls_init_ssl_buf(esp_buf, buffer_len);
+    init_tx_buffer(ssl, esp_buf->buf);
 
     if (cached) {
         memcpy(ssl->out_ctr, cache_buf, COUNTER_SIZE);
@@ -270,11 +294,11 @@ int esp_mbedtls_free_tx_buffer(mbedtls_ssl_context *ssl)
 {
     int ret = 0;
     unsigned char buf[CACHE_BUFFER_SIZE];
-    unsigned char *pdata;
+    struct esp_mbedtls_ssl_buf *esp_buf;
 
     ESP_LOGV(TAG, "--> free out");
 
-    if (!ssl->out_buf || (ssl->out_buf && !ssl->out_iv)) {
+    if (!ssl->out_buf || (ssl->out_buf && (esp_mbedtls_get_buf_state(ssl->out_buf) == ESP_MBEDTLS_SSL_BUF_NO_CACHED))) {
         ret = 0;
         goto exit;
     }
@@ -282,22 +306,19 @@ int esp_mbedtls_free_tx_buffer(mbedtls_ssl_context *ssl)
     memcpy(buf, ssl->out_ctr, COUNTER_SIZE);
     memcpy(buf + COUNTER_SIZE, ssl->out_iv, CACHE_IV_SIZE);
 
-    ESP_LOGV(TAG, "free out buffer @ %p", ssl->out_buf);
-
-    mbedtls_free(ssl->out_buf);
-
+    esp_mbedtls_free_buf(ssl->out_buf);
     init_tx_buffer(ssl, NULL);
 
-    pdata = mbedtls_calloc(1, TX_IDLE_BUFFER_SIZE);
-    if (!pdata) {
-        ESP_LOGE(TAG, "alloc(%d bytes) failed", TX_IDLE_BUFFER_SIZE);
+    esp_buf = mbedtls_calloc(1, SSL_BUF_HEAD_OFFSET_SIZE + TX_IDLE_BUFFER_SIZE);
+    if (!esp_buf) {
+        ESP_LOGE(TAG, "alloc(%d bytes) failed", SSL_BUF_HEAD_OFFSET_SIZE + TX_IDLE_BUFFER_SIZE);
         return MBEDTLS_ERR_SSL_ALLOC_FAILED;
     }
 
-    memcpy(pdata, buf, CACHE_BUFFER_SIZE);
-    init_tx_buffer(ssl, pdata);
-    ssl->out_iv = NULL;
-
+    esp_mbedtls_init_ssl_buf(esp_buf, TX_IDLE_BUFFER_SIZE);
+    memcpy(esp_buf->buf, buf, CACHE_BUFFER_SIZE);
+    init_tx_buffer(ssl, esp_buf->buf);
+    esp_mbedtls_set_buf_state(ssl->out_buf, ESP_MBEDTLS_SSL_BUF_NO_CACHED);
 exit:
     ESP_LOGV(TAG, "<-- free out");
 
@@ -309,7 +330,7 @@ int esp_mbedtls_add_rx_buffer(mbedtls_ssl_context *ssl)
     int cached = 0;
     int ret = 0;
     int buffer_len;
-    unsigned char *buf;
+    struct esp_mbedtls_ssl_buf *esp_buf;
     unsigned char cache_buf[16];
     unsigned char msg_head[5];
     size_t in_msglen, in_left;
@@ -317,9 +338,13 @@ int esp_mbedtls_add_rx_buffer(mbedtls_ssl_context *ssl)
     ESP_LOGV(TAG, "--> add rx");
 
     if (ssl->in_buf) {
-        ESP_LOGV(TAG, "in buffer is not empty");
-        ret = 0;
-        goto exit;
+        if (esp_mbedtls_get_buf_state(ssl->in_buf) == ESP_MBEDTLS_SSL_BUF_CACHED) {
+            ESP_LOGV(TAG, "in buffer is not empty");
+            ret = 0;
+            goto exit;
+        } else {
+            cached = 1;
+        }
     }
 
     ssl->in_hdr = msg_head;
@@ -346,22 +371,23 @@ int esp_mbedtls_add_rx_buffer(mbedtls_ssl_context *ssl)
     ESP_LOGV(TAG, "message length is %d RX buffer length should be %d left is %d",
                 (int)in_msglen, (int)buffer_len, (int)ssl->in_left);
 
-    buf = mbedtls_calloc(1, buffer_len);
-    if (!buf) {
-        ESP_LOGE(TAG, "alloc(%d bytes) failed", buffer_len);
+    if (cached) {
+        memcpy(cache_buf, ssl->in_buf, 16);
+        esp_mbedtls_free_buf(ssl->in_buf);
+        init_rx_buffer(ssl, NULL);
+    }
+
+    esp_buf = mbedtls_calloc(1, SSL_BUF_HEAD_OFFSET_SIZE + buffer_len);
+    if (!esp_buf) {
+        ESP_LOGE(TAG, "alloc(%d bytes) failed", SSL_BUF_HEAD_OFFSET_SIZE + buffer_len);
         ret = MBEDTLS_ERR_SSL_ALLOC_FAILED;
         goto exit;
     }
 
-    ESP_LOGV(TAG, "add in buffer %d bytes @ %p", buffer_len, buf);
-
-    if (ssl->in_ctr) {
-        memcpy(cache_buf, ssl->in_ctr, 16);
-        mbedtls_free(ssl->in_ctr);
-        cached = 1;
-    }
+    ESP_LOGV(TAG, "add in buffer %d bytes @ %p", buffer_len, esp_buf->buf);
 
-    init_rx_buffer(ssl, buf);
+    esp_mbedtls_init_ssl_buf(esp_buf, buffer_len);
+    init_rx_buffer(ssl, esp_buf->buf);
 
     if (cached) {
         memcpy(ssl->in_ctr, cache_buf, 8);
@@ -382,14 +408,15 @@ int esp_mbedtls_free_rx_buffer(mbedtls_ssl_context *ssl)
 {
     int ret = 0;
     unsigned char buf[16];
-    unsigned char *pdata;
+    struct esp_mbedtls_ssl_buf *esp_buf;
 
     ESP_LOGV(TAG, "--> free rx");
 
     /**
      * When have read multi messages once, can't free the input buffer directly.
      */
-    if (!ssl->in_buf || (ssl->in_hslen && (ssl->in_hslen < ssl->in_msglen))) {
+    if (!ssl->in_buf || (ssl->in_hslen && (ssl->in_hslen < ssl->in_msglen)) ||
+        (ssl->in_buf && (esp_mbedtls_get_buf_state(ssl->in_buf) == ESP_MBEDTLS_SSL_BUF_NO_CACHED))) {
         ret = 0;
         goto exit;
     }
@@ -404,22 +431,20 @@ int esp_mbedtls_free_rx_buffer(mbedtls_ssl_context *ssl)
     memcpy(buf, ssl->in_ctr, 8);
     memcpy(buf + 8, ssl->in_iv, 8);
 
-    ESP_LOGV(TAG, "free in buffer @ %p", ssl->out_buf);
-
-    mbedtls_free(ssl->in_buf);
-
+    esp_mbedtls_free_buf(ssl->in_buf);
     init_rx_buffer(ssl, NULL);
 
-    pdata = mbedtls_calloc(1, 16);
-    if (!pdata) {
-        ESP_LOGE(TAG, "alloc(%d bytes) failed", 16);
+    esp_buf = mbedtls_calloc(1, SSL_BUF_HEAD_OFFSET_SIZE + 16);
+    if (!esp_buf) {
+        ESP_LOGE(TAG, "alloc(%d bytes) failed", SSL_BUF_HEAD_OFFSET_SIZE + 16);
         ret = MBEDTLS_ERR_SSL_ALLOC_FAILED;
         goto exit;
     }
 
-    memcpy(pdata, buf, 16);
-    ssl->in_ctr = pdata;
-
+    esp_mbedtls_init_ssl_buf(esp_buf, 16);
+    memcpy(esp_buf->buf, buf, 16);
+    init_rx_buffer(ssl, esp_buf->buf);
+    esp_mbedtls_set_buf_state(ssl->in_buf, ESP_MBEDTLS_SSL_BUF_NO_CACHED);
 exit:
     ESP_LOGV(TAG, "<-- free rx");
 
@@ -516,4 +541,17 @@ void esp_mbedtls_free_peer_cert(mbedtls_ssl_context *ssl)
         ssl->session_negotiate->peer_cert = NULL;
     }
 }
+
+bool esp_mbedtls_ssl_is_rsa(mbedtls_ssl_context *ssl)
+{
+    const mbedtls_ssl_ciphersuite_t *ciphersuite_info =
+        ssl->transform_negotiate->ciphersuite_info;
+
+    if (ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_RSA ||
+        ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_RSA_PSK) {
+        return true;
+    } else {
+        return false;
+    }
+}
 #endif

+ 17 - 3
components/mbedtls/port/dynamic/esp_mbedtls_dynamic_impl.h

@@ -33,9 +33,6 @@
  \
     if ((_ret = _fn) != 0) { \
         ESP_LOGV(TAG, "\"%s\" result is -0x%x", # _fn, -_ret); \
-        if (_ret == MBEDTLS_ERR_SSL_CONN_EOF) {\
-            return 0; \
-        } \
         TRACE_CHECK(_fn, "fail"); \
         return _ret; \
     } \
@@ -44,6 +41,21 @@
  \
 })
 
+typedef enum {
+    ESP_MBEDTLS_SSL_BUF_CACHED,
+    ESP_MBEDTLS_SSL_BUF_NO_CACHED,
+} esp_mbedtls_ssl_buf_states;
+
+struct esp_mbedtls_ssl_buf {
+    esp_mbedtls_ssl_buf_states state;
+    unsigned int len;
+    unsigned char buf[];
+};
+
+#define SSL_BUF_HEAD_OFFSET_SIZE offsetof(struct esp_mbedtls_ssl_buf, buf)
+
+void esp_mbedtls_free_buf(unsigned char *buf);
+
 int esp_mbedtls_setup_tx_buffer(mbedtls_ssl_context *ssl);
 
 void esp_mbedtls_setup_rx_buffer(mbedtls_ssl_context *ssl);
@@ -82,6 +94,8 @@ void esp_mbedtls_free_cacert(mbedtls_ssl_context *ssl);
 
 #ifdef CONFIG_MBEDTLS_DYNAMIC_FREE_PEER_CERT
 void esp_mbedtls_free_peer_cert(mbedtls_ssl_context *ssl);
+
+bool esp_mbedtls_ssl_is_rsa(mbedtls_ssl_context *ssl);
 #endif
 
 #endif /* _DYNAMIC_IMPL_H_ */

+ 17 - 1
components/mbedtls/port/dynamic/esp_ssl_cli.c

@@ -73,7 +73,17 @@ static int manage_resource(mbedtls_ssl_context *ssl, bool add)
                     CHECK_OK(esp_mbedtls_free_rx_buffer(ssl));
                 }
 #ifdef CONFIG_MBEDTLS_DYNAMIC_FREE_PEER_CERT
-                esp_mbedtls_free_peer_cert(ssl);
+                /**
+                 * If current ciphersuite is RSA, we should free peer'
+                 * certificate at step  MBEDTLS_SSL_CLIENT_KEY_EXCHANGE.
+                 *
+                 * And if it is other kinds of ciphersuite, we can free
+                 * peer certificate here.
+                 */
+
+                if (esp_mbedtls_ssl_is_rsa(ssl) == false) {
+                    esp_mbedtls_free_peer_cert(ssl);
+                }
 #endif
             }
             break;
@@ -123,6 +133,12 @@ static int manage_resource(mbedtls_ssl_context *ssl, bool add)
                 size_t buffer_len = MBEDTLS_SSL_OUT_BUFFER_LEN;
 
                 CHECK_OK(esp_mbedtls_add_tx_buffer(ssl, buffer_len));
+            } else {
+#ifdef CONFIG_MBEDTLS_DYNAMIC_FREE_PEER_CERT
+                if (esp_mbedtls_ssl_is_rsa(ssl) == true) {
+                    esp_mbedtls_free_peer_cert(ssl);
+                }
+#endif
             }
             break;
         case MBEDTLS_SSL_CERTIFICATE_VERIFY:

+ 12 - 3
components/mbedtls/port/dynamic/esp_ssl_tls.c

@@ -85,7 +85,16 @@ int __wrap_mbedtls_ssl_read(mbedtls_ssl_context *ssl, unsigned char *buf, size_t
 {
     int ret;
 
-    CHECK_OK(esp_mbedtls_add_rx_buffer(ssl));
+    ESP_LOGD(TAG, "add mbedtls RX buffer");
+    ret = esp_mbedtls_add_rx_buffer(ssl);
+    if (ret == MBEDTLS_ERR_SSL_CONN_EOF) {
+        ESP_LOGD(TAG, "fail, the connection indicated an EOF");
+        return 0;
+    } else if (ret < 0) {
+        ESP_LOGD(TAG, "fail, error=-0x%x", -ret);
+        return ret;
+    }
+    ESP_LOGD(TAG, "end");
 
     ret = __real_mbedtls_ssl_read(ssl, buf, len);
 
@@ -99,12 +108,12 @@ int __wrap_mbedtls_ssl_read(mbedtls_ssl_context *ssl, unsigned char *buf, size_t
 void __wrap_mbedtls_ssl_free(mbedtls_ssl_context *ssl)
 {
     if (ssl->out_buf) {
-        mbedtls_free(ssl->out_buf);
+        esp_mbedtls_free_buf(ssl->out_buf);
         ssl->out_buf = NULL;
     }
 
     if (ssl->in_buf) {
-        mbedtls_free(ssl->in_buf);
+        esp_mbedtls_free_buf(ssl->in_buf);
         ssl->in_buf = NULL;
     }