app_test.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. from __future__ import print_function
  2. from __future__ import unicode_literals
  3. import re
  4. import os
  5. import socket
  6. import select
  7. import subprocess
  8. from threading import Thread, Event
  9. import ttfw_idf
  10. import ssl
  11. def _path(f):
  12. return os.path.join(os.path.dirname(os.path.realpath(__file__)),f)
  13. def set_server_cert_cn(ip):
  14. arg_list = [
  15. ['openssl', 'req', '-out', _path('srv.csr'), '-key', _path('server.key'),'-subj', "/CN={}".format(ip), '-new'],
  16. ['openssl', 'x509', '-req', '-in', _path('srv.csr'), '-CA', _path('ca.crt'),
  17. '-CAkey', _path('ca.key'), '-CAcreateserial', '-out', _path('srv.crt'), '-days', '360']]
  18. for args in arg_list:
  19. if subprocess.check_call(args) != 0:
  20. raise("openssl command {} failed".format(args))
  21. def get_my_ip():
  22. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  23. try:
  24. # doesn't even have to be reachable
  25. s.connect(('10.255.255.255', 1))
  26. IP = s.getsockname()[0]
  27. except Exception:
  28. IP = '127.0.0.1'
  29. finally:
  30. s.close()
  31. return IP
  32. # Simple server for mqtt over TLS connection
  33. class TlsServer:
  34. def __init__(self, port, client_cert=False, refuse_connection=False, use_alpn=False):
  35. self.port = port
  36. self.socket = socket.socket()
  37. self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  38. self.socket.settimeout(10.0)
  39. self.shutdown = Event()
  40. self.client_cert = client_cert
  41. self.refuse_connection = refuse_connection
  42. self.ssl_error = None
  43. self.use_alpn = use_alpn
  44. self.negotiated_protocol = None
  45. def __enter__(self):
  46. try:
  47. self.socket.bind(('', self.port))
  48. except socket.error as e:
  49. print("Bind failed:{}".format(e))
  50. raise
  51. self.socket.listen(1)
  52. self.server_thread = Thread(target=self.run_server)
  53. self.server_thread.start()
  54. return self
  55. def __exit__(self, exc_type, exc_value, traceback):
  56. self.shutdown.set()
  57. self.server_thread.join()
  58. self.socket.close()
  59. if (self.conn is not None):
  60. self.conn.close()
  61. def get_last_ssl_error(self):
  62. return self.ssl_error
  63. def get_negotiated_protocol(self):
  64. return self.negotiated_protocol
  65. def run_server(self):
  66. context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
  67. if self.client_cert:
  68. context.verify_mode = ssl.CERT_REQUIRED
  69. context.load_verify_locations(cafile=_path("ca.crt"))
  70. context.load_cert_chain(certfile=_path("srv.crt"), keyfile=_path("server.key"))
  71. if self.use_alpn:
  72. context.set_alpn_protocols(["mymqtt", "http/1.1"])
  73. self.socket = context.wrap_socket(self.socket, server_side=True)
  74. try:
  75. self.conn, address = self.socket.accept() # accept new connection
  76. self.socket.settimeout(10.0)
  77. print(" - connection from: {}".format(address))
  78. if self.use_alpn:
  79. self.negotiated_protocol = self.conn.selected_alpn_protocol()
  80. print(" - negotiated_protocol: {}".format(self.negotiated_protocol))
  81. self.handle_conn()
  82. except ssl.SSLError as e:
  83. self.conn = None
  84. self.ssl_error = str(e)
  85. print(" - SSLError: {}".format(str(e)))
  86. def handle_conn(self):
  87. while not self.shutdown.is_set():
  88. r,w,e = select.select([self.conn], [], [], 1)
  89. try:
  90. if self.conn in r:
  91. self.process_mqtt_connect()
  92. except socket.error as err:
  93. print(" - error: {}".format(err))
  94. raise
  95. def process_mqtt_connect(self):
  96. try:
  97. data = bytearray(self.conn.recv(1024))
  98. message = ''.join(format(x, '02x') for x in data)
  99. if message[0:16] == '101800044d515454':
  100. if self.refuse_connection is False:
  101. print(" - received mqtt connect, sending ACK")
  102. self.conn.send(bytearray.fromhex("20020000"))
  103. else:
  104. # injecting connection not authorized error
  105. print(" - received mqtt connect, sending NAK")
  106. self.conn.send(bytearray.fromhex("20020005"))
  107. else:
  108. raise Exception(" - error process_mqtt_connect unexpected connect received: {}".format(message))
  109. finally:
  110. # stop the server after the connect message in happy flow, or if any exception occur
  111. self.shutdown.set()
  112. @ttfw_idf.idf_custom_test(env_tag="Example_WIFI", group="test-apps")
  113. def test_app_protocol_mqtt_publish_connect(env, extra_data):
  114. """
  115. steps:
  116. 1. join AP
  117. 2. connect to uri specified in the config
  118. 3. send and receive data
  119. """
  120. dut1 = env.get_dut("mqtt_publish_connect_test", "tools/test_apps/protocols/mqtt/publish_connect_test", dut_class=ttfw_idf.ESP32DUT)
  121. # check and log bin size
  122. binary_file = os.path.join(dut1.app.binary_path, "mqtt_publish_connect_test.bin")
  123. bin_size = os.path.getsize(binary_file)
  124. ttfw_idf.log_performance("mqtt_publish_connect_test_bin_size", "{}KB".format(bin_size // 1024))
  125. ttfw_idf.check_performance("mqtt_publish_connect_test_bin_size_vin_size", bin_size // 1024, dut1.TARGET)
  126. # Look for test case symbolic names
  127. cases = {}
  128. try:
  129. for i in ["CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT",
  130. "CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT",
  131. "CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH",
  132. "CONFIG_EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT",
  133. "CONFIG_EXAMPLE_CONNECT_CASE_SERVER_DER_CERT",
  134. "CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD",
  135. "CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT",
  136. "CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN"]:
  137. cases[i] = dut1.app.get_sdkconfig()[i]
  138. except Exception:
  139. print('ENV_TEST_FAILURE: Some mandatory test case not found in sdkconfig')
  140. raise
  141. dut1.start_app()
  142. esp_ip = dut1.expect(re.compile(r" IPv4 address: ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)"), timeout=30)
  143. print("Got IP={}".format(esp_ip[0]))
  144. #
  145. # start connection test
  146. ip = get_my_ip()
  147. set_server_cert_cn(ip)
  148. server_port = 2222
  149. def start_case(case, desc):
  150. print("Starting {}: {}".format(case, desc))
  151. case_id = cases[case]
  152. dut1.write("conn {} {} {}".format(ip, server_port, case_id))
  153. dut1.expect("Test case:{} started".format(case_id))
  154. return case_id
  155. for case in ["CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT", "CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT", "CONFIG_EXAMPLE_CONNECT_CASE_SERVER_DER_CERT"]:
  156. # All these cases connect to the server with no server verification or with server only verification
  157. with TlsServer(server_port):
  158. test_nr = start_case(case, "default server - expect to connect normally")
  159. dut1.expect("MQTT_EVENT_CONNECTED: Test={}".format(test_nr), timeout=30)
  160. with TlsServer(server_port, refuse_connection=True):
  161. test_nr = start_case(case, "ssl shall connect, but mqtt sends connect refusal")
  162. dut1.expect("MQTT_EVENT_ERROR: Test={}".format(test_nr), timeout=30)
  163. dut1.expect("MQTT ERROR: 0x5") # expecting 0x5 ... connection not authorized error
  164. with TlsServer(server_port, client_cert=True) as s:
  165. test_nr = start_case(case, "server with client verification - handshake error since client presents no client certificate")
  166. dut1.expect("MQTT_EVENT_ERROR: Test={}".format(test_nr), timeout=30)
  167. dut1.expect("ESP-TLS ERROR: 0x8010") # expect ... handshake error (PEER_DID_NOT_RETURN_A_CERTIFICATE)
  168. if "PEER_DID_NOT_RETURN_A_CERTIFICATE" not in s.get_last_ssl_error():
  169. raise("Unexpected ssl error from the server {}".format(s.get_last_ssl_error()))
  170. for case in ["CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH", "CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD"]:
  171. # These cases connect to server with both server and client verification (client key might be password protected)
  172. with TlsServer(server_port, client_cert=True):
  173. test_nr = start_case(case, "server with client verification - expect to connect normally")
  174. dut1.expect("MQTT_EVENT_CONNECTED: Test={}".format(test_nr), timeout=30)
  175. case = "CONFIG_EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT"
  176. with TlsServer(server_port) as s:
  177. test_nr = start_case(case, "invalid server certificate on default server - expect ssl handshake error")
  178. dut1.expect("MQTT_EVENT_ERROR: Test={}".format(test_nr), timeout=30)
  179. dut1.expect("ESP-TLS ERROR: 0x8010") # expect ... handshake error (TLSV1_ALERT_UNKNOWN_CA)
  180. if "alert unknown ca" not in s.get_last_ssl_error():
  181. raise Exception("Unexpected ssl error from the server {}".format(s.get_last_ssl_error()))
  182. case = "CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT"
  183. with TlsServer(server_port, client_cert=True) as s:
  184. test_nr = start_case(case, "Invalid client certificate on server with client verification - expect ssl handshake error")
  185. dut1.expect("MQTT_EVENT_ERROR: Test={}".format(test_nr), timeout=30)
  186. dut1.expect("ESP-TLS ERROR: 0x8010") # expect ... handshake error (CERTIFICATE_VERIFY_FAILED)
  187. if "CERTIFICATE_VERIFY_FAILED" not in s.get_last_ssl_error():
  188. raise Exception("Unexpected ssl error from the server {}".format(s.get_last_ssl_error()))
  189. for case in ["CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT", "CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN"]:
  190. with TlsServer(server_port, use_alpn=True) as s:
  191. test_nr = start_case(case, "server with alpn - expect connect, check resolved protocol")
  192. dut1.expect("MQTT_EVENT_CONNECTED: Test={}".format(test_nr), timeout=30)
  193. if case == "CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT" and s.get_negotiated_protocol() is None:
  194. print(" - client with alpn off, no negotiated protocol: OK")
  195. elif case == "CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN" and s.get_negotiated_protocol() == "mymqtt":
  196. print(" - client with alpn on, negotiated protocol resolved: OK")
  197. else:
  198. raise Exception("Unexpected negotiated protocol {}".format(s.get_negotiated_protocol()))
  199. if __name__ == '__main__':
  200. test_app_protocol_mqtt_publish_connect()