Przeglądaj źródła

esp_transport: Use tcp_connect from esp_tls for plain TCP

so we don't have to allocate esp_tls structure (~2KB) to save heap when using plain TCP connection

Closes https://github.com/espressif/esp-idf/issues/6940
David Cermak 4 lat temu
rodzic
commit
f249ddd9ae

+ 2 - 2
components/esp-tls/esp_tls.c

@@ -267,7 +267,7 @@ static esp_err_t esp_tls_set_socket_non_blocking(int fd, bool non_blocking)
     return ESP_OK;
 }
 
-static esp_err_t esp_tcp_connect(const char *host, int hostlen, int port, int *sockfd, esp_tls_error_handle_t error_handle, const esp_tls_cfg_t *cfg)
+esp_err_t esp_tls_tcp_connect(const char *host, int hostlen, int port, const esp_tls_cfg_t *cfg, esp_tls_error_handle_t error_handle, int *sockfd)
 {
     struct sockaddr_storage address;
     int fd;
@@ -371,7 +371,7 @@ static int esp_tls_low_level_conn(const char *hostname, int hostlen, int port, c
             _esp_tls_net_init(tls);
             tls->is_tls = true;
         }
-        if ((esp_ret = esp_tcp_connect(hostname, hostlen, port, &tls->sockfd, tls->error_handle, cfg)) != ESP_OK) {
+        if ((esp_ret = esp_tls_tcp_connect(hostname, hostlen, port, cfg, tls->error_handle, &tls->sockfd)) != ESP_OK) {
             ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ESP_TLS_ERR_TYPE_ESP, esp_ret);
             return -1;
         }

+ 14 - 0
components/esp-tls/esp_tls.h

@@ -599,6 +599,20 @@ int esp_tls_server_session_create(esp_tls_cfg_server_t *cfg, int sockfd, esp_tls
 void esp_tls_server_session_delete(esp_tls_t *tls);
 #endif /* ! CONFIG_ESP_TLS_SERVER */
 
+/**
+ * @brief Creates a plain TCP connection, returning a valid socket fd on success or an error handle
+ *
+ * @param[in]  host      Hostname of the host.
+ * @param[in]  hostlen   Length of hostname.
+ * @param[in]  port      Port number of the host.
+ * @param[in]  cfg       ESP-TLS configuration as esp_tls_cfg_t.
+ * @param[out] error_handle ESP-TLS error handle holding potential errors occurred during connection
+ * @param[out] sockfd    Socket descriptor if successfully connected on TCP layer
+ * @return     ESP_OK   on success
+ *             ESP-TLS based error codes on failure
+ */
+esp_err_t esp_tls_tcp_connect(const char *host, int hostlen, int port, const esp_tls_cfg_t *cfg, esp_tls_error_handle_t error_handle, int *sockfd);
+
 #ifdef __cplusplus
 }
 #endif

+ 1 - 0
components/tcp_transport/test/test_transport.c

@@ -315,6 +315,7 @@ static void socket_operation_test(esp_transport_handle_t transport_under_test,
     close(params.listen_sock);
     close(params.accepted_sock);
 
+    xEventGroupWaitBits(params.tcp_connect_done, TCP_LISTENER_DONE, true, true, max_wait);
     // Cleanup
     TEST_ASSERT_EQUAL(false, params.tcp_listener_failed);
     vEventGroupDelete(params.tcp_connect_done);

+ 90 - 27
components/tcp_transport/transport_ssl.c

@@ -42,6 +42,7 @@ typedef struct transport_esp_tls {
     esp_tls_cfg_t            cfg;
     bool                     ssl_initialized;
     transport_ssl_conn_state_t conn_state;
+    int                      sockfd;
 } transport_esp_tls_t;
 
 static inline struct transport_esp_tls * ssl_get_context_data(esp_transport_handle_t t)
@@ -95,12 +96,12 @@ static inline int tcp_connect_async(esp_transport_handle_t t, const char *host,
     return esp_tls_connect_async(t, host, port, timeout_ms, true);
 }
 
-static int esp_tls_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms, bool is_plain_tcp)
+static int ssl_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms)
 {
     transport_esp_tls_t *ssl = ssl_get_context_data(t);
 
     ssl->cfg.timeout_ms = timeout_ms;
-    ssl->cfg.is_plain_tcp = is_plain_tcp;
+    ssl->cfg.is_plain_tcp = false;
 
     ssl->ssl_initialized = true;
     ssl->tls = esp_tls_init();
@@ -114,19 +115,27 @@ static int esp_tls_connect(esp_transport_handle_t t, const char *host, int port,
         esp_transport_set_errors(t, ssl->tls->error_handle);
         esp_tls_conn_destroy(ssl->tls);
         ssl->tls = NULL;
+        ssl->sockfd = -1;
         return -1;
     }
+    ssl->sockfd = ssl->tls->sockfd;
     return 0;
 }
 
-static inline int ssl_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms)
+static int tcp_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms)
 {
-    return esp_tls_connect(t, host, port, timeout_ms, false);
-}
+    transport_esp_tls_t *ssl = ssl_get_context_data(t);
+    esp_tls_last_error_t *err_handle = esp_transport_get_error_handle(t);
 
-static inline int tcp_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms)
-{
-    return esp_tls_connect(t, host, port, timeout_ms, true);
+    ssl->cfg.timeout_ms = timeout_ms;
+    esp_err_t err = esp_tls_tcp_connect(host, strlen(host), port, &ssl->cfg, err_handle, &ssl->sockfd);
+    if (err != ESP_OK) {
+        ESP_LOGE(TAG, "Failed to open a new connection: %d", err);
+        err_handle->last_error = err;
+        ssl->sockfd = -1;
+        return -1;
+    }
+    return 0;
 }
 
 static int ssl_poll_read(esp_transport_handle_t t, int timeout_ms)
@@ -139,20 +148,20 @@ static int ssl_poll_read(esp_transport_handle_t t, int timeout_ms)
     fd_set errset;
     FD_ZERO(&readset);
     FD_ZERO(&errset);
-    FD_SET(ssl->tls->sockfd, &readset);
-    FD_SET(ssl->tls->sockfd, &errset);
+    FD_SET(ssl->sockfd, &readset);
+    FD_SET(ssl->sockfd, &errset);
 
-    if ((remain = esp_tls_get_bytes_avail(ssl->tls)) > 0) {
+    if (ssl->tls && (remain = esp_tls_get_bytes_avail(ssl->tls)) > 0) {
         ESP_LOGD(TAG, "remain data in cache, need to read again");
         return remain;
     }
-    ret = select(ssl->tls->sockfd + 1, &readset, NULL, &errset, esp_transport_utils_ms_to_timeval(timeout_ms, &timeout));
-    if (ret > 0 && FD_ISSET(ssl->tls->sockfd, &errset)) {
+    ret = select(ssl->sockfd + 1, &readset, NULL, &errset, esp_transport_utils_ms_to_timeval(timeout_ms, &timeout));
+    if (ret > 0 && FD_ISSET(ssl->sockfd, &errset)) {
         int sock_errno = 0;
         uint32_t optlen = sizeof(sock_errno);
-        getsockopt(ssl->tls->sockfd, SOL_SOCKET, SO_ERROR, &sock_errno, &optlen);
+        getsockopt(ssl->sockfd, SOL_SOCKET, SO_ERROR, &sock_errno, &optlen);
         esp_transport_capture_errno(t, sock_errno);
-        ESP_LOGE(TAG, "ssl_poll_read select error %d, errno = %s, fd = %d", sock_errno, strerror(sock_errno), ssl->tls->sockfd);
+        ESP_LOGE(TAG, "poll_read select error %d, errno = %s, fd = %d", sock_errno, strerror(sock_errno), ssl->tls->sockfd);
         ret = -1;
     }
     return ret;
@@ -167,15 +176,15 @@ static int ssl_poll_write(esp_transport_handle_t t, int timeout_ms)
     fd_set errset;
     FD_ZERO(&writeset);
     FD_ZERO(&errset);
-    FD_SET(ssl->tls->sockfd, &writeset);
-    FD_SET(ssl->tls->sockfd, &errset);
-    ret = select(ssl->tls->sockfd + 1, NULL, &writeset, &errset, esp_transport_utils_ms_to_timeval(timeout_ms, &timeout));
-    if (ret > 0 && FD_ISSET(ssl->tls->sockfd, &errset)) {
+    FD_SET(ssl->sockfd, &writeset);
+    FD_SET(ssl->sockfd, &errset);
+    ret = select(ssl->sockfd + 1, NULL, &writeset, &errset, esp_transport_utils_ms_to_timeval(timeout_ms, &timeout));
+    if (ret > 0 && FD_ISSET(ssl->sockfd, &errset)) {
         int sock_errno = 0;
         uint32_t optlen = sizeof(sock_errno);
-        getsockopt(ssl->tls->sockfd, SOL_SOCKET, SO_ERROR, &sock_errno, &optlen);
+        getsockopt(ssl->sockfd, SOL_SOCKET, SO_ERROR, &sock_errno, &optlen);
         esp_transport_capture_errno(t, sock_errno);
-        ESP_LOGE(TAG, "ssl_poll_write select error %d, errno = %s, fd = %d", sock_errno, strerror(sock_errno), ssl->tls->sockfd);
+        ESP_LOGE(TAG, "poll_write select error %d, errno = %s, fd = %d", sock_errno, strerror(sock_errno), ssl->tls->sockfd);
         ret = -1;
     }
     return ret;
@@ -183,14 +192,14 @@ static int ssl_poll_write(esp_transport_handle_t t, int timeout_ms)
 
 static int ssl_write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms)
 {
-    int poll, ret;
+    int poll;
     transport_esp_tls_t *ssl = ssl_get_context_data(t);
 
     if ((poll = esp_transport_poll_write(t, timeout_ms)) <= 0) {
         ESP_LOGW(TAG, "Poll timeout or error, errno=%s, fd=%d, timeout_ms=%d", strerror(errno), ssl->tls->sockfd, timeout_ms);
         return poll;
     }
-    ret = esp_tls_conn_write(ssl->tls, (const unsigned char *) buffer, len);
+    int ret = esp_tls_conn_write(ssl->tls, (const unsigned char *) buffer, len);
     if (ret < 0) {
         ESP_LOGE(TAG, "esp_tls_conn_write error, errno=%s", strerror(errno));
         esp_transport_set_errors(t, ssl->tls->error_handle);
@@ -198,15 +207,32 @@ static int ssl_write(esp_transport_handle_t t, const char *buffer, int len, int
     return ret;
 }
 
+static int tcp_write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms)
+{
+    int poll;
+    transport_esp_tls_t *ssl = ssl_get_context_data(t);
+
+    if ((poll = esp_transport_poll_write(t, timeout_ms)) <= 0) {
+        ESP_LOGW(TAG, "Poll timeout or error, errno=%s, fd=%d, timeout_ms=%d", strerror(errno), ssl->tls->sockfd, timeout_ms);
+        return poll;
+    }
+    int ret = send(ssl->sockfd,(const unsigned char *) buffer, len, 0);
+    if (ret < 0) {
+        ESP_LOGE(TAG, "tcp_write error, errno=%s", strerror(errno));
+        esp_transport_capture_errno(t, errno);
+    }
+    return ret;
+}
+
 static int ssl_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
 {
-    int poll, ret;
+    int poll;
     transport_esp_tls_t *ssl = ssl_get_context_data(t);
 
     if ((poll = esp_transport_poll_read(t, timeout_ms)) <= 0) {
         return poll;
     }
-    ret = esp_tls_conn_read(ssl->tls, (unsigned char *)buffer, len);
+    int ret = esp_tls_conn_read(ssl->tls, (unsigned char *)buffer, len);
     if (ret < 0) {
         ESP_LOGE(TAG, "esp_tls_conn_read error, errno=%s", strerror(errno));
         esp_transport_set_errors(t, ssl->tls->error_handle);
@@ -221,6 +247,29 @@ static int ssl_read(esp_transport_handle_t t, char *buffer, int len, int timeout
     return ret;
 }
 
+static int tcp_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
+{
+    int poll;
+    transport_esp_tls_t *ssl = ssl_get_context_data(t);
+
+    if ((poll = esp_transport_poll_read(t, timeout_ms)) <= 0) {
+        return poll;
+    }
+    int ret = recv(ssl->sockfd, (unsigned char *)buffer, len, 0);
+    if (ret < 0) {
+        ESP_LOGE(TAG, "tcp_read error, errno=%s", strerror(errno));
+        esp_transport_capture_errno(t, errno);
+    }
+    if (ret == 0) {
+        if (poll > 0) {
+            // no error, socket reads 0 while previously detected as readable -> connection has been closed cleanly
+            capture_tcp_transport_error(t, ERR_TCP_TRANSPORT_CONNECTION_CLOSED_BY_FIN);
+        }
+        ret = -1;
+    }
+    return ret;
+}
+
 static int ssl_close(esp_transport_handle_t t)
 {
     int ret = -1;
@@ -229,6 +278,10 @@ static int ssl_close(esp_transport_handle_t t)
         ret = esp_tls_conn_destroy(ssl->tls);
         ssl->conn_state = TRANS_SSL_INIT;
         ssl->ssl_initialized = false;
+        ssl->sockfd = -1;
+    } else if (ssl && ssl->sockfd >= 0) {
+        close(ssl->sockfd);
+        ssl->sockfd = -1;
     }
     return ret;
 }
@@ -344,6 +397,15 @@ static int ssl_get_socket(esp_transport_handle_t t)
     return -1;
 }
 
+static int tcp_get_socket(esp_transport_handle_t t)
+{
+    transport_esp_tls_t *ctx = ssl_get_context_data(t);
+    if (ctx) {
+        return ctx->sockfd;
+    }
+    return -1;
+}
+
 void esp_transport_ssl_set_ds_data(esp_transport_handle_t t, void *ds_data)
 {
     GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
@@ -378,6 +440,7 @@ esp_transport_handle_t esp_transport_ssl_init(void)
 struct transport_esp_tls* esp_transport_esp_tls_create(void)
 {
     transport_esp_tls_t *transport_esp_tls = calloc(1, sizeof(transport_esp_tls_t));
+    transport_esp_tls->sockfd = -1;
     return transport_esp_tls;
 }
 
@@ -392,9 +455,9 @@ esp_transport_handle_t esp_transport_tcp_init(void)
     if (t == NULL) {
         return NULL;
     }
-    esp_transport_set_func(t, tcp_connect, ssl_read, ssl_write, ssl_close, ssl_poll_read, ssl_poll_write, ssl_destroy);
+    esp_transport_set_func(t, tcp_connect, tcp_read, tcp_write, ssl_close, ssl_poll_read, ssl_poll_write, ssl_destroy);
     esp_transport_set_async_connect_func(t, tcp_connect_async);
-    t->_get_socket = ssl_get_socket;
+    t->_get_socket = tcp_get_socket;
     return t;
 }