example_test.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. from __future__ import print_function
  2. from __future__ import unicode_literals
  3. import re
  4. import os
  5. import sys
  6. import socket
  7. import select
  8. import hashlib
  9. import base64
  10. import queue
  11. import random
  12. import string
  13. from threading import Thread, Event
  14. try:
  15. import IDF
  16. except Exception:
  17. # this is a test case write with tiny-test-fw.
  18. # to run test cases outside tiny-test-fw,
  19. # we need to set environment variable `TEST_FW_PATH`,
  20. # then get and insert `TEST_FW_PATH` to sys path before import FW module
  21. test_fw_path = os.getenv("TEST_FW_PATH")
  22. if test_fw_path and test_fw_path not in sys.path:
  23. sys.path.insert(0, test_fw_path)
  24. import IDF
  25. def get_my_ip():
  26. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  27. try:
  28. # doesn't even have to be reachable
  29. s.connect(('10.255.255.255', 1))
  30. IP = s.getsockname()[0]
  31. except Exception:
  32. IP = '127.0.0.1'
  33. finally:
  34. s.close()
  35. return IP
  36. # Simple Websocket server for testing purposes
  37. class Websocket:
  38. HEADER_LEN = 6
  39. def __init__(self, port):
  40. self.port = port
  41. self.socket = socket.socket()
  42. self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  43. self.socket.settimeout(10.0)
  44. self.send_q = queue.Queue()
  45. self.shutdown = Event()
  46. def __enter__(self):
  47. try:
  48. self.socket.bind(('', self.port))
  49. except socket.error as e:
  50. print("Bind failed:{}".format(e))
  51. raise
  52. self.socket.listen(1)
  53. self.server_thread = Thread(target=self.run_server)
  54. self.server_thread.start()
  55. return self
  56. def __exit__(self, exc_type, exc_value, traceback):
  57. self.shutdown.set()
  58. self.server_thread.join()
  59. self.socket.close()
  60. self.conn.close()
  61. def run_server(self):
  62. self.conn, address = self.socket.accept() # accept new connection
  63. self.socket.settimeout(10.0)
  64. print("Connection from: {}".format(address))
  65. self.establish_connection()
  66. print("WS established")
  67. # Handle connection until client closes it, will echo any data received and send data from send_q queue
  68. self.handle_conn()
  69. def establish_connection(self):
  70. while not self.shutdown.is_set():
  71. try:
  72. # receive data stream. it won't accept data packet greater than 1024 bytes
  73. data = self.conn.recv(1024).decode()
  74. if not data:
  75. # exit if data is not received
  76. raise
  77. if "Upgrade: websocket" in data and "Connection: Upgrade" in data:
  78. self.handshake(data)
  79. return
  80. except socket.error as err:
  81. print("Unable to establish a websocket connection: {}, {}".format(err))
  82. raise
  83. def handshake(self, data):
  84. # Magic string from RFC
  85. MAGIC_STRING = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
  86. headers = data.split("\r\n")
  87. for header in headers:
  88. if "Sec-WebSocket-Key" in header:
  89. client_key = header.split()[1]
  90. if client_key:
  91. resp_key = client_key + MAGIC_STRING
  92. resp_key = base64.standard_b64encode(hashlib.sha1(resp_key.encode()).digest())
  93. resp = "HTTP/1.1 101 Switching Protocols\r\n" + \
  94. "Upgrade: websocket\r\n" + \
  95. "Connection: Upgrade\r\n" + \
  96. "Sec-WebSocket-Accept: {}\r\n\r\n".format(resp_key.decode())
  97. self.conn.send(resp.encode())
  98. def handle_conn(self):
  99. while not self.shutdown.is_set():
  100. r,w,e = select.select([self.conn], [], [], 1)
  101. try:
  102. if self.conn in r:
  103. self.echo_data()
  104. if not self.send_q.empty():
  105. self._send_data_(self.send_q.get())
  106. except socket.error as err:
  107. print("Stopped echoing data: {}".format(err))
  108. raise
  109. def echo_data(self):
  110. header = bytearray(self.conn.recv(self.HEADER_LEN, socket.MSG_WAITALL))
  111. if not header:
  112. # exit if socket closed by peer
  113. return
  114. # Remove mask bit
  115. payload_len = ~(1 << 7) & header[1]
  116. payload = bytearray(self.conn.recv(payload_len, socket.MSG_WAITALL))
  117. if not payload:
  118. # exit if socket closed by peer
  119. return
  120. frame = header + payload
  121. decoded_payload = self.decode_frame(frame)
  122. print("Sending echo...")
  123. self._send_data_(decoded_payload)
  124. def _send_data_(self, data):
  125. frame = self.encode_frame(data)
  126. self.conn.send(frame)
  127. def send_data(self, data):
  128. self.send_q.put(data.encode())
  129. def decode_frame(self, frame):
  130. # Mask out MASK bit from payload length, this len is only valid for short messages (<126)
  131. payload_len = ~(1 << 7) & frame[1]
  132. mask = frame[2:self.HEADER_LEN]
  133. encrypted_payload = frame[self.HEADER_LEN:self.HEADER_LEN + payload_len]
  134. payload = bytearray()
  135. for i in range(payload_len):
  136. payload.append(encrypted_payload[i] ^ mask[i % 4])
  137. return payload
  138. def encode_frame(self, payload):
  139. # Set FIN = 1 and OP_CODE = 1 (text)
  140. header = (1 << 7) | (1 << 0)
  141. frame = bytearray([header])
  142. payload_len = len(payload)
  143. # If payload len is longer than 125 then the next 16 bits are used to encode length
  144. if payload_len > 125:
  145. frame.append(126)
  146. frame.append(payload_len >> 8)
  147. frame.append(0xFF & payload_len)
  148. else:
  149. frame.append(payload_len)
  150. frame += payload
  151. return frame
  152. def test_echo(dut):
  153. dut.expect("WEBSOCKET_EVENT_CONNECTED")
  154. for i in range(0, 10):
  155. dut.expect(re.compile(r"Received=hello (\d)"), timeout=30)
  156. print("All echos received")
  157. def test_recv_long_msg(dut, websocket, msg_len, repeats):
  158. send_msg = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(msg_len))
  159. for _ in range(repeats):
  160. websocket.send_data(send_msg)
  161. recv_msg = ''
  162. while len(recv_msg) < msg_len:
  163. # Filter out color encoding
  164. match = dut.expect(re.compile(r"Received=([a-zA-Z0-9]*).*\n"), timeout=30)[0]
  165. recv_msg += match
  166. if recv_msg == send_msg:
  167. print("Sent message and received message are equal")
  168. else:
  169. raise ValueError("DUT received string do not match sent string, \nexpected: {}\nwith length {}\
  170. \nreceived: {}\nwith length {}".format(send_msg, len(send_msg), recv_msg, len(recv_msg)))
  171. @IDF.idf_example_test(env_tag="Example_WIFI")
  172. def test_examples_protocol_websocket(env, extra_data):
  173. """
  174. steps:
  175. 1. join AP
  176. 2. connect to uri specified in the config
  177. 3. send and receive data
  178. """
  179. dut1 = env.get_dut("websocket", "examples/protocols/websocket")
  180. # check and log bin size
  181. binary_file = os.path.join(dut1.app.binary_path, "websocket-example.bin")
  182. bin_size = os.path.getsize(binary_file)
  183. IDF.log_performance("websocket_bin_size", "{}KB".format(bin_size // 1024))
  184. IDF.check_performance("websocket_bin_size", bin_size // 1024)
  185. try:
  186. if "CONFIG_WEBSOCKET_URI_FROM_STDIN" in dut1.app.get_sdkconfig():
  187. uri_from_stdin = True
  188. else:
  189. uri = dut1.app.get_sdkconfig()["CONFIG_WEBSOCKET_URI"].strip('"')
  190. uri_from_stdin = False
  191. except Exception:
  192. print('ENV_TEST_FAILURE: Cannot find uri settings in sdkconfig')
  193. raise
  194. # start test
  195. dut1.start_app()
  196. if uri_from_stdin:
  197. server_port = 4455
  198. with Websocket(server_port) as ws:
  199. uri = "ws://{}:{}".format(get_my_ip(), server_port)
  200. print("DUT connecting to {}".format(uri))
  201. dut1.expect("Please enter uri of websocket endpoint", timeout=30)
  202. dut1.write(uri)
  203. test_echo(dut1)
  204. # Message length should exceed DUT's buffer size to test fragmentation, default is 1024 byte
  205. test_recv_long_msg(dut1, ws, 2000, 3)
  206. else:
  207. print("DUT connecting to {}".format(uri))
  208. test_echo(dut1)
  209. if __name__ == '__main__':
  210. test_examples_protocol_websocket()