Просмотр исходного кода

Merge branch 'bugfix/ws_client_fixes' into 'master'

ws_client: various fixes

See merge request espressif/esp-idf!5879
Ivan Grokhotkov 6 лет назад
Родитель
Сommit
d77a7c23da

+ 24 - 13
components/esp_websocket_client/esp_websocket_client.c

@@ -63,6 +63,7 @@ typedef struct {
     bool                        auto_reconnect;
     void                        *user_context;
     int                         network_timeout_ms;
+    char                        *subprotocol;
 } websocket_config_storage_t;
 
 typedef enum {
@@ -172,6 +173,11 @@ static esp_err_t esp_websocket_client_set_config(esp_websocket_client_handle_t c
         cfg->path = strdup(config->path);
         ESP_WS_CLIENT_MEM_CHECK(TAG, cfg->path, return ESP_ERR_NO_MEM);
     }
+    if (config->subprotocol) {
+        free(cfg->subprotocol);
+        cfg->subprotocol = strdup(config->subprotocol);
+        ESP_WS_CLIENT_MEM_CHECK(TAG, cfg->subprotocol, return ESP_ERR_NO_MEM);
+    }
 
     cfg->network_timeout_ms = WEBSOCKET_NETWORK_TIMEOUT_MS;
     cfg->user_context = config->user_context;
@@ -199,12 +205,23 @@ static esp_err_t esp_websocket_client_destroy_config(esp_websocket_client_handle
     free(cfg->scheme);
     free(cfg->username);
     free(cfg->password);
+    free(cfg->subprotocol);
     memset(cfg, 0, sizeof(websocket_config_storage_t));
     free(client->config);
     client->config = NULL;
     return ESP_OK;
 }
 
+static void set_websocket_transport_optional_settings(esp_websocket_client_handle_t client, esp_transport_handle_t trans)
+{
+    if (trans && client->config->path) {
+        esp_transport_ws_set_path(trans, client->config->path);
+    }
+    if (trans && client->config->subprotocol) {
+        esp_transport_ws_set_subprotocol(trans, client->config->subprotocol);
+    }
+}
+
 esp_websocket_client_handle_t esp_websocket_client_init(const esp_websocket_client_config_t *config)
 {
     esp_websocket_client_handle_t client = calloc(1, sizeof(struct esp_websocket_client));
@@ -224,6 +241,9 @@ esp_websocket_client_handle_t esp_websocket_client_init(const esp_websocket_clie
     client->lock = xSemaphoreCreateMutex();
     ESP_WS_CLIENT_MEM_CHECK(TAG, client->lock, goto _websocket_init_fail);
 
+    client->config = calloc(1, sizeof(websocket_config_storage_t));
+    ESP_WS_CLIENT_MEM_CHECK(TAG, client->config, goto _websocket_init_fail);
+
     client->transport_list = esp_transport_list_init();
     ESP_WS_CLIENT_MEM_CHECK(TAG, client->transport_list, goto _websocket_init_fail);
 
@@ -259,14 +279,11 @@ esp_websocket_client_handle_t esp_websocket_client_init(const esp_websocket_clie
     esp_transport_set_default_port(wss, WEBSOCKET_SSL_DEFAULT_PORT);
 
     esp_transport_list_add(client->transport_list, wss, "wss");
-    if (config->transport == WEBSOCKET_TRANSPORT_OVER_TCP) {
+    if (config->transport == WEBSOCKET_TRANSPORT_OVER_SSL) {
         asprintf(&client->config->scheme, "wss");
         ESP_WS_CLIENT_MEM_CHECK(TAG, client->config->scheme, goto _websocket_init_fail);
     }
 
-    client->config = calloc(1, sizeof(websocket_config_storage_t));
-    ESP_WS_CLIENT_MEM_CHECK(TAG, client->config, goto _websocket_init_fail);
-
     if (config->uri) {
         if (esp_websocket_client_set_uri(client, config->uri) != ESP_OK) {
             ESP_LOGE(TAG, "Invalid uri");
@@ -284,6 +301,9 @@ esp_websocket_client_handle_t esp_websocket_client_init(const esp_websocket_clie
         ESP_WS_CLIENT_MEM_CHECK(TAG, client->config->scheme, goto _websocket_init_fail);
     }
 
+    set_websocket_transport_optional_settings(client, esp_transport_list_get_transport(client->transport_list, "ws"));
+    set_websocket_transport_optional_settings(client, esp_transport_list_get_transport(client->transport_list, "wss"));
+
     client->keepalive_tick_ms = _tick_get_ms();
     client->reconnect_tick_ms = _tick_get_ms();
     client->ping_tick_ms = _tick_get_ms();
@@ -366,15 +386,6 @@ esp_err_t esp_websocket_client_set_uri(esp_websocket_client_handle_t client, con
         free(client->config->path);
         asprintf(&client->config->path, "%.*s", puri.field_data[UF_PATH].len, uri + puri.field_data[UF_PATH].off);
         ESP_WS_CLIENT_MEM_CHECK(TAG, client->config->path, return ESP_ERR_NO_MEM);
-
-        esp_transport_handle_t trans = esp_transport_list_get_transport(client->transport_list, "ws");
-        if (trans) {
-            esp_transport_ws_set_path(trans, client->config->path);
-        }
-        trans = esp_transport_list_get_transport(client->transport_list, "wss");
-        if (trans) {
-            esp_transport_ws_set_path(trans, client->config->path);
-        }
     }
     if (puri.field_data[UF_PORT].off) {
         client->config->port = strtol((const char*)(uri + puri.field_data[UF_PORT].off), NULL, 10);

+ 1 - 0
components/esp_websocket_client/include/esp_websocket_client.h

@@ -91,6 +91,7 @@ typedef struct {
     int                         buffer_size;                /*!< Websocket buffer size */
     const char                  *cert_pem;                  /*!< SSL Certification, PEM format as string, if the client requires to verify server */
     esp_websocket_transport_t   transport;                  /*!< Websocket transport type, see `esp_websocket_transport_t */
+    char                        *subprotocol;               /*!< Websocket subprotocol */
 } esp_websocket_client_config_t;
 
 /**

+ 0 - 1
components/tcp_transport/include/esp_transport_ws.h

@@ -43,7 +43,6 @@ void esp_transport_ws_set_path(esp_transport_handle_t t, const char *path);
  */
 esp_err_t esp_transport_ws_set_subprotocol(esp_transport_handle_t t, const char *sub_protocol);
 
-
 #ifdef __cplusplus
 }
 #endif

+ 11 - 12
components/tcp_transport/transport_ws.c

@@ -188,18 +188,17 @@ static int _ws_write(esp_transport_handle_t t, int opcode, int mask_flag, const
         ws_header[header_len++] = (uint8_t)((len >> 8) & 0xFF);
         ws_header[header_len++] = (uint8_t)((len >> 0) & 0xFF);
     }
-    if (len) {
-        if (mask_flag) {
-            mask = &ws_header[header_len];
-            getrandom(ws_header + header_len, 4, 0);
-            header_len += 4;
-
-            for (i = 0; i < len; ++i) {
-                buffer[i] = (buffer[i] ^ mask[i % 4]);
-            }
-        }
 
+    if (mask_flag) {
+        mask = &ws_header[header_len];
+        getrandom(ws_header + header_len, 4, 0);
+        header_len += 4;
+
+        for (i = 0; i < len; ++i) {
+            buffer[i] = (buffer[i] ^ mask[i % 4]);
+        }
     }
+
     if (esp_transport_write(ws->parent, ws_header, header_len, timeout_ms) != header_len) {
         ESP_LOGE(TAG, "Error write header");
         return -1;
@@ -224,7 +223,7 @@ static int ws_write(esp_transport_handle_t t, const char *b, int len, int timeou
 {
     if (len == 0) {
         ESP_LOGD(TAG, "Write PING message");
-        return _ws_write(t, WS_OPCODE_PING | WS_FIN, 0, NULL, 0, timeout_ms);
+        return _ws_write(t, WS_OPCODE_PING | WS_FIN, WS_MASK, NULL, 0, timeout_ms);
     }
     return _ws_write(t, WS_OPCODE_BINARY | WS_FIN, WS_MASK, b, len, timeout_ms);
 }
@@ -282,7 +281,7 @@ static int ws_read(esp_transport_handle_t t, char *buffer, int len, int timeout_
     }
 
     // Then receive and process payload
-    if ((rlen = esp_transport_read(ws->parent, buffer, payload_len, timeout_ms)) <= 0) {
+    if (payload_len != 0 && (rlen = esp_transport_read(ws->parent, buffer, payload_len, timeout_ms)) <= 0) {
         ESP_LOGE(TAG, "Error read data");
         return rlen;
     }