example_test.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. from __future__ import print_function
  2. from __future__ import unicode_literals
  3. import re
  4. import os
  5. import socket
  6. import select
  7. import hashlib
  8. import base64
  9. import queue
  10. import random
  11. import string
  12. from threading import Thread, Event
  13. import ttfw_idf
  14. def get_my_ip():
  15. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  16. try:
  17. # doesn't even have to be reachable
  18. s.connect(('10.255.255.255', 1))
  19. IP = s.getsockname()[0]
  20. except Exception:
  21. IP = '127.0.0.1'
  22. finally:
  23. s.close()
  24. return IP
  25. # Simple Websocket server for testing purposes
  26. class Websocket:
  27. HEADER_LEN = 6
  28. def __init__(self, port):
  29. self.port = port
  30. self.socket = socket.socket()
  31. self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  32. self.socket.settimeout(10.0)
  33. self.send_q = queue.Queue()
  34. self.shutdown = Event()
  35. self.conn = None
  36. def __enter__(self):
  37. try:
  38. self.socket.bind(('', self.port))
  39. except socket.error as e:
  40. print("Bind failed:{}".format(e))
  41. raise
  42. self.socket.listen(1)
  43. self.server_thread = Thread(target=self.run_server)
  44. self.server_thread.start()
  45. return self
  46. def __exit__(self, exc_type, exc_value, traceback):
  47. self.shutdown.set()
  48. self.server_thread.join()
  49. self.socket.close()
  50. if self.conn:
  51. self.conn.close()
  52. def run_server(self):
  53. self.conn, address = self.socket.accept() # accept new connection
  54. self.socket.settimeout(10.0)
  55. print("Connection from: {}".format(address))
  56. self.establish_connection()
  57. print("WS established")
  58. # Handle connection until client closes it, will echo any data received and send data from send_q queue
  59. self.handle_conn()
  60. def establish_connection(self):
  61. while not self.shutdown.is_set():
  62. try:
  63. # receive data stream. it won't accept data packet greater than 1024 bytes
  64. data = self.conn.recv(1024).decode()
  65. if not data:
  66. # exit if data is not received
  67. raise
  68. if "Upgrade: websocket" in data and "Connection: Upgrade" in data:
  69. self.handshake(data)
  70. return
  71. except socket.error as err:
  72. print("Unable to establish a websocket connection: {}, {}".format(err))
  73. raise
  74. def handshake(self, data):
  75. # Magic string from RFC
  76. MAGIC_STRING = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
  77. headers = data.split("\r\n")
  78. for header in headers:
  79. if "Sec-WebSocket-Key" in header:
  80. client_key = header.split()[1]
  81. if client_key:
  82. resp_key = client_key + MAGIC_STRING
  83. resp_key = base64.standard_b64encode(hashlib.sha1(resp_key.encode()).digest())
  84. resp = "HTTP/1.1 101 Switching Protocols\r\n" + \
  85. "Upgrade: websocket\r\n" + \
  86. "Connection: Upgrade\r\n" + \
  87. "Sec-WebSocket-Accept: {}\r\n\r\n".format(resp_key.decode())
  88. self.conn.send(resp.encode())
  89. def handle_conn(self):
  90. while not self.shutdown.is_set():
  91. r,w,e = select.select([self.conn], [], [], 1)
  92. try:
  93. if self.conn in r:
  94. self.echo_data()
  95. if not self.send_q.empty():
  96. self._send_data_(self.send_q.get())
  97. except socket.error as err:
  98. print("Stopped echoing data: {}".format(err))
  99. raise
  100. def echo_data(self):
  101. header = bytearray(self.conn.recv(self.HEADER_LEN, socket.MSG_WAITALL))
  102. if not header:
  103. # exit if socket closed by peer
  104. return
  105. # Remove mask bit
  106. payload_len = ~(1 << 7) & header[1]
  107. payload = bytearray(self.conn.recv(payload_len, socket.MSG_WAITALL))
  108. if not payload:
  109. # exit if socket closed by peer
  110. return
  111. frame = header + payload
  112. decoded_payload = self.decode_frame(frame)
  113. print("Sending echo...")
  114. self._send_data_(decoded_payload)
  115. def _send_data_(self, data):
  116. frame = self.encode_frame(data)
  117. self.conn.send(frame)
  118. def send_data(self, data):
  119. self.send_q.put(data.encode())
  120. def decode_frame(self, frame):
  121. # Mask out MASK bit from payload length, this len is only valid for short messages (<126)
  122. payload_len = ~(1 << 7) & frame[1]
  123. mask = frame[2:self.HEADER_LEN]
  124. encrypted_payload = frame[self.HEADER_LEN:self.HEADER_LEN + payload_len]
  125. payload = bytearray()
  126. for i in range(payload_len):
  127. payload.append(encrypted_payload[i] ^ mask[i % 4])
  128. return payload
  129. def encode_frame(self, payload):
  130. # Set FIN = 1 and OP_CODE = 1 (text)
  131. header = (1 << 7) | (1 << 0)
  132. frame = bytearray([header])
  133. payload_len = len(payload)
  134. # If payload len is longer than 125 then the next 16 bits are used to encode length
  135. if payload_len > 125:
  136. frame.append(126)
  137. frame.append(payload_len >> 8)
  138. frame.append(0xFF & payload_len)
  139. else:
  140. frame.append(payload_len)
  141. frame += payload
  142. return frame
  143. def test_echo(dut):
  144. dut.expect("WEBSOCKET_EVENT_CONNECTED")
  145. for i in range(0, 10):
  146. dut.expect(re.compile(r"Received=hello (\d)"), timeout=30)
  147. print("All echos received")
  148. def test_recv_long_msg(dut, websocket, msg_len, repeats):
  149. send_msg = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(msg_len))
  150. for _ in range(repeats):
  151. websocket.send_data(send_msg)
  152. recv_msg = ''
  153. while len(recv_msg) < msg_len:
  154. # Filter out color encoding
  155. match = dut.expect(re.compile(r"Received=([a-zA-Z0-9]*).*\n"), timeout=30)[0]
  156. recv_msg += match
  157. if recv_msg == send_msg:
  158. print("Sent message and received message are equal")
  159. else:
  160. raise ValueError("DUT received string do not match sent string, \nexpected: {}\nwith length {}\
  161. \nreceived: {}\nwith length {}".format(send_msg, len(send_msg), recv_msg, len(recv_msg)))
  162. @ttfw_idf.idf_example_test(env_tag="Example_WIFI")
  163. def test_examples_protocol_websocket(env, extra_data):
  164. """
  165. steps:
  166. 1. join AP
  167. 2. connect to uri specified in the config
  168. 3. send and receive data
  169. """
  170. dut1 = env.get_dut("websocket", "examples/protocols/websocket")
  171. # check and log bin size
  172. binary_file = os.path.join(dut1.app.binary_path, "websocket-example.bin")
  173. bin_size = os.path.getsize(binary_file)
  174. ttfw_idf.log_performance("websocket_bin_size", "{}KB".format(bin_size // 1024))
  175. ttfw_idf.check_performance("websocket_bin_size", bin_size // 1024)
  176. try:
  177. if "CONFIG_WEBSOCKET_URI_FROM_STDIN" in dut1.app.get_sdkconfig():
  178. uri_from_stdin = True
  179. else:
  180. uri = dut1.app.get_sdkconfig()["CONFIG_WEBSOCKET_URI"].strip('"')
  181. uri_from_stdin = False
  182. except Exception:
  183. print('ENV_TEST_FAILURE: Cannot find uri settings in sdkconfig')
  184. raise
  185. # start test
  186. dut1.start_app()
  187. if uri_from_stdin:
  188. server_port = 4455
  189. with Websocket(server_port) as ws:
  190. uri = "ws://{}:{}".format(get_my_ip(), server_port)
  191. print("DUT connecting to {}".format(uri))
  192. dut1.expect("Please enter uri of websocket endpoint", timeout=30)
  193. dut1.write(uri)
  194. test_echo(dut1)
  195. # Message length should exceed DUT's buffer size to test fragmentation, default is 1024 byte
  196. test_recv_long_msg(dut1, ws, 2000, 3)
  197. else:
  198. print("DUT connecting to {}".format(uri))
  199. test_echo(dut1)
  200. if __name__ == '__main__':
  201. test_examples_protocol_websocket()