Bläddra i källkod

Add User-Agent and additional headers to esp_websocket_client

Merges https://github.com/espressif/esp-idf/pull/4345
David N. Junod 6 år sedan
förälder
incheckning
9200250f51

+ 20 - 0
components/esp_websocket_client/esp_websocket_client.c

@@ -63,6 +63,8 @@ typedef struct {
     void                        *user_context;
     int                         network_timeout_ms;
     char                        *subprotocol;
+    char                        *user_agent;
+    char                        *headers;
 } websocket_config_storage_t;
 
 typedef enum {
@@ -179,6 +181,16 @@ static esp_err_t esp_websocket_client_set_config(esp_websocket_client_handle_t c
         cfg->subprotocol = strdup(config->subprotocol);
         ESP_WS_CLIENT_MEM_CHECK(TAG, cfg->subprotocol, return ESP_ERR_NO_MEM);
     }
+    if (config->user_agent) {
+        free(cfg->user_agent);
+        cfg->user_agent = strdup(config->user_agent);
+        ESP_WS_CLIENT_MEM_CHECK(TAG, cfg->user_agent, return ESP_ERR_NO_MEM);
+    }
+    if (config->headers) {
+        free(cfg->headers);
+        cfg->headers = strdup(config->headers);
+        ESP_WS_CLIENT_MEM_CHECK(TAG, cfg->headers, return ESP_ERR_NO_MEM);
+    }
 
     cfg->network_timeout_ms = WEBSOCKET_NETWORK_TIMEOUT_MS;
     cfg->user_context = config->user_context;
@@ -207,6 +219,8 @@ static esp_err_t esp_websocket_client_destroy_config(esp_websocket_client_handle
     free(cfg->username);
     free(cfg->password);
     free(cfg->subprotocol);
+    free(cfg->user_agent);
+    free(cfg->headers);
     memset(cfg, 0, sizeof(websocket_config_storage_t));
     free(client->config);
     client->config = NULL;
@@ -221,6 +235,12 @@ static void set_websocket_transport_optional_settings(esp_websocket_client_handl
     if (trans && client->config->subprotocol) {
         esp_transport_ws_set_subprotocol(trans, client->config->subprotocol);
     }
+    if (trans && client->config->user_agent) {
+        esp_transport_ws_set_user_agent(trans, client->config->user_agent);
+    }
+    if (trans && client->config->headers) {
+        esp_transport_ws_set_headers(trans, client->config->headers);
+    }
 }
 
 esp_websocket_client_handle_t esp_websocket_client_init(const esp_websocket_client_config_t *config)

+ 2 - 0
components/esp_websocket_client/include/esp_websocket_client.h

@@ -92,6 +92,8 @@ typedef struct {
     const char                  *cert_pem;                  /*!< SSL Certification, PEM format as string, if the client requires to verify server */
     esp_websocket_transport_t   transport;                  /*!< Websocket transport type, see `esp_websocket_transport_t */
     char                        *subprotocol;               /*!< Websocket subprotocol */
+    char                        *user_agent;                /*!< Websocket user-agent */
+    char                        *headers;                   /*!< Websocket additional headers */
 } esp_websocket_client_config_t;
 
 /**

+ 24 - 0
components/tcp_transport/include/esp_transport_ws.h

@@ -50,6 +50,30 @@ void esp_transport_ws_set_path(esp_transport_handle_t t, const char *path);
  */
 esp_err_t esp_transport_ws_set_subprotocol(esp_transport_handle_t t, const char *sub_protocol);
 
+/**
+ * @brief               Set websocket user-agent header
+ *
+ * @param t             websocket transport handle
+ * @param sub_protocol  user-agent string
+ *
+ * @return
+ *      - ESP_OK on success
+ *      - One of the error codes
+ */
+esp_err_t esp_transport_ws_set_user_agent(esp_transport_handle_t t, const char *user_agent);
+
+/**
+ * @brief               Set websocket additional headers
+ *
+ * @param t             websocket transport handle
+ * @param sub_protocol  additional header strings each terminated with \r\n
+ *
+ * @return
+ *      - ESP_OK on success
+ *      - One of the error codes
+ */
+esp_err_t esp_transport_ws_set_headers(esp_transport_handle_t t, const char *headers);
+
 /**
  * @brief               Sends websocket raw message with custom opcode and payload
  *

+ 61 - 4
components/tcp_transport/transport_ws.c

@@ -31,6 +31,8 @@ typedef struct {
     char *path;
     char *buffer;
     char *sub_protocol;
+    char *user_agent;
+    char *headers;
     uint8_t read_opcode;
     esp_transport_handle_t parent;
 } transport_ws_t;
@@ -96,24 +98,27 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
     // Size of base64 coded string is equal '((input_size * 4) / 3) + (input_size / 96) + 6' including Z-term
     unsigned char client_key[28] = {0};
 
+    const char *user_agent_ptr = (ws->user_agent)?(ws->user_agent):"ESP32 Websocket Client";
+
     size_t outlen = 0;
     mbedtls_base64_encode(client_key, sizeof(client_key), &outlen, random_key, sizeof(random_key));
     int len = snprintf(ws->buffer, DEFAULT_WS_BUFFER,
                          "GET %s HTTP/1.1\r\n"
                          "Connection: Upgrade\r\n"
                          "Host: %s:%d\r\n"
+                         "User-Agent: %s\r\n"
                          "Upgrade: websocket\r\n"
                          "Sec-WebSocket-Version: 13\r\n"
-                         "Sec-WebSocket-Key: %s\r\n"
-                         "User-Agent: ESP32 Websocket Client\r\n",
+                         "Sec-WebSocket-Key: %s\r\n",
                          ws->path,
-                         host, port,
+                         host, port, user_agent_ptr,
                          client_key);
     if (len <= 0 || len >= DEFAULT_WS_BUFFER) {
         ESP_LOGE(TAG, "Error in request generation, %d", len);
         return -1;
     }
     if (ws->sub_protocol) {
+        ESP_LOGD(TAG, "sub_protocol: %s", ws->sub_protocol);
         int r = snprintf(ws->buffer + len, DEFAULT_WS_BUFFER - len, "Sec-WebSocket-Protocol: %s\r\n", ws->sub_protocol);
         len += r;
         if (r <= 0 || len >= DEFAULT_WS_BUFFER) {
@@ -122,6 +127,16 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
             return -1;
         }
     }
+    if (ws->headers) {
+        ESP_LOGD(TAG, "headers: %s", ws->headers);
+        int r = snprintf(ws->buffer + len, DEFAULT_WS_BUFFER - len, "%s", ws->headers);
+        len += r;
+        if (r <= 0 || len >= DEFAULT_WS_BUFFER) {
+            ESP_LOGE(TAG, "Error in request generation"
+                          "(strncpy of headers returned %d, desired request len: %d, buffer size: %d", r, len, DEFAULT_WS_BUFFER);
+            return -1;
+        }
+    }
     int r = snprintf(ws->buffer + len, DEFAULT_WS_BUFFER - len, "\r\n");
     len += r;
     if (r <= 0 || len >= DEFAULT_WS_BUFFER) {
@@ -233,7 +248,7 @@ static int _ws_write(esp_transport_handle_t t, int opcode, int mask_flag, const
         for (i = 0; i < len; ++i) {
             buffer[i] = (buffer[i] ^ mask[i % 4]);
         }
-    }    
+    }
     return ret;
 }
 
@@ -352,6 +367,8 @@ static esp_err_t ws_destroy(esp_transport_handle_t t)
     free(ws->buffer);
     free(ws->path);
     free(ws->sub_protocol);
+    free(ws->user_agent);
+    free(ws->headers);
     free(ws);
     return 0;
 }
@@ -409,6 +426,46 @@ esp_err_t esp_transport_ws_set_subprotocol(esp_transport_handle_t t, const char
     return ESP_OK;
 }
 
+esp_err_t esp_transport_ws_set_user_agent(esp_transport_handle_t t, const char *user_agent)
+{
+    if (t == NULL) {
+        return ESP_ERR_INVALID_ARG;
+    }
+    transport_ws_t *ws = esp_transport_get_context_data(t);
+    if (ws->user_agent) {
+        free(ws->user_agent);
+    }
+    if (user_agent == NULL) {
+        ws->user_agent = NULL;
+        return ESP_OK;
+    }
+    ws->user_agent = strdup(user_agent);
+    if (ws->user_agent == NULL) {
+        return ESP_ERR_NO_MEM;
+    }
+    return ESP_OK;
+}
+
+esp_err_t esp_transport_ws_set_headers(esp_transport_handle_t t, const char *headers)
+{
+    if (t == NULL) {
+        return ESP_ERR_INVALID_ARG;
+    }
+    transport_ws_t *ws = esp_transport_get_context_data(t);
+    if (ws->headers) {
+        free(ws->headers);
+    }
+    if (headers == NULL) {
+        ws->headers = NULL;
+        return ESP_OK;
+    }
+    ws->headers = strdup(headers);
+    if (ws->headers == NULL) {
+        return ESP_ERR_NO_MEM;
+    }
+    return ESP_OK;
+}
+
 ws_transport_opcodes_t esp_transport_ws_get_read_opcode(esp_transport_handle_t t)
 {
     transport_ws_t *ws = esp_transport_get_context_data(t);