Parcourir la source

sockets: task #14247, add CMSG and IP_PKTINFO

This commit adds CMSG infrastructure (currently used with recvmsg) and
the IP_PKTINFO socket option.

In order to use IP_PKTINFO, set LWIP_NETBUF_RECVINFO to 1

Unit test is added to verify this feature
Joel Cunningham il y a 8 ans
Parent
commit
2f117add7a

+ 6 - 4
src/api/api_msg.c

@@ -245,12 +245,10 @@ recv_udp(void *arg, struct udp_pcb *pcb, struct pbuf *p,
     ip_addr_set(&buf->addr, addr);
     buf->port = port;
 #if LWIP_NETBUF_RECVINFO
-    {
+    if (conn->flags & NETCONN_FLAG_PKTINFO) {
       /* get the UDP header - always in the first pbuf, ensured by udp_input */
       const struct udp_hdr* udphdr = (const struct udp_hdr*)ip_next_header_ptr();
-#if LWIP_CHECKSUM_ON_COPY
       buf->flags = NETBUF_FLAG_DESTADDR;
-#endif /* LWIP_CHECKSUM_ON_COPY */
       ip_addr_set(&buf->toaddr, ip_current_dest_addr());
       buf->toport_chksum = udphdr->dest;
     }
@@ -694,6 +692,7 @@ netconn_alloc(enum netconn_type t, netconn_callback callback)
 {
   struct netconn *conn;
   int size;
+  u8_t init_flags = 0;
 
   conn = (struct netconn *)memp_malloc(MEMP_NETCONN);
   if (conn == NULL) {
@@ -714,6 +713,9 @@ netconn_alloc(enum netconn_type t, netconn_callback callback)
 #if LWIP_UDP
   case NETCONN_UDP:
     size = DEFAULT_UDP_RECVMBOX_SIZE;
+#if LWIP_NETBUF_RECVINFO
+    init_flags |= NETCONN_FLAG_PKTINFO;
+#endif /* LWIP_NETBUF_RECVINFO */
     break;
 #endif /* LWIP_UDP */
 #if LWIP_TCP
@@ -761,7 +763,7 @@ netconn_alloc(enum netconn_type t, netconn_callback callback)
 #if LWIP_SO_LINGER
   conn->linger = -1;
 #endif /* LWIP_SO_LINGER */
-  conn->flags = 0;
+  conn->flags = init_flags;
   return conn;
 free_and_return:
   memp_free(MEMP_NETCONN, conn);

+ 71 - 14
src/api/sockets.c

@@ -1030,8 +1030,7 @@ lwip_recv_tcp_from(struct lwip_sock *sock, struct sockaddr *from, socklen_t *fro
  * Keeps sock->lastdata for peeking.
  */
 static err_t
-lwip_recvfrom_udp_raw(struct lwip_sock *sock, int flags, const struct iovec *iov, int iovcnt,
-                      struct sockaddr *from, socklen_t *fromlen, u16_t *datagram_len, int dbg_s)
+lwip_recvfrom_udp_raw(struct lwip_sock *sock, int flags, struct msghdr *msg, u16_t *datagram_len, int dbg_s)
 {
   struct netbuf *buf;
   u8_t apiflags;
@@ -1040,7 +1039,7 @@ lwip_recvfrom_udp_raw(struct lwip_sock *sock, int flags, const struct iovec *iov
   int i;
 
   LWIP_UNUSED_ARG(dbg_s);
-  LWIP_ERROR("lwip_recvfrom_udp_raw: invalid arguments", (iov != NULL) || (iovcnt <= 0), return ERR_ARG;);
+  LWIP_ERROR("lwip_recvfrom_udp_raw: invalid arguments", (msg->msg_iov != NULL) || (msg->msg_iovlen <= 0), return ERR_ARG;);
 
   if (flags & MSG_DONTWAIT) {
     apiflags = NETCONN_DONTBLOCK;
@@ -1069,30 +1068,63 @@ lwip_recvfrom_udp_raw(struct lwip_sock *sock, int flags, const struct iovec *iov
 
   copied = 0;
   /* copy the pbuf payload into the iovs */
-  for (i = 0; (i < iovcnt) && (copied < buflen); i++) {
+  for (i = 0; (i < msg->msg_iovlen) && (copied < buflen); i++) {
     u16_t len_left = buflen - copied;
-    if (iov[i].iov_len > len_left) {
+    if (msg->msg_iov[i].iov_len > len_left) {
       copylen = len_left;
     } else {
-      copylen = (u16_t)iov[i].iov_len;
+      copylen = (u16_t)msg->msg_iov[i].iov_len;
     }
 
     /* copy the contents of the received buffer into
         the supplied memory buffer */
-    pbuf_copy_partial(buf->p, (u8_t*)iov[i].iov_base, copylen, copied);
+    pbuf_copy_partial(buf->p, (u8_t*)msg->msg_iov[i].iov_base, copylen, copied);
     copied += copylen;
   }
 
   /* Check to see from where the data was.*/
 #if !SOCKETS_DEBUG
-  if (from && fromlen)
+  if (msg->msg_name && msg->msg_namelen)
 #endif /* !SOCKETS_DEBUG */
   {
     LWIP_DEBUGF(SOCKETS_DEBUG, ("lwip_recvfrom_udp_raw(%d):  addr=", dbg_s));
     ip_addr_debug_print(SOCKETS_DEBUG, netbuf_fromaddr(buf));
     LWIP_DEBUGF(SOCKETS_DEBUG, (" port=%"U16_F" len=%d\n", netbuf_fromport(buf), copied));
-    if (from && fromlen) {
-      lwip_sock_make_addr(sock->conn, netbuf_fromaddr(buf), netbuf_fromport(buf), from, fromlen);
+    if (msg->msg_name && msg->msg_namelen) {
+      lwip_sock_make_addr(sock->conn, netbuf_fromaddr(buf), netbuf_fromport(buf),
+                          (struct sockaddr *)msg->msg_name, &msg->msg_namelen);
+    }
+  }
+
+  /* Initialize flag output */
+  msg->msg_flags = 0;
+
+  if (msg->msg_control){
+    u8_t wrote_msg = 0;
+#if LWIP_NETBUF_RECVINFO
+    /* Check if packet info was recorded */
+    if (buf->flags & NETBUF_FLAG_DESTADDR) {
+      if (IP_IS_V4(&buf->toaddr)) {
+#if LWIP_IPV4
+        if (msg->msg_controllen >= CMSG_SPACE(sizeof(struct in_pktinfo))) {
+          struct cmsghdr *chdr = CMSG_FIRSTHDR(msg); /* This will always return a header!! */
+          struct in_pktinfo *pkti = (struct in_pktinfo *)CMSG_DATA(chdr);
+          chdr->cmsg_level = IPPROTO_IP;
+          chdr->cmsg_type = IP_PKTINFO;
+          chdr->cmsg_len = CMSG_LEN(sizeof(struct in_pktinfo));
+          pkti->ipi_ifindex = buf->p->if_idx;
+          inet_addr_from_ip4addr(&pkti->ipi_addr, ip_2_ip4(netbuf_destaddr(buf)));
+          wrote_msg = 1;
+        } else {
+          msg->msg_flags |= MSG_CTRUNC;
+        }
+#endif /* LWIP_IPV4 */
+      }
+    }
+#endif /* LWIP_NETBUF_RECVINFO */
+
+    if (!wrote_msg) {
+      msg->msg_controllen = 0;
     }
   }
 
@@ -1130,10 +1162,18 @@ lwip_recvfrom(int s, void *mem, size_t len, int flags,
   {
     u16_t datagram_len = 0;
     struct iovec vec;
+    struct msghdr msg;
     err_t err;
     vec.iov_base = mem;
     vec.iov_len = len;
-    err = lwip_recvfrom_udp_raw(sock, flags, &vec, 1, from, fromlen, &datagram_len, s);
+    msg.msg_control = NULL;
+    msg.msg_controllen = 0;
+    msg.msg_flags = 0;
+    msg.msg_iov = &vec;
+    msg.msg_iovlen = 1;
+    msg.msg_name = from;
+    msg.msg_namelen = (fromlen ? *fromlen : 0);
+    err = lwip_recvfrom_udp_raw(sock, flags, &msg, &datagram_len, s);
     if (err != ERR_OK) {
       LWIP_DEBUGF(SOCKETS_DEBUG, ("lwip_recvfrom[UDP/RAW](%d): buf == NULL, error is \"%s\"!\n",
         s, lwip_strerr(err)));
@@ -1142,6 +1182,9 @@ lwip_recvfrom(int s, void *mem, size_t len, int flags,
       return -1;
     }
     ret = (ssize_t)LWIP_MIN(LWIP_MIN(len, datagram_len), SSIZE_MAX);
+    if (fromlen) {
+      *fromlen = msg.msg_namelen;
+    }
   }
 
   sock_set_errno(sock, 0);
@@ -1240,8 +1283,7 @@ lwip_recvmsg(int s, struct msghdr *message, int flags)
   {
     u16_t datagram_len = 0;
     err_t err;
-    err = lwip_recvfrom_udp_raw(sock, flags, message->msg_iov, message->msg_iovlen,
-      (struct sockaddr *)message->msg_name, &message->msg_namelen, &datagram_len, s);
+    err = lwip_recvfrom_udp_raw(sock, flags, message, &datagram_len, s);
     if (err != ERR_OK) {
       LWIP_DEBUGF(SOCKETS_DEBUG, ("lwip_recvmsg[UDP/RAW](%d): buf == NULL, error is \"%s\"!\n",
         s, lwip_strerr(err)));
@@ -1249,7 +1291,6 @@ lwip_recvmsg(int s, struct msghdr *message, int flags)
       done_socket(sock);
       return -1;
     }
-    message->msg_flags = 0;
     if (datagram_len > buflen) {
       message->msg_flags |= MSG_TRUNC;
     }
@@ -1590,6 +1631,12 @@ lwip_socket(int domain, int type, int protocol)
                  DEFAULT_SOCKET_EVENTCB);
     LWIP_DEBUGF(SOCKETS_DEBUG, ("lwip_socket(%s, SOCK_DGRAM, %d) = ",
                                  domain == PF_INET ? "PF_INET" : "UNKNOWN", protocol));
+#if LWIP_NETBUF_RECVINFO
+    if (conn) {
+      /* netconn layer enables pktinfo by default, sockets default to off */
+      conn->flags &= ~NETCONN_FLAG_PKTINFO;
+    }
+#endif /* LWIP_NETBUF_RECVINFO */
     break;
   case SOCK_STREAM:
     conn = netconn_new_with_callback(DOMAIN_TO_NETCONN_TYPE(domain, NETCONN_TCP), DEFAULT_SOCKET_EVENTCB);
@@ -2896,6 +2943,16 @@ lwip_setsockopt_impl(int s, int level, int optname, const void *optval, socklen_
       LWIP_DEBUGF(SOCKETS_DEBUG, ("lwip_setsockopt(%d, IPPROTO_IP, IP_TOS, ..)-> %d\n",
                   s, sock->conn->pcb.ip->tos));
       break;
+#if LWIP_NETBUF_RECVINFO
+    case IP_PKTINFO:
+      LWIP_SOCKOPT_CHECK_OPTLEN_CONN_PCB_TYPE(sock, optlen, int, NETCONN_UDP);
+      if (*(const int*)optval) {
+        sock->conn->flags |= NETCONN_FLAG_PKTINFO;
+      } else {
+        sock->conn->flags &= ~NETCONN_FLAG_PKTINFO;
+      }
+      break;
+#endif /* LWIP_NETBUF_RECVINFO */
 #if LWIP_IPV4 && LWIP_MULTICAST_TX_OPTIONS
     case IP_MULTICAST_TTL:
       LWIP_SOCKOPT_CHECK_OPTLEN_CONN_PCB_TYPE(sock, optlen, u8_t, NETCONN_UDP);

+ 4 - 0
src/include/lwip/api.h

@@ -81,6 +81,10 @@ extern "C" {
     dual-stack usage by default. */
 #define NETCONN_FLAG_IPV6_V6ONLY              0x20
 #endif /* LWIP_IPV6 */
+#if LWIP_NETBUF_RECVINFO
+/** Received packet info will be recorded for this netconn */
+#define NETCONN_FLAG_PKTINFO                  0x40
+#endif /* LWIP_NETBUF_RECVINFO */
 
 
 /* Helpers to process several netconn_types by the same code */

+ 0 - 2
src/include/lwip/netbuf.h

@@ -62,9 +62,7 @@ struct netbuf {
   ip_addr_t addr;
   u16_t port;
 #if LWIP_NETBUF_RECVINFO || LWIP_CHECKSUM_ON_COPY
-#if LWIP_CHECKSUM_ON_COPY
   u8_t flags;
-#endif /* LWIP_CHECKSUM_ON_COPY */
   u16_t toport_chksum;
 #if LWIP_NETBUF_RECVINFO
   ip_addr_t toaddr;

+ 44 - 0
src/include/lwip/sockets.h

@@ -135,6 +135,42 @@ struct msghdr {
 #define MSG_TRUNC   0x04
 #define MSG_CTRUNC  0x08
 
+/* RFC 3542, Section 20: Ancillary Data */
+struct cmsghdr {
+  socklen_t  cmsg_len;   /* number of bytes, including header */
+  int        cmsg_level; /* originating protocol */
+  int        cmsg_type;  /* protocol-specific type */
+};
+/* Data section follows header and possible padding, typically referred to as
+      unsigned char cmsg_data[]; */
+
+/* cmsg header/data alignment */
+#define ALIGN_H(size) LWIP_MEM_ALIGN_SIZE(size)
+#define ALIGN_D(size) LWIP_MEM_ALIGN_SIZE(size)
+
+#define CMSG_FIRSTHDR(mhdr) \
+          ((mhdr)->msg_controllen >= sizeof(struct cmsghdr) ? \
+           (struct cmsghdr *)(mhdr)->msg_control : \
+           (struct cmsghdr *)NULL)
+
+#define CMSG_NXTHDR(mhdr, cmsg) \
+        (((cmsg) == NULL) ? CMSG_FIRSTHDR(mhdr) : \
+         (((u8_t *)(cmsg) + ALIGN_H((cmsg)->cmsg_len) \
+                            + ALIGN_D(sizeof(struct cmsghdr)) > \
+           (u8_t *)((mhdr)->msg_control) + (mhdr)->msg_controllen) ? \
+          (struct cmsghdr *)NULL : \
+          (struct cmsghdr *)((u8_t *)(cmsg) + \
+                                      ALIGN_H((cmsg)->cmsg_len))))
+
+#define CMSG_DATA(cmsg) ((u8_t *)(cmsg) + \
+                         ALIGN_D(sizeof(struct cmsghdr)))
+
+#define CMSG_SPACE(length) (ALIGN_D(sizeof(struct cmsghdr)) + \
+                            ALIGN_H(length))
+
+#define CMSG_LEN(length) (ALIGN_D(sizeof(struct cmsghdr)) + \
+                           length)
+
 /* Socket protocol types (TCP/UDP/RAW) */
 #define SOCK_STREAM     1
 #define SOCK_DGRAM      2
@@ -221,6 +257,7 @@ struct linger {
  */
 #define IP_TOS             1
 #define IP_TTL             2
+#define IP_PKTINFO         8
 
 #if LWIP_TCP
 /*
@@ -272,6 +309,13 @@ typedef struct ip_mreq {
 } ip_mreq;
 #endif /* LWIP_IGMP */
 
+#if LWIP_IPV4
+struct in_pktinfo {
+  unsigned int   ipi_ifindex;  /* Interface index */
+  struct in_addr ipi_addr;     /* Destination (from header) address */
+};
+#endif /* LWIP_IPV4 */
+
 /*
  * The Type of Service provides an indication of the abstract
  * parameters of the quality of service desired.  These parameters are

+ 91 - 0
test/unit/api/test_sockets.c

@@ -468,12 +468,103 @@ static void test_sockets_msgapi_udp(int domain)
   fail_unless(ret == 0);
 }
 
+#if LWIP_IPV4
+static void test_sockets_msgapi_cmsg(int domain)
+{
+  int s, ret, enable;
+  struct sockaddr_storage addr_storage;
+  socklen_t addr_size;
+  struct iovec iov;
+  struct msghdr msg;
+  struct cmsghdr *cmsg;
+  struct in_pktinfo *pktinfo;
+  u8_t rcv_buf[4];
+  u8_t snd_buf[4] = {0xDE, 0xAD, 0xBE, 0xEF};
+  u8_t cmsg_buf[CMSG_SPACE(sizeof(struct in_pktinfo))];
+
+  test_sockets_init_loopback_addr(domain, &addr_storage, &addr_size);
+
+  s = test_sockets_alloc_socket_nonblocking(domain, SOCK_DGRAM);
+  fail_unless(s >= 0);
+
+  ret = lwip_bind(s, (struct sockaddr*)&addr_storage, addr_size);
+  fail_unless(ret == 0);
+
+  /* Update addr with epehermal port */
+  ret = lwip_getsockname(s, (struct sockaddr*)&addr_storage, &addr_size);
+  fail_unless(ret == 0);
+
+  enable = 1;
+  ret = lwip_setsockopt(s, IPPROTO_IP, IP_PKTINFO, &enable, sizeof(enable));
+  fail_unless(ret == 0);
+
+  /* Receive full message, including control message */
+  iov.iov_base = rcv_buf;
+  iov.iov_len = sizeof(rcv_buf);
+  msg.msg_control = cmsg_buf;
+  msg.msg_controllen = sizeof(cmsg_buf);
+  msg.msg_flags = 0;
+  msg.msg_iov = &iov;
+  msg.msg_iovlen = 1;
+  msg.msg_name = NULL;
+  msg.msg_namelen = 0;
+
+  memset(rcv_buf, 0, sizeof(rcv_buf));
+  ret = lwip_sendto(s, snd_buf, sizeof(snd_buf), 0, (struct sockaddr*)&addr_storage, addr_size);
+  fail_unless(ret == sizeof(snd_buf));
+  
+  tcpip_thread_poll_one();
+
+  ret = lwip_recvmsg(s, &msg, 0);
+  fail_unless(ret == sizeof(rcv_buf));
+  fail_unless(!memcmp(rcv_buf, snd_buf, sizeof(rcv_buf)));
+  
+  /* Verify message header */
+  cmsg = CMSG_FIRSTHDR(&msg);
+  fail_unless(cmsg);
+  fail_unless(cmsg->cmsg_len > 0);
+  fail_unless(cmsg->cmsg_level == IPPROTO_IP);
+  fail_unless(cmsg->cmsg_type = IP_PKTINFO);
+
+  /* Verify message data */
+  pktinfo = (struct in_pktinfo*)CMSG_DATA(cmsg);
+  /* We only have loopback interface enabled */
+  fail_unless(pktinfo->ipi_ifindex == 1);
+  fail_unless(pktinfo->ipi_addr.s_addr == PP_HTONL(INADDR_LOOPBACK));
+
+  /* Verify there are no additional messages */
+  cmsg = CMSG_NXTHDR(&msg, cmsg);
+  fail_unless(cmsg == NULL);
+
+  /* Send datagram again, testing truncation */
+  memset(rcv_buf, 0, sizeof(rcv_buf));
+  ret = lwip_sendto(s, snd_buf, sizeof(snd_buf), 0, (struct sockaddr*)&addr_storage, addr_size);
+  fail_unless(ret == sizeof(snd_buf));
+
+  tcpip_thread_poll_one();
+
+  msg.msg_controllen = 1;
+  msg.msg_flags = 0;
+  ret = lwip_recvmsg(s, &msg, 0);
+  fail_unless(ret == sizeof(rcv_buf));
+  fail_unless(!memcmp(rcv_buf, snd_buf, sizeof(rcv_buf)));
+  /* Ensure truncation was returned */
+  fail_unless(msg.msg_flags & MSG_CTRUNC);
+  /* Ensure no control messages were returned */
+  fail_unless(msg.msg_controllen == 0);
+
+  ret = lwip_close(s);
+  fail_unless(ret == 0);
+}
+#endif /* LWIP_IPV4 */
+
 START_TEST(test_sockets_msgapis)
 {
   LWIP_UNUSED_ARG(_i);
 #if LWIP_IPV4
   test_sockets_msgapi_udp(AF_INET);
   test_sockets_msgapi_tcp(AF_INET);
+  test_sockets_msgapi_cmsg(AF_INET);
 #endif
 #if LWIP_IPV6
   test_sockets_msgapi_udp(AF_INET6);

+ 1 - 0
test/unit/lwipopts.h

@@ -40,6 +40,7 @@
 #define LWIP_NETCONN                    !NO_SYS
 #define LWIP_SOCKET                     !NO_SYS
 #define LWIP_NETCONN_FULLDUPLEX         LWIP_SOCKET
+#define LWIP_NETBUF_RECVINFO            1
 #define LWIP_HAVE_LOOPIF                1
 #define TCPIP_THREAD_TEST