|
|
@@ -23,6 +23,10 @@
|
|
|
#include "esp_transport_utils.h"
|
|
|
#include "esp_transport_internal.h"
|
|
|
|
|
|
+#define GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t) \
|
|
|
+ transport_esp_tls_t *ssl = ssl_get_context_data(t); \
|
|
|
+ if (!ssl) { return; }
|
|
|
+
|
|
|
static const char *TAG = "TRANSPORT_BASE";
|
|
|
|
|
|
typedef enum {
|
|
|
@@ -40,11 +44,30 @@ typedef struct transport_esp_tls {
|
|
|
transport_ssl_conn_state_t conn_state;
|
|
|
} transport_esp_tls_t;
|
|
|
|
|
|
+static inline struct transport_esp_tls * ssl_get_context_data(esp_transport_handle_t t)
|
|
|
+{
|
|
|
+ if (!t) {
|
|
|
+ return NULL;
|
|
|
+ }
|
|
|
+ if (t->data) { // Prefer internal ssl context (independent from the list)
|
|
|
+ return (transport_esp_tls_t*)t->data;
|
|
|
+ }
|
|
|
+ if (t->base && t->base->transport_esp_tls) { // Next one is the lists inherent context
|
|
|
+ t->data = t->base->transport_esp_tls; // Optimize: if we have base context, use it as internal
|
|
|
+ return t->base->transport_esp_tls;
|
|
|
+ }
|
|
|
+ // If we don't have a valid context, let's to create one
|
|
|
+ transport_esp_tls_t *ssl = esp_transport_esp_tls_create();
|
|
|
+ ESP_TRANSPORT_MEM_CHECK(TAG, ssl, return NULL)
|
|
|
+ t->data = ssl;
|
|
|
+ return ssl;
|
|
|
+}
|
|
|
+
|
|
|
static int ssl_close(esp_transport_handle_t t);
|
|
|
|
|
|
static int esp_tls_connect_async(esp_transport_handle_t t, const char *host, int port, int timeout_ms, bool is_plain_tcp)
|
|
|
{
|
|
|
- transport_esp_tls_t *ssl = t->base->transport_esp_tls;
|
|
|
+ transport_esp_tls_t *ssl = ssl_get_context_data(t);
|
|
|
if (ssl->conn_state == TRANS_SSL_INIT) {
|
|
|
ssl->cfg.timeout_ms = timeout_ms;
|
|
|
ssl->cfg.is_plain_tcp = is_plain_tcp;
|
|
|
@@ -74,7 +97,7 @@ static inline int tcp_connect_async(esp_transport_handle_t t, const char *host,
|
|
|
|
|
|
static int esp_tls_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms, bool is_plain_tcp)
|
|
|
{
|
|
|
- transport_esp_tls_t *ssl = t->base->transport_esp_tls;
|
|
|
+ transport_esp_tls_t *ssl = ssl_get_context_data(t);
|
|
|
|
|
|
ssl->cfg.timeout_ms = timeout_ms;
|
|
|
ssl->cfg.is_plain_tcp = is_plain_tcp;
|
|
|
@@ -103,7 +126,7 @@ static inline int tcp_connect(esp_transport_handle_t t, const char *host, int po
|
|
|
|
|
|
static int ssl_poll_read(esp_transport_handle_t t, int timeout_ms)
|
|
|
{
|
|
|
- transport_esp_tls_t *ssl = t->base->transport_esp_tls;
|
|
|
+ transport_esp_tls_t *ssl = ssl_get_context_data(t);
|
|
|
int ret = -1;
|
|
|
int remain = 0;
|
|
|
struct timeval timeout;
|
|
|
@@ -132,7 +155,7 @@ static int ssl_poll_read(esp_transport_handle_t t, int timeout_ms)
|
|
|
|
|
|
static int ssl_poll_write(esp_transport_handle_t t, int timeout_ms)
|
|
|
{
|
|
|
- transport_esp_tls_t *ssl = t->base->transport_esp_tls;
|
|
|
+ transport_esp_tls_t *ssl = ssl_get_context_data(t);
|
|
|
int ret = -1;
|
|
|
struct timeval timeout;
|
|
|
fd_set writeset;
|
|
|
@@ -156,7 +179,7 @@ static int ssl_poll_write(esp_transport_handle_t t, int timeout_ms)
|
|
|
static int ssl_write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms)
|
|
|
{
|
|
|
int poll, ret;
|
|
|
- transport_esp_tls_t *ssl = t->base->transport_esp_tls;
|
|
|
+ transport_esp_tls_t *ssl = ssl_get_context_data(t);
|
|
|
|
|
|
if ((poll = esp_transport_poll_write(t, timeout_ms)) <= 0) {
|
|
|
ESP_LOGW(TAG, "Poll timeout or error, errno=%s, fd=%d, timeout_ms=%d", strerror(errno), ssl->tls->sockfd, timeout_ms);
|
|
|
@@ -173,7 +196,7 @@ static int ssl_write(esp_transport_handle_t t, const char *buffer, int len, int
|
|
|
static int ssl_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
|
|
|
{
|
|
|
int poll, ret;
|
|
|
- transport_esp_tls_t *ssl = t->base->transport_esp_tls;
|
|
|
+ transport_esp_tls_t *ssl = ssl_get_context_data(t);
|
|
|
|
|
|
if ((poll = esp_transport_poll_read(t, timeout_ms)) <= 0) {
|
|
|
return poll;
|
|
|
@@ -196,8 +219,8 @@ static int ssl_read(esp_transport_handle_t t, char *buffer, int len, int timeout
|
|
|
static int ssl_close(esp_transport_handle_t t)
|
|
|
{
|
|
|
int ret = -1;
|
|
|
- if (t && t->base && t->base->transport_esp_tls && t->base->transport_esp_tls->ssl_initialized) {
|
|
|
- transport_esp_tls_t *ssl = t->base->transport_esp_tls;
|
|
|
+ transport_esp_tls_t *ssl = ssl_get_context_data(t);
|
|
|
+ if (ssl && ssl->ssl_initialized) {
|
|
|
ret = esp_tls_conn_destroy(ssl->tls);
|
|
|
ssl->conn_state = TRANS_SSL_INIT;
|
|
|
ssl->ssl_initialized = false;
|
|
|
@@ -207,127 +230,124 @@ static int ssl_close(esp_transport_handle_t t)
|
|
|
|
|
|
static int ssl_destroy(esp_transport_handle_t t)
|
|
|
{
|
|
|
- esp_transport_close(t);
|
|
|
+ transport_esp_tls_t *ssl = ssl_get_context_data(t);
|
|
|
+ if (ssl) {
|
|
|
+ esp_transport_close(t);
|
|
|
+ if (t->base && t->base->transport_esp_tls &&
|
|
|
+ t->data == t->base->transport_esp_tls) {
|
|
|
+ // if internal ssl the same as the foundation transport,
|
|
|
+ // just zero out, it will be freed on list destroy
|
|
|
+ t->data = NULL;
|
|
|
+ }
|
|
|
+ esp_transport_esp_tls_destroy(t->data); // okay to pass NULL
|
|
|
+ }
|
|
|
return 0;
|
|
|
}
|
|
|
|
|
|
+
|
|
|
void esp_transport_ssl_enable_global_ca_store(esp_transport_handle_t t)
|
|
|
{
|
|
|
- if (t && t->base && t->base->transport_esp_tls) {
|
|
|
- t->base->transport_esp_tls->cfg.use_global_ca_store = true;
|
|
|
- }
|
|
|
+ GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
|
|
|
+ ssl->cfg.use_global_ca_store = true;
|
|
|
}
|
|
|
|
|
|
void esp_transport_ssl_set_psk_key_hint(esp_transport_handle_t t, const psk_hint_key_t* psk_hint_key)
|
|
|
{
|
|
|
- if (t && t->base && t->base->transport_esp_tls) {
|
|
|
- t->base->transport_esp_tls->cfg.psk_hint_key = psk_hint_key;
|
|
|
- }
|
|
|
+ GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
|
|
|
+ ssl->cfg.psk_hint_key = psk_hint_key;
|
|
|
}
|
|
|
|
|
|
void esp_transport_ssl_set_cert_data(esp_transport_handle_t t, const char *data, int len)
|
|
|
{
|
|
|
- if (t && t->base && t->base->transport_esp_tls) {
|
|
|
- t->base->transport_esp_tls->cfg.cacert_pem_buf = (void *)data;
|
|
|
- t->base->transport_esp_tls->cfg.cacert_pem_bytes = len + 1;
|
|
|
- }
|
|
|
+ GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
|
|
|
+ ssl->cfg.cacert_pem_buf = (void *)data;
|
|
|
+ ssl->cfg.cacert_pem_bytes = len + 1;
|
|
|
}
|
|
|
|
|
|
void esp_transport_ssl_set_cert_data_der(esp_transport_handle_t t, const char *data, int len)
|
|
|
{
|
|
|
- if (t && t->base && t->base->transport_esp_tls) {
|
|
|
- t->base->transport_esp_tls->cfg.cacert_buf = (void *)data;
|
|
|
- t->base->transport_esp_tls->cfg.cacert_bytes = len;
|
|
|
- }
|
|
|
+ GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
|
|
|
+ ssl->cfg.cacert_buf = (void *)data;
|
|
|
+ ssl->cfg.cacert_bytes = len;
|
|
|
}
|
|
|
|
|
|
void esp_transport_ssl_set_client_cert_data(esp_transport_handle_t t, const char *data, int len)
|
|
|
{
|
|
|
- if (t && t->base && t->base->transport_esp_tls) {
|
|
|
- t->base->transport_esp_tls->cfg.clientcert_pem_buf = (void *)data;
|
|
|
- t->base->transport_esp_tls->cfg.clientcert_pem_bytes = len + 1;
|
|
|
- }
|
|
|
+ GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
|
|
|
+ ssl->cfg.clientcert_pem_buf = (void *)data;
|
|
|
+ ssl->cfg.clientcert_pem_bytes = len + 1;
|
|
|
}
|
|
|
|
|
|
void esp_transport_ssl_set_client_cert_data_der(esp_transport_handle_t t, const char *data, int len)
|
|
|
{
|
|
|
- if (t && t->base && t->base->transport_esp_tls) {
|
|
|
- t->base->transport_esp_tls->cfg.clientcert_buf = (void *)data;
|
|
|
- t->base->transport_esp_tls->cfg.clientcert_bytes = len;
|
|
|
- }
|
|
|
+ GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
|
|
|
+ ssl->cfg.clientcert_buf = (void *)data;
|
|
|
+ ssl->cfg.clientcert_bytes = len;
|
|
|
}
|
|
|
|
|
|
void esp_transport_ssl_set_client_key_data(esp_transport_handle_t t, const char *data, int len)
|
|
|
{
|
|
|
- if (t && t->base && t->base->transport_esp_tls) {
|
|
|
- t->base->transport_esp_tls->cfg.clientkey_pem_buf = (void *)data;
|
|
|
- t->base->transport_esp_tls->cfg.clientkey_pem_bytes = len + 1;
|
|
|
- }
|
|
|
+ GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
|
|
|
+ ssl->cfg.clientkey_pem_buf = (void *)data;
|
|
|
+ ssl->cfg.clientkey_pem_bytes = len + 1;
|
|
|
}
|
|
|
|
|
|
void esp_transport_ssl_set_client_key_password(esp_transport_handle_t t, const char *password, int password_len)
|
|
|
{
|
|
|
- if (t && t->base && t->base->transport_esp_tls) {
|
|
|
- t->base->transport_esp_tls->cfg.clientkey_password = (void *)password;
|
|
|
- t->base->transport_esp_tls->cfg.clientkey_password_len = password_len;
|
|
|
- }
|
|
|
+ GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
|
|
|
+ ssl->cfg.clientkey_password = (void *)password;
|
|
|
+ ssl->cfg.clientkey_password_len = password_len;
|
|
|
}
|
|
|
|
|
|
void esp_transport_ssl_set_client_key_data_der(esp_transport_handle_t t, const char *data, int len)
|
|
|
{
|
|
|
- if (t && t->base && t->base->transport_esp_tls) {
|
|
|
- t->base->transport_esp_tls->cfg.clientkey_buf = (void *)data;
|
|
|
- t->base->transport_esp_tls->cfg.clientkey_bytes = len;
|
|
|
- }
|
|
|
+ GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
|
|
|
+ ssl->cfg.clientkey_buf = (void *)data;
|
|
|
+ ssl->cfg.clientkey_bytes = len;
|
|
|
}
|
|
|
|
|
|
void esp_transport_ssl_set_alpn_protocol(esp_transport_handle_t t, const char **alpn_protos)
|
|
|
{
|
|
|
- if (t && t->base && t->base->transport_esp_tls) {
|
|
|
- t->base->transport_esp_tls->cfg.alpn_protos = alpn_protos;
|
|
|
- }
|
|
|
+ GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
|
|
|
+ ssl->cfg.alpn_protos = alpn_protos;
|
|
|
}
|
|
|
|
|
|
void esp_transport_ssl_skip_common_name_check(esp_transport_handle_t t)
|
|
|
{
|
|
|
- if (t && t->base && t->base->transport_esp_tls) {
|
|
|
- t->base->transport_esp_tls->cfg.skip_common_name = true;
|
|
|
- }
|
|
|
+ GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
|
|
|
+ ssl->cfg.skip_common_name = true;
|
|
|
}
|
|
|
|
|
|
void esp_transport_ssl_use_secure_element(esp_transport_handle_t t)
|
|
|
{
|
|
|
- if (t && t->base && t->base->transport_esp_tls) {
|
|
|
- t->base->transport_esp_tls->cfg.use_secure_element = true;
|
|
|
- }
|
|
|
+ GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
|
|
|
+ ssl->cfg.use_secure_element = true;
|
|
|
}
|
|
|
|
|
|
static int ssl_get_socket(esp_transport_handle_t t)
|
|
|
{
|
|
|
- if (t && t->base && t->base->transport_esp_tls && t->base->transport_esp_tls->tls) {
|
|
|
- return t->base->transport_esp_tls->tls->sockfd;
|
|
|
+ transport_esp_tls_t *ssl = ssl_get_context_data(t);
|
|
|
+ if (ssl && ssl->tls) {
|
|
|
+ return ssl->tls->sockfd;
|
|
|
}
|
|
|
return -1;
|
|
|
}
|
|
|
|
|
|
void esp_transport_ssl_set_ds_data(esp_transport_handle_t t, void *ds_data)
|
|
|
{
|
|
|
- if (t && t->base && t->base->transport_esp_tls) {
|
|
|
- t->base->transport_esp_tls->cfg.ds_data = ds_data;
|
|
|
- }
|
|
|
+ GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
|
|
|
+ ssl->cfg.ds_data = ds_data;
|
|
|
}
|
|
|
|
|
|
void esp_transport_ssl_set_keep_alive(esp_transport_handle_t t, esp_transport_keep_alive_t *keep_alive_cfg)
|
|
|
{
|
|
|
- if (t && t->base && t->base->transport_esp_tls) {
|
|
|
- t->base->transport_esp_tls->cfg.keep_alive_cfg = (tls_keep_alive_cfg_t *) keep_alive_cfg;
|
|
|
- }
|
|
|
+ GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
|
|
|
+ ssl->cfg.keep_alive_cfg = (tls_keep_alive_cfg_t *) keep_alive_cfg;
|
|
|
}
|
|
|
|
|
|
esp_transport_handle_t esp_transport_ssl_init(void)
|
|
|
{
|
|
|
esp_transport_handle_t t = esp_transport_init();
|
|
|
- esp_transport_set_context_data(t, NULL);
|
|
|
esp_transport_set_func(t, ssl_connect, ssl_read, ssl_write, ssl_close, ssl_poll_read, ssl_poll_write, ssl_destroy);
|
|
|
esp_transport_set_async_connect_func(t, ssl_connect_async);
|
|
|
t->_get_socket = ssl_get_socket;
|
|
|
@@ -348,7 +368,6 @@ void esp_transport_esp_tls_destroy(struct transport_esp_tls* transport_esp_tls)
|
|
|
esp_transport_handle_t esp_transport_tcp_init(void)
|
|
|
{
|
|
|
esp_transport_handle_t t = esp_transport_init();
|
|
|
- esp_transport_set_context_data(t, NULL);
|
|
|
esp_transport_set_func(t, tcp_connect, ssl_read, ssl_write, ssl_close, ssl_poll_read, ssl_poll_write, ssl_destroy);
|
|
|
esp_transport_set_async_connect_func(t, tcp_connect_async);
|
|
|
t->_get_socket = ssl_get_socket;
|