Просмотр исходного кода

Merge branch 'feature/ws_client_close_frame' into 'master'

ws_client: Added support for close frame, closing connection gracefully

Closes IDF-1915

See merge request espressif/esp-idf!9677
David Čermák 5 лет назад
Родитель
Сommit
a80b25ebbb

+ 121 - 26
components/esp_websocket_client/esp_websocket_client.c

@@ -52,6 +52,8 @@ static const char *TAG = "WEBSOCKET_CLIENT";
         }
 
 const static int STOPPED_BIT = BIT0;
+const static int CLOSE_FRAME_SENT_BIT = BIT1;   // Indicates that a close frame was sent by the client
+                                        // and we are waiting for the server to continue with clean close
 
 ESP_EVENT_DEFINE_BASE(WEBSOCKET_EVENTS);
 
@@ -80,6 +82,7 @@ typedef enum {
     WEBSOCKET_STATE_INIT,
     WEBSOCKET_STATE_CONNECTED,
     WEBSOCKET_STATE_WAIT_TIMEOUT,
+    WEBSOCKET_STATE_CLOSING,
 } websocket_client_state_t;
 
 struct esp_websocket_client {
@@ -493,14 +496,20 @@ static esp_err_t esp_websocket_client_recv(esp_websocket_client_handle_t client)
         const char *data = (client->payload_len == 0) ? NULL : client->rx_buffer;
         esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_PONG | WS_TRANSPORT_OPCODES_FIN, data, client->payload_len,
                                   client->config->network_timeout_ms);
-    }
-    else if (client->last_opcode == WS_TRANSPORT_OPCODES_PONG) {
+    } else if (client->last_opcode == WS_TRANSPORT_OPCODES_PONG) {
         client->wait_for_pong_resp = false;
+    } else if (client->last_opcode == WS_TRANSPORT_OPCODES_CLOSE) {
+        ESP_LOGD(TAG, "Received close frame");
+        client->state = WEBSOCKET_STATE_CLOSING;
     }
 
     return ESP_OK;
 }
 
+static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t client, ws_transport_opcodes_t opcode, const uint8_t *data, int len, TickType_t timeout);
+
+static int esp_websocket_client_send_close(esp_websocket_client_handle_t client, int code, const char *additional_data, int total_len, TickType_t timeout);
+
 static void esp_websocket_client_task(void *pv)
 {
     const int lock_timeout = portMAX_DELAY;
@@ -520,7 +529,7 @@ static void esp_websocket_client_task(void *pv)
     }
 
     client->state = WEBSOCKET_STATE_INIT;
-    xEventGroupClearBits(client->status_bits, STOPPED_BIT);
+    xEventGroupClearBits(client->status_bits, STOPPED_BIT | CLOSE_FRAME_SENT_BIT);
     int read_select = 0;
     while (client->run) {
         if (xSemaphoreTakeRecursive(client->lock, lock_timeout) != pdPASS) {
@@ -550,22 +559,25 @@ static void esp_websocket_client_task(void *pv)
 
                 break;
             case WEBSOCKET_STATE_CONNECTED:
-                if (_tick_get_ms() - client->ping_tick_ms > WEBSOCKET_PING_TIMEOUT_MS) {
-                    client->ping_tick_ms = _tick_get_ms();
-                    ESP_LOGD(TAG, "Sending PING...");
-                    esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_PING | WS_TRANSPORT_OPCODES_FIN, NULL, 0, client->config->network_timeout_ms);
-
-                    if (!client->wait_for_pong_resp && client->config->pingpong_timeout_sec) {
-                        client->pingpong_tick_ms = _tick_get_ms();
-                        client->wait_for_pong_resp = true;
+                if ((CLOSE_FRAME_SENT_BIT & xEventGroupGetBits(client->status_bits)) == 0) { // only send and check for PING
+                                                                                                          // if closing hasn't been initiated
+                    if (_tick_get_ms() - client->ping_tick_ms > WEBSOCKET_PING_TIMEOUT_MS) {
+                        client->ping_tick_ms = _tick_get_ms();
+                        ESP_LOGD(TAG, "Sending PING...");
+                        esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_PING | WS_TRANSPORT_OPCODES_FIN, NULL, 0, client->config->network_timeout_ms);
+
+                        if (!client->wait_for_pong_resp && client->config->pingpong_timeout_sec) {
+                            client->pingpong_tick_ms = _tick_get_ms();
+                            client->wait_for_pong_resp = true;
+                        }
                     }
-                }
 
-                if ( _tick_get_ms() - client->pingpong_tick_ms > client->config->pingpong_timeout_sec*1000 ) {
-                    if (client->wait_for_pong_resp) {
-                        ESP_LOGE(TAG, "Error, no PONG received for more than %d seconds after PING", client->config->pingpong_timeout_sec);
-                        esp_websocket_client_abort_connection(client);
-                        break;
+                    if ( _tick_get_ms() - client->pingpong_tick_ms > client->config->pingpong_timeout_sec*1000 ) {
+                        if (client->wait_for_pong_resp) {
+                            ESP_LOGE(TAG, "Error, no PONG received for more than %d seconds after PING", client->config->pingpong_timeout_sec);
+                            esp_websocket_client_abort_connection(client);
+                            break;
+                        }
                     }
                 }
 
@@ -593,6 +605,17 @@ static void esp_websocket_client_task(void *pv)
                     ESP_LOGD(TAG, "Reconnecting...");
                 }
                 break;
+            case WEBSOCKET_STATE_CLOSING:
+                // if closing not initiated by the client echo the close message back
+                if ((CLOSE_FRAME_SENT_BIT & xEventGroupGetBits(client->status_bits)) == 0) {
+                    ESP_LOGD(TAG, "Closing initiated by the server, sending close frame");
+                    esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_CLOSE | WS_TRANSPORT_OPCODES_FIN, NULL, 0, client->config->network_timeout_ms);
+                    xEventGroupSetBits(client->status_bits, CLOSE_FRAME_SENT_BIT);
+                }
+                break;
+            default:
+                ESP_LOGD(TAG, "Client run iteration in a default state: %d", client->state);
+                break;
         }
         xSemaphoreGiveRecursive(client->lock);
         if (WEBSOCKET_STATE_CONNECTED == client->state) {
@@ -604,6 +627,21 @@ static void esp_websocket_client_task(void *pv)
         } else if (WEBSOCKET_STATE_WAIT_TIMEOUT == client->state) {
             // waiting for reconnecting...
             vTaskDelay(client->wait_timeout_ms / 2 / portTICK_RATE_MS);
+        } else if (WEBSOCKET_STATE_CLOSING == client->state &&
+                  (CLOSE_FRAME_SENT_BIT & xEventGroupGetBits(client->status_bits))) {
+            ESP_LOGD(TAG, " Waiting for TCP connection to be closed by the server");
+            int ret = esp_transport_ws_poll_connection_closed(client->transport, 1000);
+            if (ret == 0) {
+                // still waiting
+                break;
+            }
+            if (ret < 0) {
+                ESP_LOGW(TAG, "Connection terminated while waiting for clean TCP close");
+            }
+            client->run = false;
+            client->state = WEBSOCKET_STATE_UNKNOW;
+            esp_websocket_client_dispatch_event(client, WEBSOCKET_EVENT_CLOSED, NULL, 0);
+            break;
         }
     }
 
@@ -626,7 +664,7 @@ esp_err_t esp_websocket_client_start(esp_websocket_client_handle_t client)
         ESP_LOGE(TAG, "Error create websocket task");
         return ESP_FAIL;
     }
-    xEventGroupClearBits(client->status_bits, STOPPED_BIT);
+    xEventGroupClearBits(client->status_bits, STOPPED_BIT | CLOSE_FRAME_SENT_BIT);
     return ESP_OK;
 }
 
@@ -645,30 +683,87 @@ esp_err_t esp_websocket_client_stop(esp_websocket_client_handle_t client)
     return ESP_OK;
 }
 
-static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t client, ws_transport_opcodes_t opcode, const char *data, int len, TickType_t timeout);
+static int esp_websocket_client_send_close(esp_websocket_client_handle_t client, int code, const char *additional_data, int total_len, TickType_t timeout)
+{
+    uint8_t *close_status_data = NULL;
+    // RFC6455#section-5.5.1: The Close frame MAY contain a body (indicated by total_len >= 2)
+    if (total_len >= 2) {
+        close_status_data = calloc(1, total_len);
+        ESP_WS_CLIENT_MEM_CHECK(TAG, close_status_data, return -1);
+        // RFC6455#section-5.5.1: The first two bytes of the body MUST be a 2-byte representing a status
+        uint16_t *code_network_order = (uint16_t *) close_status_data;
+        *code_network_order = htons(code);
+        memcpy(close_status_data + 2, additional_data, total_len - 2);
+    }
+    int ret = esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_CLOSE, close_status_data, total_len, timeout);
+    free(close_status_data);
+    return ret;
+}
+
+
+static esp_err_t esp_websocket_client_close_with_optional_body(esp_websocket_client_handle_t client, bool send_body, int code, const char *data, int len, TickType_t timeout)
+{
+    if (client == NULL) {
+        return ESP_ERR_INVALID_ARG;
+    }
+    if (!client->run) {
+        ESP_LOGW(TAG, "Client was not started");
+        return ESP_FAIL;
+    }
+
+    if (send_body) {
+        esp_websocket_client_send_close(client, code, data, len + 2, portMAX_DELAY); // len + 2 -> always sending the code
+    } else {
+        esp_websocket_client_send_close(client, 0, NULL, 0, portMAX_DELAY); // only opcode frame
+    }
+
+    // Set closing bit to prevent from sending PING frames while connected
+    xEventGroupSetBits(client->status_bits, CLOSE_FRAME_SENT_BIT);
+
+    if (STOPPED_BIT & xEventGroupWaitBits(client->status_bits, STOPPED_BIT, false, true, timeout)) {
+        return ESP_OK;
+    }
+
+    // If could not close gracefully within timeout, stop the client and disconnect
+    client->run = false;
+    xEventGroupWaitBits(client->status_bits, STOPPED_BIT, false, true, portMAX_DELAY);
+    client->state = WEBSOCKET_STATE_UNKNOW;
+    return ESP_OK;
+}
+
+esp_err_t esp_websocket_client_close_with_code(esp_websocket_client_handle_t client, int code, const char *data, int len, TickType_t timeout)
+{
+    return esp_websocket_client_close_with_optional_body(client, true, code, data, len, timeout);
+}
+
+esp_err_t esp_websocket_client_close(esp_websocket_client_handle_t client, TickType_t timeout)
+{
+    return esp_websocket_client_close_with_optional_body(client, false, 0, NULL, 0, timeout);
+}
 
 int esp_websocket_client_send_text(esp_websocket_client_handle_t client, const char *data, int len, TickType_t timeout)
 {
-    return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_TEXT, data, len, timeout);
+    return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_TEXT, (const uint8_t *)data, len, timeout);
 }
 
 int esp_websocket_client_send(esp_websocket_client_handle_t client, const char *data, int len, TickType_t timeout)
 {
-    return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_BINARY, data, len, timeout);
+    return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_BINARY, (const uint8_t *)data, len, timeout);
 }
 
 int esp_websocket_client_send_bin(esp_websocket_client_handle_t client, const char *data, int len, TickType_t timeout)
 {
-    return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_BINARY, data, len, timeout);
+    return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_BINARY, (const uint8_t *)data, len, timeout);
 }
 
-static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t client, ws_transport_opcodes_t opcode, const char *data, int len, TickType_t timeout)
+static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t client, ws_transport_opcodes_t opcode, const uint8_t *data, int len, TickType_t timeout)
 {
     int need_write = len;
     int wlen = 0, widx = 0;
     int ret = ESP_FAIL;
 
-    if (client == NULL || data == NULL || len <= 0) {
+    if (client == NULL || len < 0 ||
+        (opcode != WS_TRANSPORT_OPCODES_CLOSE && (data == NULL || len <= 0))) {
         ESP_LOGE(TAG, "Invalid arguments");
         return ESP_FAIL;
     }
@@ -688,7 +783,7 @@ static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t c
         goto unlock_and_return;
     }
     uint32_t current_opcode = opcode;
-    while (widx < len) {
+    while (widx < len || current_opcode) {  // allow for sending "current_opcode" only message with len==0
         if (need_write > client->buffer_size) {
             need_write = client->buffer_size;
         } else {
@@ -698,7 +793,7 @@ static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t c
         // send with ws specific way and specific opcode
         wlen = esp_transport_ws_send_raw(client->transport, current_opcode, (char *)client->tx_buffer, need_write,
                                         (timeout==portMAX_DELAY)? -1 : timeout * portTICK_PERIOD_MS);
-        if (wlen <= 0) {
+        if (wlen < 0 || (wlen == 0 && need_write != 0)) {
             ret = wlen;
             ESP_LOGE(TAG, "Network error: esp_transport_write() returned %d, errno=%d", ret, errno);
             esp_websocket_client_abort_connection(client);

+ 36 - 1
components/esp_websocket_client/include/esp_websocket_client.h

@@ -40,6 +40,7 @@ typedef enum {
     WEBSOCKET_EVENT_CONNECTED,      /*!< Once the Websocket has been connected to the server, no data exchange has been performed */
     WEBSOCKET_EVENT_DISCONNECTED,   /*!< The connection has been disconnected */
     WEBSOCKET_EVENT_DATA,           /*!< When receiving data from the server, possibly multiple portions of the packet */
+    WEBSOCKET_EVENT_CLOSED,         /*!< The connection has been closed cleanly */
     WEBSOCKET_EVENT_MAX
 } esp_websocket_event_id_t;
 
@@ -125,7 +126,11 @@ esp_err_t esp_websocket_client_set_uri(esp_websocket_client_handle_t client, con
 esp_err_t esp_websocket_client_start(esp_websocket_client_handle_t client);
 
 /**
- * @brief      Close the WebSocket connection
+ * @brief      Stops the WebSocket connection without websocket closing handshake
+ *
+ * This API stops ws client and closes TCP connection directly without sending
+ * close frames. It is a good practice to close the connection in a clean way
+ * using esp_websocket_client_close().
  *
  * @param[in]  client  The client
  *
@@ -187,6 +192,36 @@ int esp_websocket_client_send_bin(esp_websocket_client_handle_t client, const ch
  */
 int esp_websocket_client_send_text(esp_websocket_client_handle_t client, const char *data, int len, TickType_t timeout);
 
+/**
+ * @brief      Close the WebSocket connection in a clean way
+ *
+ * Sequence of clean close initiated by client:
+ * * Client sends CLOSE frame
+ * * Client waits until server echos the CLOSE frame
+ * * Client waits until server closes the connection
+ * * Client is stopped the same way as by the `esp_websocket_client_stop()`
+ *
+ * @param[in]  client  The client
+ * @param[in]  timeout Timeout in RTOS ticks for waiting
+ *
+ * @return     esp_err_t
+ */
+esp_err_t esp_websocket_client_close(esp_websocket_client_handle_t client, TickType_t timeout);
+
+/**
+ * @brief      Close the WebSocket connection in a clean way with custom code/data
+ *             Closing sequence is the same as for esp_websocket_client_close()
+ *
+ * @param[in]  client  The client
+ * @param[in]  code    Close status code as defined in RFC6455 section-7.4
+ * @param[in]  data    Additional data to closing message
+ * @param[in]  len     The length of the additional data
+ * @param[in]  timeout Timeout in RTOS ticks for waiting
+ *
+ * @return     esp_err_t
+ */
+esp_err_t esp_websocket_client_close_with_code(esp_websocket_client_handle_t client, int code, const char *data, int len, TickType_t timeout);
+
 /**
  * @brief      Check the WebSocket client connection state
  *

+ 1 - 1
components/tcp_transport/include/esp_transport.h

@@ -310,7 +310,7 @@ esp_err_t esp_transport_set_parent_transport_func(esp_transport_handle_t t, payl
  * @return
  *            - valid pointer of esp_error_handle_t
  *            - NULL if invalid transport handle
-  */
+ */
 esp_tls_error_handle_t esp_transport_get_error_handle(esp_transport_handle_t t);
 
 

+ 15 - 0
components/tcp_transport/include/esp_transport_ws.h

@@ -117,6 +117,21 @@ ws_transport_opcodes_t esp_transport_ws_get_read_opcode(esp_transport_handle_t t
  */
 int esp_transport_ws_get_read_payload_len(esp_transport_handle_t t);
 
+/**
+ * @brief               Polls the active connection for termination
+ *
+ * This API is typically used by the client to wait for clean connection closure
+ * by websocket server
+ *
+ * @param t             Websocket transport handle
+ * @param[in] timeout_ms The timeout milliseconds
+ *
+ * @return
+ *      0 - no activity on read and error socket descriptor within timeout
+ *      1 - Success: either connection terminated by FIN or the most common RST err codes
+ *      -1 - Failure: Unexpected error code or socket is normally readable
+ */
+int esp_transport_ws_poll_connection_closed(esp_transport_handle_t t, int timeout_ms);
 
 #ifdef __cplusplus
 }

+ 56 - 0
components/tcp_transport/private_include/esp_transport_internal.h

@@ -0,0 +1,56 @@
+// Copyright 2020 Espressif Systems (Shanghai) PTE LTD
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef _ESP_TRANSPORT_INTERNAL_H_
+#define _ESP_TRANSPORT_INTERNAL_H_
+
+#include "esp_transport.h"
+#include "sys/queue.h"
+
+typedef int (*get_socket_func)(esp_transport_handle_t t);
+
+/**
+ * Transport layer structure, which will provide functions, basic properties for transport types
+ */
+struct esp_transport_item_t {
+    int             port;
+    char            *scheme;        /*!< Tag name */
+    void            *data;          /*!< Additional transport data */
+    connect_func    _connect;       /*!< Connect function of this transport */
+    io_read_func    _read;          /*!< Read */
+    io_func         _write;         /*!< Write */
+    trans_func      _close;         /*!< Close */
+    poll_func       _poll_read;     /*!< Poll and read */
+    poll_func       _poll_write;    /*!< Poll and write */
+    trans_func      _destroy;       /*!< Destroy and free transport */
+    connect_async_func _connect_async;      /*!< non-blocking connect function of this transport */
+    payload_transfer_func  _parent_transfer;        /*!< Function returning underlying transport layer */
+    get_socket_func        _get_socket;
+    esp_tls_error_handle_t     error_handle;            /*!< Pointer to esp-tls error handle */
+
+    STAILQ_ENTRY(esp_transport_item_t) next;
+};
+
+/**
+ * @brief Returns underlying socket for the supplied transport handle
+ *
+ * @param t Transport handle
+ *
+ * @return Socket file descriptor in case of success
+ *         -1 in case of error
+ */
+int esp_transport_get_socket(esp_transport_handle_t t);
+
+
+#endif //_ESP_TRANSPORT_INTERNAL_H_

+ 3 - 3
components/tcp_transport/private_include/esp_transport_ssl_internal.h

@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#ifndef _ESP_TRANSPORT_INTERNAL_H_
-#define _ESP_TRANSPORT_INTERNAL_H_
+#ifndef _ESP_TRANSPORT_SSL_INTERNAL_H_
+#define _ESP_TRANSPORT_SSL_INTERNAL_H_
 
 /**
  * @brief      Sets error to common transport handle
@@ -27,4 +27,4 @@
 void esp_transport_set_errors(esp_transport_handle_t t, const esp_tls_error_handle_t error_handle);
 
 
-#endif /* _ESP_TRANSPORT_INTERNAL_H_ */
+#endif /* _ESP_TRANSPORT_SSL_INTERNAL_H_ */

+ 10 - 23
components/tcp_transport/transport.c

@@ -21,32 +21,11 @@
 #include "esp_log.h"
 
 #include "esp_transport.h"
+#include "esp_transport_internal.h"
 #include "esp_transport_utils.h"
 
 static const char *TAG = "TRANSPORT";
 
-/**
- * Transport layer structure, which will provide functions, basic properties for transport types
- */
-struct esp_transport_item_t {
-    int             port;
-    int             socket;         /*!< Socket to use in this transport */
-    char            *scheme;        /*!< Tag name */
-    void            *context;       /*!< Context data */
-    void            *data;          /*!< Additional transport data */
-    connect_func    _connect;       /*!< Connect function of this transport */
-    io_read_func    _read;          /*!< Read */
-    io_func         _write;         /*!< Write */
-    trans_func      _close;         /*!< Close */
-    poll_func       _poll_read;     /*!< Poll and read */
-    poll_func       _poll_write;    /*!< Poll and write */
-    trans_func      _destroy;       /*!< Destroy and free transport */
-    connect_async_func _connect_async;      /*!< non-blocking connect function of this transport */
-    payload_transfer_func  _parent_transfer;        /*!< Function returning underlying transport layer */
-    esp_tls_error_handle_t     error_handle;            /*!< Pointer to esp-tls error handle */
-
-    STAILQ_ENTRY(esp_transport_item_t) next;
-};
 
 
 /**
@@ -305,4 +284,12 @@ void esp_transport_set_errors(esp_transport_handle_t t, const esp_tls_error_hand
     if (t)  {
         memcpy(t->error_handle, error_handle, sizeof(esp_tls_last_error_t));
     }
-}
+}
+
+int esp_transport_get_socket(esp_transport_handle_t t)
+{
+    if (t && t->_get_socket)  {
+        return  t->_get_socket(t);
+    }
+    return -1;
+}

+ 13 - 0
components/tcp_transport/transport_ssl.c

@@ -25,6 +25,7 @@
 #include "esp_transport_ssl.h"
 #include "esp_transport_utils.h"
 #include "esp_transport_ssl_internal.h"
+#include "esp_transport_internal.h"
 
 static const char *TAG = "TRANS_SSL";
 
@@ -288,6 +289,17 @@ void esp_transport_ssl_use_secure_element(esp_transport_handle_t t)
     }
 }
 
+static int ssl_get_socket(esp_transport_handle_t t)
+{
+    if (t) {
+        transport_ssl_t *ssl = t->data;
+        if (ssl && ssl->tls) {
+            return ssl->tls->sockfd;
+        }
+    }
+    return -1;
+}
+
 esp_transport_handle_t esp_transport_ssl_init(void)
 {
     esp_transport_handle_t t = esp_transport_init();
@@ -296,6 +308,7 @@ esp_transport_handle_t esp_transport_ssl_init(void)
     esp_transport_set_context_data(t, ssl);
     esp_transport_set_func(t, ssl_connect, ssl_read, ssl_write, ssl_close, ssl_poll_read, ssl_poll_write, ssl_destroy);
     esp_transport_set_async_connect_func(t, ssl_connect_async);
+    t->_get_socket = ssl_get_socket;
     return t;
 }
 

+ 13 - 0
components/tcp_transport/transport_tcp.c

@@ -25,6 +25,7 @@
 
 #include "esp_transport_utils.h"
 #include "esp_transport.h"
+#include "esp_transport_internal.h"
 
 static const char *TAG = "TRANS_TCP";
 
@@ -234,6 +235,17 @@ static esp_err_t tcp_destroy(esp_transport_handle_t t)
     return 0;
 }
 
+static int tcp_get_socket(esp_transport_handle_t t)
+{
+    if (t) {
+        transport_tcp_t *tcp = t->data;
+        if (tcp) {
+            return tcp->sock;
+        }
+    }
+    return -1;
+}
+
 esp_transport_handle_t esp_transport_tcp_init(void)
 {
     esp_transport_handle_t t = esp_transport_init();
@@ -242,6 +254,7 @@ esp_transport_handle_t esp_transport_tcp_init(void)
     tcp->sock = -1;
     esp_transport_set_func(t, tcp_connect, tcp_read, tcp_write, tcp_close, tcp_poll_read, tcp_poll_write, tcp_destroy);
     esp_transport_set_context_data(t, tcp);
+    t->_get_socket = tcp_get_socket;
 
     return t;
 }

+ 52 - 0
components/tcp_transport/transport_ws.c

@@ -2,6 +2,7 @@
 #include <string.h>
 #include <ctype.h>
 #include <sys/random.h>
+#include <sys/socket.h>
 #include "esp_log.h"
 #include "esp_transport.h"
 #include "esp_transport_tcp.h"
@@ -9,6 +10,8 @@
 #include "esp_transport_utils.h"
 #include "mbedtls/base64.h"
 #include "mbedtls/sha1.h"
+#include "esp_transport_internal.h"
+#include "errno.h"
 
 static const char *TAG = "TRANSPORT_WS";
 
@@ -449,6 +452,17 @@ void esp_transport_ws_set_path(esp_transport_handle_t t, const char *path)
     strcpy(ws->path, path);
 }
 
+static int ws_get_socket(esp_transport_handle_t t)
+{
+    if (t) {
+        transport_ws_t *ws = t->data;
+        if (ws && ws->parent && ws->parent->_get_socket) {
+            return ws->parent->_get_socket(ws->parent);
+        }
+    }
+    return -1;
+}
+
 esp_transport_handle_t esp_transport_ws_init(esp_transport_handle_t parent_handle)
 {
     esp_transport_handle_t t = esp_transport_init();
@@ -473,6 +487,7 @@ esp_transport_handle_t esp_transport_ws_init(esp_transport_handle_t parent_handl
     esp_transport_set_parent_transport_func(t, ws_get_payload_transport_handle);
 
     esp_transport_set_context_data(t, ws);
+    t->_get_socket = ws_get_socket;
     return t;
 }
 
@@ -548,4 +563,41 @@ int esp_transport_ws_get_read_payload_len(esp_transport_handle_t t)
     return ws->frame_state.payload_len;
 }
 
+int esp_transport_ws_poll_connection_closed(esp_transport_handle_t t, int timeout_ms)
+{
+    struct timeval timeout;
+    int sock = esp_transport_get_socket(t);
+    fd_set readset;
+    fd_set errset;
+    FD_ZERO(&readset);
+    FD_ZERO(&errset);
+    FD_SET(sock, &readset);
+    FD_SET(sock, &errset);
+
+    int ret = select(sock + 1, &readset, NULL, &errset, esp_transport_utils_ms_to_timeval(timeout_ms, &timeout));
+    if (ret > 0) {
+        if (FD_ISSET(sock, &readset)) {
+            uint8_t buffer;
+            if (recv(sock, &buffer, 1, MSG_PEEK) <= 0) {
+                // socket is readable, but reads zero bytes -- connection cleanly closed by FIN flag
+                return 1;
+            }
+            ESP_LOGW(TAG, "esp_transport_ws_poll_connection_closed: unexpected data readable on socket=%d", sock);
+        } else if (FD_ISSET(sock, &errset)) {
+            int sock_errno = 0;
+            uint32_t optlen = sizeof(sock_errno);
+            getsockopt(sock, SOL_SOCKET, SO_ERROR, &sock_errno, &optlen);
+            ESP_LOGD(TAG, "esp_transport_ws_poll_connection_closed select error %d, errno = %s, fd = %d", sock_errno, strerror(sock_errno), sock);
+            if (sock_errno == ENOTCONN || sock_errno == ECONNRESET || sock_errno == ECONNABORTED) {
+                // the three err codes above might be caused by connection termination by RTS flag
+                // which we still assume as expected closing sequence of ws-transport connection
+                return 1;
+            }
+            ESP_LOGE(TAG, "esp_transport_ws_poll_connection_closed: unexpected errno=%d on socket=%d", sock_errno, sock);
+        }
+        return -1; // indicates error: socket unexpectedly reads an actual data, or unexpected errno code
+    }
+    return ret;
+
+}
 

+ 36 - 146
examples/protocols/websocket/example_test.py

@@ -3,12 +3,10 @@ from __future__ import unicode_literals
 import re
 import os
 import socket
-import select
-import hashlib
-import base64
-import queue
 import random
 import string
+from SimpleWebSocketServer import SimpleWebSocketServer, WebSocket
+from tiny_test_fw import Utility
 from threading import Thread, Event
 import ttfw_idf
 
@@ -26,159 +24,45 @@ def get_my_ip():
     return IP
 
 
-# Simple Websocket server for testing purposes
-class Websocket:
-    HEADER_LEN = 6
+class TestEcho(WebSocket):
 
-    def __init__(self, port):
-        self.port = port
-        self.socket = socket.socket()
-        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-        self.socket.settimeout(10.0)
-        self.send_q = queue.Queue()
-        self.shutdown = Event()
+    def handleMessage(self):
+        self.sendMessage(self.data)
+        print('Server sent: {}'.format(self.data))
 
-    def __enter__(self):
-        try:
-            self.socket.bind(('', self.port))
-        except socket.error as e:
-            print("Bind failed:{}".format(e))
-            raise
+    def handleConnected(self):
+        print('Connection from: {}'.format(self.address))
 
-        self.socket.listen(1)
-        self.server_thread = Thread(target=self.run_server)
-        self.server_thread.start()
+    def handleClose(self):
+        print('{} closed the connection'.format(self.address))
 
-        return self
 
-    def __exit__(self, exc_type, exc_value, traceback):
-        self.shutdown.set()
-        self.server_thread.join()
-        self.socket.close()
-        self.conn.close()
-
-    def run_server(self):
-        self.conn, address = self.socket.accept()  # accept new connection
-        self.socket.settimeout(10.0)
-
-        print("Connection from: {}".format(address))
-
-        self.establish_connection()
-        print("WS established")
-        # Handle connection until client closes it, will echo any data received and send data from send_q queue
-        self.handle_conn()
-
-    def establish_connection(self):
-        while not self.shutdown.is_set():
-            try:
-                # receive data stream. it won't accept data packet greater than 1024 bytes
-                data = self.conn.recv(1024).decode()
-                if not data:
-                    # exit if data is not received
-                    raise
-
-                if "Upgrade: websocket" in data and "Connection: Upgrade" in data:
-                    self.handshake(data)
-                    return
-
-            except socket.error as err:
-                print("Unable to establish a websocket connection: {}".format(err))
-                raise
-
-    def handshake(self, data):
-        # Magic string from RFC
-        MAGIC_STRING = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
-        headers = data.split("\r\n")
-
-        for header in headers:
-            if "Sec-WebSocket-Key" in header:
-                client_key = header.split()[1]
-
-        if client_key:
-            resp_key = client_key + MAGIC_STRING
-            resp_key = base64.standard_b64encode(hashlib.sha1(resp_key.encode()).digest())
-
-            resp = "HTTP/1.1 101 Switching Protocols\r\n" + \
-                "Upgrade: websocket\r\n" + \
-                "Connection: Upgrade\r\n" + \
-                "Sec-WebSocket-Accept: {}\r\n\r\n".format(resp_key.decode())
-
-            self.conn.send(resp.encode())
-
-    def handle_conn(self):
-        while not self.shutdown.is_set():
-            r,w,e = select.select([self.conn], [], [], 1)
-            try:
-                if self.conn in r:
-                    self.echo_data()
-
-                if not self.send_q.empty():
-                    self._send_data_(self.send_q.get())
-
-            except socket.error as err:
-                print("Stopped echoing data: {}".format(err))
-                raise
-
-    def echo_data(self):
-        header = bytearray(self.conn.recv(self.HEADER_LEN, socket.MSG_WAITALL))
-        if not header:
-            # exit if socket closed by peer
-            return
-
-        # Remove mask bit
-        payload_len = ~(1 << 7) & header[1]
-
-        payload = bytearray(self.conn.recv(payload_len, socket.MSG_WAITALL))
-
-        if not payload:
-            # exit if socket closed by peer
-            return
-        frame = header + payload
-
-        decoded_payload = self.decode_frame(frame)
-        print("Sending echo...")
-        self._send_data_(decoded_payload)
-
-    def _send_data_(self, data):
-        frame = self.encode_frame(data)
-        self.conn.send(frame)
+# Simple Websocket server for testing purposes
+class Websocket(object):
 
     def send_data(self, data):
-        self.send_q.put(data.encode())
-
-    def decode_frame(self, frame):
-        # Mask out MASK bit from payload length, this len is only valid for short messages (<126)
-        payload_len = ~(1 << 7) & frame[1]
-
-        mask = frame[2:self.HEADER_LEN]
-
-        encrypted_payload = frame[self.HEADER_LEN:self.HEADER_LEN + payload_len]
-        payload = bytearray()
-
-        for i in range(payload_len):
-            payload.append(encrypted_payload[i] ^ mask[i % 4])
+        for nr, conn in self.server.connections.items():
+            conn.sendMessage(data)
 
-        return payload
+    def run(self):
+        self.server = SimpleWebSocketServer('', self.port, TestEcho)
+        while not self.exit_event.is_set():
+            self.server.serveonce()
 
-    def encode_frame(self, payload):
-        # Set FIN = 1 and OP_CODE = 1 (text)
-        header = (1 << 7) | (1 << 0)
-
-        frame = bytearray([header])
-        payload_len = len(payload)
-
-        # If payload len is longer than 125 then the next 16 bits are used to encode length
-        if payload_len > 125:
-            frame.append(126)
-            frame.append(payload_len >> 8)
-            frame.append(0xFF & payload_len)
-
-        else:
-            frame.append(payload_len)
+    def __init__(self, port):
+        self.port = port
+        self.exit_event = Event()
+        self.thread = Thread(target=self.run)
+        self.thread.start()
 
-        frame += payload
+    def __enter__(self):
+        return self
 
-        return frame
+    def __exit__(self, exc_type, exc_value, traceback):
+        self.exit_event.set()
+        self.thread.join(10)
+        if self.thread.is_alive():
+            Utility.console_log('Thread cannot be joined', 'orange')
 
 
 def test_echo(dut):
@@ -188,6 +72,11 @@ def test_echo(dut):
     print("All echos received")
 
 
+def test_close(dut):
+    code = dut.expect(re.compile(r"WEBSOCKET: Received closed message with code=(\d*)"), timeout=60)[0]
+    print("Received close frame with code {}".format(code))
+
+
 def test_recv_long_msg(dut, websocket, msg_len, repeats):
     send_msg = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(msg_len))
 
@@ -246,6 +135,7 @@ def test_examples_protocol_websocket(env, extra_data):
             test_echo(dut1)
             # Message length should exceed DUT's buffer size to test fragmentation, default is 1024 byte
             test_recv_long_msg(dut1, ws, 2000, 3)
+            test_close(dut1)
 
     else:
         print("DUT connecting to {}".format(uri))

+ 6 - 2
examples/protocols/websocket/main/websocket_example.c

@@ -69,7 +69,11 @@ static void websocket_event_handler(void *handler_args, esp_event_base_t base, i
     case WEBSOCKET_EVENT_DATA:
         ESP_LOGI(TAG, "WEBSOCKET_EVENT_DATA");
         ESP_LOGI(TAG, "Received opcode=%d", data->op_code);
-        ESP_LOGW(TAG, "Received=%.*s", data->data_len, (char *)data->data_ptr);
+        if (data->op_code == 0x08 && data->data_len == 2) {
+            ESP_LOGW(TAG, "Received closed message with code=%d", 256*data->data_ptr[0] + data->data_ptr[1]);
+        } else {
+            ESP_LOGW(TAG, "Received=%.*s", data->data_len, (char *)data->data_ptr);
+        }
         ESP_LOGW(TAG, "Total payload length=%d, data_len=%d, current payload offset=%d\r\n", data->payload_len, data->data_len, data->payload_offset);
 
         xTimerReset(shutdown_signal_timer, portMAX_DELAY);
@@ -121,7 +125,7 @@ static void websocket_app_start(void)
     }
 
     xSemaphoreTake(shutdown_sema, portMAX_DELAY);
-    esp_websocket_client_stop(client);
+    esp_websocket_client_close(client, portMAX_DELAY);
     ESP_LOGI(TAG, "Websocket Stopped");
     esp_websocket_client_destroy(client);
 }