example_test.py 7.9 KB


  1. from __future__ import print_function
  2. from __future__ import unicode_literals
  3. import re
  4. import os
  5. import socket
  6. import select
  7. import hashlib
  8. import base64
  9. import queue
  10. import random
  11. import string
  12. from threading import Thread, Event
  13. import ttfw_idf
  14. def get_my_ip():
  15. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  16. try:
  17. # doesn't even have to be reachable
  18. s.connect(('10.255.255.255', 1))
  19. IP = s.getsockname()[0]
  20. except Exception:
  21. IP = '127.0.0.1'
  22. finally:
  23. s.close()
  24. return IP
  25. # Simple Websocket server for testing purposes
  26. class Websocket:
  27. HEADER_LEN = 6
  28. def __init__(self, port):
  29. self.port = port
  30. self.socket = socket.socket()
  31. self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  32. self.socket.settimeout(10.0)
  33. self.send_q = queue.Queue()
  34. self.shutdown = Event()
  35. def __enter__(self):
  36. try:
  37. self.socket.bind(('', self.port))
  38. except socket.error as e:
  39. print("Bind failed:{}".format(e))
  40. raise
  41. self.socket.listen(1)
  42. self.server_thread = Thread(target=self.run_server)
  43. self.server_thread.start()
  44. return self
  45. def __exit__(self, exc_type, exc_value, traceback):
  46. self.shutdown.set()
  47. self.server_thread.join()
  48. self.socket.close()
  49. self.conn.close()
  50. def run_server(self):
  51. self.conn, address = self.socket.accept() # accept new connection
  52. self.socket.settimeout(10.0)
  53. print("Connection from: {}".format(address))
  54. self.establish_connection()
  55. print("WS established")
  56. # Handle connection until client closes it, will echo any data received and send data from send_q queue
  57. self.handle_conn()
  58. def establish_connection(self):
  59. while not self.shutdown.is_set():
  60. try:
  61. # receive data stream. it won't accept data packet greater than 1024 bytes
  62. data = self.conn.recv(1024).decode()
  63. if not data:
  64. # exit if data is not received
  65. raise
  66. if "Upgrade: websocket" in data and "Connection: Upgrade" in data:
  67. self.handshake(data)
  68. return
  69. except socket.error as err:
  70. print("Unable to establish a websocket connection: {}, {}".format(err))
  71. raise
  72. def handshake(self, data):
  73. # Magic string from RFC
  74. MAGIC_STRING = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
  75. headers = data.split("\r\n")
  76. for header in headers:
  77. if "Sec-WebSocket-Key" in header:
  78. client_key = header.split()[1]
  79. if client_key:
  80. resp_key = client_key + MAGIC_STRING
  81. resp_key = base64.standard_b64encode(hashlib.sha1(resp_key.encode()).digest())
  82. resp = "HTTP/1.1 101 Switching Protocols\r\n" + \
  83. "Upgrade: websocket\r\n" + \
  84. "Connection: Upgrade\r\n" + \
  85. "Sec-WebSocket-Accept: {}\r\n\r\n".format(resp_key.decode())
  86. self.conn.send(resp.encode())
  87. def handle_conn(self):
  88. while not self.shutdown.is_set():
  89. r,w,e = select.select([self.conn], [], [], 1)
  90. try:
  91. if self.conn in r:
  92. self.echo_data()
  93. if not self.send_q.empty():
  94. self._send_data_(self.send_q.get())
  95. except socket.error as err:
  96. print("Stopped echoing data: {}".format(err))
  97. raise
  98. def echo_data(self):
  99. header = bytearray(self.conn.recv(self.HEADER_LEN, socket.MSG_WAITALL))
  100. if not header:
  101. # exit if socket closed by peer
  102. return
  103. # Remove mask bit
  104. payload_len = ~(1 << 7) & header[1]
  105. payload = bytearray(self.conn.recv(payload_len, socket.MSG_WAITALL))
  106. if not payload:
  107. # exit if socket closed by peer
  108. return
  109. frame = header + payload
  110. decoded_payload = self.decode_frame(frame)
  111. print("Sending echo...")
  112. self._send_data_(decoded_payload)
  113. def _send_data_(self, data):
  114. frame = self.encode_frame(data)
  115. self.conn.send(frame)
  116. def send_data(self, data):
  117. self.send_q.put(data.encode())
  118. def decode_frame(self, frame):
  119. # Mask out MASK bit from payload length, this len is only valid for short messages (<126)
  120. payload_len = ~(1 << 7) & frame[1]
  121. mask = frame[2:self.HEADER_LEN]
  122. encrypted_payload = frame[self.HEADER_LEN:self.HEADER_LEN + payload_len]
  123. payload = bytearray()
  124. for i in range(payload_len):
  125. payload.append(encrypted_payload[i] ^ mask[i % 4])
  126. return payload
  127. def encode_frame(self, payload):
  128. # Set FIN = 1 and OP_CODE = 1 (text)
  129. header = (1 << 7) | (1 << 0)
  130. frame = bytearray([header])
  131. payload_len = len(payload)
  132. # If payload len is longer than 125 then the next 16 bits are used to encode length
  133. if payload_len > 125:
  134. frame.append(126)
  135. frame.append(payload_len >> 8)
  136. frame.append(0xFF & payload_len)
  137. else:
  138. frame.append(payload_len)
  139. frame += payload
  140. return frame
  141. def test_echo(dut):
  142. dut.expect("WEBSOCKET_EVENT_CONNECTED")
  143. for i in range(0, 10):
  144. dut.expect(re.compile(r"Received=hello (\d)"), timeout=30)
  145. print("All echos received")
  146. def test_recv_long_msg(dut, websocket, msg_len, repeats):
  147. send_msg = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(msg_len))
  148. for _ in range(repeats):
  149. websocket.send_data(send_msg)
  150. recv_msg = ''
  151. while len(recv_msg) < msg_len:
  152. # Filter out color encoding
  153. match = dut.expect(re.compile(r"Received=([a-zA-Z0-9]*).*\n"), timeout=30)[0]
  154. recv_msg += match
  155. if recv_msg == send_msg:
  156. print("Sent message and received message are equal")
  157. else:
  158. raise ValueError("DUT received string do not match sent string, \nexpected: {}\nwith length {}\
  159. \nreceived: {}\nwith length {}".format(send_msg, len(send_msg), recv_msg, len(recv_msg)))
  160. @ttfw_idf.idf_example_test(env_tag="Example_WIFI")
  161. def test_examples_protocol_websocket(env, extra_data):
  162. """
  163. steps:
  164. 1. join AP
  165. 2. connect to uri specified in the config
  166. 3. send and receive data
  167. """
  168. dut1 = env.get_dut("websocket", "examples/protocols/websocket", dut_class=ttfw_idf.ESP32DUT)
  169. # check and log bin size
  170. binary_file = os.path.join(dut1.app.binary_path, "websocket-example.bin")
  171. bin_size = os.path.getsize(binary_file)
  172. ttfw_idf.log_performance("websocket_bin_size", "{}KB".format(bin_size // 1024))
  173. ttfw_idf.check_performance("websocket_bin_size", bin_size // 1024)
  174. try:
  175. if "CONFIG_WEBSOCKET_URI_FROM_STDIN" in dut1.app.get_sdkconfig():
  176. uri_from_stdin = True
  177. else:
  178. uri = dut1.app.get_sdkconfig()["CONFIG_WEBSOCKET_URI"].strip('"')
  179. uri_from_stdin = False
  180. except Exception:
  181. print('ENV_TEST_FAILURE: Cannot find uri settings in sdkconfig')
  182. raise
  183. # start test
  184. dut1.start_app()
  185. if uri_from_stdin:
  186. server_port = 4455
  187. with Websocket(server_port) as ws:
  188. uri = "ws://{}:{}".format(get_my_ip(), server_port)
  189. print("DUT connecting to {}".format(uri))
  190. dut1.expect("Please enter uri of websocket endpoint", timeout=30)
  191. dut1.write(uri)
  192. test_echo(dut1)
  193. # Message length should exceed DUT's buffer size to test fragmentation, default is 1024 byte
  194. test_recv_long_msg(dut1, ws, 2000, 3)
  195. else:
  196. print("DUT connecting to {}".format(uri))
  197. test_echo(dut1)
  198. if __name__ == '__main__':
  199. test_examples_protocol_websocket()