Parcourir la source

add webclient tls support

chenyong il y a 8 ans
Parent
commit
a646d54775
4 fichiers modifiés avec 272 ajouts et 40 suppressions
  1. 253 39
      webclient.c
  2. 10 0
      webclient.h
  3. 8 0
      webclient_file.c
  4. 1 1
      webclient_internal.h

+ 253 - 39
webclient.c

@@ -48,6 +48,33 @@ char *webclient_strdup(const char *s)
     return tmp;
 }
 
+static int webclient_send(struct webclient_session* session, const unsigned char *buffer, size_t len, int flag)
+{
+    if (!session) 
+        return -RT_ERROR;
+
+#ifdef PKG_USING_WEBCLIENT_TLS
+    if(session->tls_session)
+        return mbedtls_client_write(session->tls_session, buffer, len);
+#endif
+
+    return send(session->socket, buffer, len, flag);     
+}
+
+static int webclient_recv(struct webclient_session* session, unsigned char *buffer, size_t len, int flag)
+{
+    if (!session) 
+        return -RT_ERROR;
+
+#ifdef PKG_USING_WEBCLIENT_TLS
+    if(session->tls_session)
+        return mbedtls_client_read(session->tls_session, buffer, len);
+#endif 
+
+    return recv(session->socket, buffer, len, flag);
+}
+
+
 static char *webclient_header_skip_prefix(char *line, const char *prefix)
 {
     char *ptr;
@@ -78,7 +105,7 @@ static char *webclient_header_skip_prefix(char *line, const char *prefix)
  * before the data.  We need to read exactly to the end of the headers
  * and no more data.  This readline reads a single char at a time.
  */
-static int webclient_read_line(int socket, char *buffer, int size)
+static int webclient_read_line(struct webclient_session* session, char *buffer, int size)
 {
     int rc;
     char *ptr = buffer;
@@ -87,7 +114,11 @@ static int webclient_read_line(int socket, char *buffer, int size)
     /* Keep reading until we fill the buffer. */
     while (count < size)
     {
-        rc = recv(socket, ptr, 1, 0);
+        rc = webclient_recv(session, (unsigned char *)ptr, 1, 0);
+#ifdef PKG_USING_WEBCLIENT_TLS
+        if(session->tls_session && rc == MBEDTLS_ERR_SSL_WANT_READ)
+            continue;
+#endif 
         if (rc <= 0)
             return rc;
 
@@ -132,19 +163,27 @@ static int webclient_resolve_address(struct webclient_session *session, struct a
     int rc = WEBCLIENT_OK;
     char *ptr;
     char port_str[6] = "80"; /* default port of 80(http) */
+    char port_tls_str[6] = "443"; /* default port of 443(https) */
 
     const char *host_addr = 0;
     int url_len, host_addr_len = 0;
 
     url_len = strlen(url);
 
-    /* strip protocol(http) */
-    if (strncmp(url, "http://", 7) != 0)
+    /* strip protocol(http or https) */
+    if (strncmp(url, "http://", 7) == 0)
+    {
+        host_addr = url + 7;
+    }
+    else if(strncmp(url, "https://", 8) == 0)
+    {
+        host_addr = url + 8;
+    }
+    else
     {
         rc = -1;
-        goto _exit;
+        goto _exit;  
     }
-    host_addr = url + 7;
 
     /* ipv6 address */
     if (host_addr[0] == '[')
@@ -190,7 +229,37 @@ static int webclient_resolve_address(struct webclient_session *session, struct a
         }
         host_addr_len = ptr - host_addr;
         *request = (char *)ptr;
+        
+#ifdef PKG_USING_WEBCLIENT_TLS
+        char *port_tls_ptr;
+
+        if(session->tls_session)
+        {
+            port_tls_ptr = strstr(host_addr, ":");
+            if (port_tls_ptr)
+            {
+                int port_tls_len = ptr - port_tls_ptr - 1;
 
+                strncpy(port_tls_str, port_tls_ptr + 1, port_tls_len);
+                port_str[port_tls_len] = '\0';
+
+                host_addr_len = port_tls_ptr - host_addr;
+            }
+        }
+        else 
+        {
+            port_ptr = strstr(host_addr, ":");
+            if (port_ptr)
+            {
+                int port_len = ptr - port_ptr - 1;
+
+                strncpy(port_str, port_ptr + 1, port_len);
+                port_str[port_len] = '\0';
+
+                host_addr_len = port_ptr - host_addr;
+            }
+        }
+#else
         port_ptr = strstr(host_addr, ":");
         if (port_ptr)
         {
@@ -201,6 +270,7 @@ static int webclient_resolve_address(struct webclient_session *session, struct a
 
             host_addr_len = port_ptr - host_addr;
         }
+#endif
     }
 
     if ((host_addr_len < 1) || (host_addr_len > url_len))
@@ -223,7 +293,11 @@ static int webclient_resolve_address(struct webclient_session *session, struct a
         memcpy(host_addr_new, host_addr, host_addr_len);
         host_addr_new[host_addr_len] = '\0';
         session->host = host_addr_new;
-        //rt_kprintf("session->host: %s\n", session->host);
+        
+#ifdef PKG_USING_WEBCLIENT_TLS
+        if(session->tls_session)
+            session->tls_session->host = rt_strdup(host_addr_new);
+#endif
     }
 
     {
@@ -232,6 +306,30 @@ static int webclient_resolve_address(struct webclient_session *session, struct a
         int ret;
 
         memset(&hint, 0, sizeof(hint));
+        
+#ifdef PKG_USING_WEBCLIENT_TLS
+        if(session->tls_session)
+        {
+            session->tls_session->port = rt_strdup(port_tls_str);
+            ret = getaddrinfo(session->tls_session->host, port_tls_str, &hint, res);
+            if (ret != 0)
+            {
+                rt_kprintf("getaddrinfo err: %d '%s'\n", ret, session->host);
+                rc = -1;
+                goto _exit;
+            }
+        }
+        else 
+        {
+            ret = getaddrinfo(session->host, port_str, &hint, res);
+            if (ret != 0)
+            {
+                rt_kprintf("getaddrinfo err: %d '%s'\n", ret, session->host);
+                rc = -1;
+                goto _exit;
+            }
+        }
+#else
         ret = getaddrinfo(session->host, port_str, &hint, res);
         if (ret != 0)
         {
@@ -239,8 +337,8 @@ static int webclient_resolve_address(struct webclient_session *session, struct a
             rc = -1;
             goto _exit;
         }
+#endif
     }
-
 _exit:
     if (rc != WEBCLIENT_OK)
     {
@@ -387,7 +485,7 @@ int webclient_handle_response(struct webclient_session *session)
         int i;
 
         /* read a line from the header information. */
-        rc = webclient_read_line(session->socket, mimeBuffer, WEBCLIENT_RESPONSE_BUFSZ);
+        rc = webclient_read_line(session, mimeBuffer, WEBCLIENT_RESPONSE_BUFSZ);
         if (rc < 0)
             break;
 
@@ -469,7 +567,7 @@ int webclient_handle_response(struct webclient_session *session)
             && strcmp(session->transfer_encoding, "chunked") == 0)
     {
         /* chunk mode, we should get the first chunk size */
-        webclient_read_line(session->socket, mimeBuffer, WEBCLIENT_RESPONSE_BUFSZ);
+        webclient_read_line(session, mimeBuffer, WEBCLIENT_RESPONSE_BUFSZ);
         session->chunk_sz = strtol(mimeBuffer, RT_NULL, 16);
         session->chunk_offset = 0;
     }
@@ -521,28 +619,62 @@ int webclient_connect(struct webclient_session *session, const char *URI)
     else
         session->request = RT_NULL;
 
-    socket_handle = socket(res->ai_family, SOCK_STREAM, IPPROTO_TCP); //
-    if (socket_handle < 0)
+#ifdef PKG_USING_WEBCLIENT_TLS
+    if(session->tls_session)
     {
-        rc = -WEBCLIENT_NOSOCKET;
+       int tls_ret = 0;
+
+        if((tls_ret = mbedtls_client_context(session->tls_session)) < 0)
+        {
+            rt_kprintf("webclient mbedtls_client_context err return : -0x%x\n", -tls_ret);
+            return -RT_ERROR;
+        }
+        
+        if((tls_ret = mbedtls_client_connect(session->tls_session)) < 0)
+    	{
+    		rt_kprintf("webclient mbedtls_client_connect err return : -0x%x\n", -tls_ret);
+            rc = -WEBCLIENT_CONNECT_FAILED;
+            goto _exit;
+    	}
+        
+        socket_handle = session->tls_session->server_fd.fd;
+
+        /* set recv timeout option */
+        setsockopt(socket_handle, SOL_SOCKET, SO_RCVTIMEO, (void*) &timeout,
+                sizeof(timeout));
+        setsockopt(socket_handle, SOL_SOCKET, SO_SNDTIMEO, (void*) &timeout,
+                sizeof(timeout));
+
+        session->socket = socket_handle;
+        rc = WEBCLIENT_OK;
         goto _exit;
     }
+#endif
 
-    /* set recv timeout option */
-    setsockopt(socket_handle, SOL_SOCKET, SO_RCVTIMEO, (void *) &timeout,
-               sizeof(timeout));
-    setsockopt(socket_handle, SOL_SOCKET, SO_SNDTIMEO, (void *) &timeout,
-               sizeof(timeout));
+    {       
+        socket_handle = socket(res->ai_family, SOCK_STREAM, IPPROTO_TCP); //
+        if (socket_handle < 0)
+        {
+            rc = -WEBCLIENT_NOSOCKET;
+            goto _exit;
+        }
 
-    if (connect(socket_handle, res->ai_addr, res->ai_addrlen) != 0)
-    {
-        /* connect failed, close socket handle */
-        closesocket(socket_handle);
-        rc = -WEBCLIENT_CONNECT_FAILED;
-        goto _exit;
-    }
+        /* set recv timeout option */
+        setsockopt(socket_handle, SOL_SOCKET, SO_RCVTIMEO, (void *) &timeout,
+                   sizeof(timeout));
+        setsockopt(socket_handle, SOL_SOCKET, SO_SNDTIMEO, (void *) &timeout,
+                   sizeof(timeout));
+
+        if (connect(socket_handle, res->ai_addr, res->ai_addrlen) != 0)
+        {
+            /* connect failed, close socket handle */
+            closesocket(socket_handle);
+            rc = -WEBCLIENT_CONNECT_FAILED;
+            goto _exit;
+        }
 
-    session->socket = socket_handle;
+        session->socket = socket_handle;
+    }
 
 _exit:
     if (res)
@@ -553,6 +685,42 @@ _exit:
     return rc;
 }
 
+int webclient_open_tls(struct webclient_session * session, const char *URI)
+{
+#ifdef PKG_USING_WEBCLIENT_TLS
+    int tls_ret = 0;
+    const char *pers = "wenclient";
+
+    if(!session)
+        return -RT_ERROR;
+
+    session->tls_session = (MbedTLSSession *)web_malloc(sizeof(MbedTLSSession));
+    if (session->tls_session == RT_NULL)
+        return -RT_ERROR;
+    memset(session->tls_session, 0x0, sizeof(MbedTLSSession));    
+    
+    session->tls_session->buffer_len = WEBCLIENT_TLS_READ_BUFFER;
+    session->tls_session->buffer = web_malloc(session->tls_session->buffer_len);
+    if(session->tls_session->buffer == RT_NULL)
+    {
+        rt_kprintf("no memory for webclient tls_session buffer malloc\n");
+        return -RT_ERROR;
+    }
+    
+    if((tls_ret = mbedtls_client_init(session->tls_session, (void *)pers, strlen(pers))) < 0)
+    {
+        rt_kprintf("webclient mbedtls_client_init err return : -0x%x\n", -tls_ret);
+        return -RT_ERROR;
+    }
+    
+    return RT_EOK;  
+#else
+    rt_kprintf("don't support TLS protocol, check your menuconfig!\n");
+    return -RT_ERROR;
+    
+#endif
+}
+
 struct webclient_session *webclient_open(const char *URI)
 {
     struct webclient_session *session;
@@ -562,6 +730,16 @@ struct webclient_session *webclient_open(const char *URI)
     if (session == RT_NULL)
         return RT_NULL;
     memset(session, 0x0, sizeof(struct webclient_session));
+    session->socket = -1;
+    
+    if(strncmp(URI, "https://", 8) == 0)
+    {
+        if(webclient_open_tls(session, URI) < 0)
+        {   
+           webclient_close(session);
+           return RT_NULL;
+        }
+    }
 
     if (webclient_connect(session, URI) < 0)
     {
@@ -611,6 +789,15 @@ struct webclient_session *webclient_open_position(const char *URI, int position)
         return RT_NULL;
     memset(session, 0x0, sizeof(struct webclient_session));
 
+    if(strncmp(URI, "https://", 8) == 0)
+    {
+        if(webclient_open_tls(session, URI) < 0)
+        {   
+           webclient_close(session);
+           return RT_NULL;
+        }
+    }
+
     if (webclient_connect(session, URI) < 0)
     {
         /* connect to webclient server failed. */
@@ -671,6 +858,15 @@ struct webclient_session *webclient_open_header(const char *URI, int method,
         return RT_NULL;
     memset(session, 0, sizeof(struct webclient_session));
 
+    if(strncmp(URI, "https://", 8) == 0)
+    {
+        if(webclient_open_tls(session, URI) < 0)
+        {   
+           webclient_close(session);
+           return RT_NULL;
+        }
+    }
+
     if (webclient_connect(session, URI) < 0)
     {
         /* connect to webclient server failed. */
@@ -715,12 +911,12 @@ static int webclient_next_chunk(struct webclient_session *session)
     char line[64];
     int length;
 
-    length = webclient_read_line(session->socket, line, sizeof(line));
+    length = webclient_read_line(session, line, sizeof(line));
     if (length)
     {
         if (strcmp(line, "\r\n") == 0)
         {
-            length = webclient_read_line(session->socket, line, sizeof(line));
+            length = webclient_read_line(session, line, sizeof(line));
             if (length <= 0)
             {
                 closesocket(session->socket);
@@ -766,7 +962,7 @@ int webclient_read(struct webclient_session *session, unsigned char *buffer,
         if (length > (session->chunk_sz - session->chunk_offset))
             length = session->chunk_sz - session->chunk_offset;
 
-        bytesRead = recv(session->socket, buffer, length, 0);
+        bytesRead = webclient_recv(session, buffer, length, 0);
         if (bytesRead <= 0)
         {
             if (errno == EWOULDBLOCK || errno == EAGAIN)
@@ -811,9 +1007,13 @@ int webclient_read(struct webclient_session *session, unsigned char *buffer,
     left = length;
     do
     {
-        bytesRead = recv(session->socket, buffer + totalRead, left, 0);
+        bytesRead = webclient_recv(session, buffer + totalRead, left, 0);
         if (bytesRead <= 0)
         {
+#ifdef PKG_USING_WEBCLIENT_TLS
+            if(session->tls_session && bytesRead == MBEDTLS_ERR_SSL_WANT_READ)
+                continue;
+#endif  
             rt_kprintf("errno=%d\n", bytesRead);
 
             if (totalRead)
@@ -868,9 +1068,13 @@ int webclient_write(struct webclient_session *session,
      */
     do
     {
-        bytesWrite = send(session->socket, buffer + totalWrite, left, 0);
+        bytesWrite = webclient_send(session, buffer + totalWrite, left, 0);
         if (bytesWrite <= 0)
         {
+#ifdef PKG_USING_WEBCLIENT_TLS
+            if(session->tls_session && bytesWrite == MBEDTLS_ERR_SSL_WANT_WRITE)
+                continue;
+#endif
             if (errno == EWOULDBLOCK || errno == EAGAIN)
             {
                 /* send timeout */
@@ -905,15 +1109,25 @@ int webclient_write(struct webclient_session *session,
 int webclient_close(struct webclient_session *session)
 {
     RT_ASSERT(session != RT_NULL);
-
+    
+#ifdef PKG_USING_WEBCLIENT_TLS
+    if(session->tls_session)
+        mbedtls_client_close(session->tls_session);
+#endif
     if (session->socket >= 0)
-        closesocket(session->socket);
-    web_free(session->transfer_encoding);
-    web_free(session->content_type);
-    web_free(session->last_modified);
-    web_free(session->host);
-    web_free(session->request);
-    web_free(session);
+        closesocket(session->socket);    
+    if(session->transfer_encoding)
+        web_free(session->transfer_encoding);    
+    if(session->content_type)
+        web_free(session->content_type);
+    if(session->last_modified)
+        web_free(session->last_modified);
+    if(session->host)
+        web_free(session->host);
+    if(session->request)
+        web_free(session->request);
+    if(session)
+        web_free(session);
 
     return 0;
 }

+ 10 - 0
webclient.h

@@ -18,8 +18,13 @@
 
 #include <rtthread.h>
 
+#ifdef PKG_USING_WEBCLIENT_TLS
+#include <tls_client.h>
+#endif
+
 #define WEBCLIENT_HEADER_BUFSZ      4096
 #define WEBCLIENT_RESPONSE_BUFSZ    4096
+#define WEBCLIENT_TLS_READ_BUFFER   4096
 
 //typedef unsigned int size_t;
 
@@ -77,6 +82,11 @@ struct webclient_session
 
     /* remainder of content reading */
     size_t content_length_remainder;
+    
+#ifdef PKG_USING_WEBCLIENT_TLS
+        /* mbedtls session struct*/
+        MbedTLSSession *tls_session;
+#endif
 };
 
 struct webclient_session *webclient_open(const char *URI);

+ 8 - 0
webclient_file.c

@@ -131,6 +131,14 @@ int webclient_post_file(const char* URI, const char* filename,
     }
     memset(session, 0x0, sizeof(struct webclient_session));
 
+    if(strncmp(URI, "https://", 8) == 0)
+    {
+        if(webclient_open_tls(session, URI) < 0)
+        {   
+           goto __exit;
+        }
+    }
+
     rc = webclient_connect(session, URI);
     if (rc < 0)
         goto __exit;

+ 1 - 1
webclient_internal.h

@@ -3,7 +3,7 @@
 
 #include <rtthread.h>
 
-#ifdef RT_USING_ESP_PSRAM
+#ifdef RT_USING_PSRAM
 #include <drv_sdram.h>
 
 #define web_malloc  sdram_malloc