app_test.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. from __future__ import print_function, unicode_literals
  2. import os
  3. import random
  4. import re
  5. import select
  6. import socket
  7. import ssl
  8. import string
  9. import subprocess
  10. import sys
  11. from threading import Event, Thread
  12. import paho.mqtt.client as mqtt
  13. import ttfw_idf
  14. DEFAULT_MSG_SIZE = 16
  15. def _path(f):
  16. return os.path.join(os.path.dirname(os.path.realpath(__file__)),f)
  17. def set_server_cert_cn(ip):
  18. arg_list = [
  19. ['openssl', 'req', '-out', _path('srv.csr'), '-key', _path('server.key'),'-subj', '/CN={}'.format(ip), '-new'],
  20. ['openssl', 'x509', '-req', '-in', _path('srv.csr'), '-CA', _path('ca.crt'),
  21. '-CAkey', _path('ca.key'), '-CAcreateserial', '-out', _path('srv.crt'), '-days', '360']]
  22. for args in arg_list:
  23. if subprocess.check_call(args) != 0:
  24. raise('openssl command {} failed'.format(args))
  25. def get_my_ip():
  26. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  27. try:
  28. # doesn't even have to be reachable
  29. s.connect(('10.255.255.255', 1))
  30. IP = s.getsockname()[0]
  31. except Exception:
  32. IP = '127.0.0.1'
  33. finally:
  34. s.close()
  35. return IP
  36. # Publisher class creating a python client to send/receive published data from esp-mqtt client
  37. class MqttPublisher:
  38. def __init__(self, dut, transport, qos, repeat, published, queue, publish_cfg, log_details=False):
  39. # instance variables used as parameters of the publish test
  40. self.event_stop_client = Event()
  41. self.sample_string = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(DEFAULT_MSG_SIZE))
  42. self.client = None
  43. self.dut = dut
  44. self.log_details = log_details
  45. self.repeat = repeat
  46. self.publish_cfg = publish_cfg
  47. self.publish_cfg['qos'] = qos
  48. self.publish_cfg['queue'] = queue
  49. self.publish_cfg['transport'] = transport
  50. # static variables used to pass options to and from static callbacks of paho-mqtt client
  51. MqttPublisher.event_client_connected = Event()
  52. MqttPublisher.event_client_got_all = Event()
  53. MqttPublisher.published = published
  54. MqttPublisher.event_client_connected.clear()
  55. MqttPublisher.event_client_got_all.clear()
  56. MqttPublisher.expected_data = self.sample_string * self.repeat
  57. def print_details(self, text):
  58. if self.log_details:
  59. print(text)
  60. def mqtt_client_task(self, client):
  61. while not self.event_stop_client.is_set():
  62. client.loop()
  63. # The callback for when the client receives a CONNACK response from the server (needs to be static)
  64. @staticmethod
  65. def on_connect(_client, _userdata, _flags, _rc):
  66. MqttPublisher.event_client_connected.set()
  67. # The callback for when a PUBLISH message is received from the server (needs to be static)
  68. @staticmethod
  69. def on_message(client, userdata, msg):
  70. payload = msg.payload.decode()
  71. if payload == MqttPublisher.expected_data:
  72. userdata += 1
  73. client.user_data_set(userdata)
  74. if userdata == MqttPublisher.published:
  75. MqttPublisher.event_client_got_all.set()
  76. def __enter__(self):
  77. qos = self.publish_cfg['qos']
  78. queue = self.publish_cfg['queue']
  79. transport = self.publish_cfg['transport']
  80. broker_host = self.publish_cfg['broker_host_' + transport]
  81. broker_port = self.publish_cfg['broker_port_' + transport]
  82. # Start the test
  83. self.print_details("PUBLISH TEST: transport:{}, qos:{}, sequence:{}, enqueue:{}, sample msg:'{}'"
  84. .format(transport, qos, MqttPublisher.published, queue, MqttPublisher.expected_data))
  85. try:
  86. if transport in ['ws', 'wss']:
  87. self.client = mqtt.Client(transport='websockets')
  88. else:
  89. self.client = mqtt.Client()
  90. self.client.on_connect = MqttPublisher.on_connect
  91. self.client.on_message = MqttPublisher.on_message
  92. self.client.user_data_set(0)
  93. if transport in ['ssl', 'wss']:
  94. self.client.tls_set(None, None, None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None)
  95. self.client.tls_insecure_set(True)
  96. self.print_details('Connecting...')
  97. self.client.connect(broker_host, broker_port, 60)
  98. except Exception:
  99. self.print_details('ENV_TEST_FAILURE: Unexpected error while connecting to broker {}'.format(broker_host))
  100. raise
  101. # Starting a py-client in a separate thread
  102. thread1 = Thread(target=self.mqtt_client_task, args=(self.client,))
  103. thread1.start()
  104. self.print_details('Connecting py-client to broker {}:{}...'.format(broker_host, broker_port))
  105. if not MqttPublisher.event_client_connected.wait(timeout=30):
  106. raise ValueError('ENV_TEST_FAILURE: Test script cannot connect to broker: {}'.format(broker_host))
  107. self.client.subscribe(self.publish_cfg['subscribe_topic'], qos)
  108. self.dut.write(' '.join(str(x) for x in (transport, self.sample_string, self.repeat, MqttPublisher.published, qos, queue)), eol='\n')
  109. try:
  110. # waiting till subscribed to defined topic
  111. self.dut.expect(re.compile(r'MQTT_EVENT_SUBSCRIBED'), timeout=30)
  112. for _ in range(MqttPublisher.published):
  113. self.client.publish(self.publish_cfg['publish_topic'], self.sample_string * self.repeat, qos)
  114. self.print_details('Publishing...')
  115. self.print_details('Checking esp-client received msg published from py-client...')
  116. self.dut.expect(re.compile(r'Correct pattern received exactly x times'), timeout=60)
  117. if not MqttPublisher.event_client_got_all.wait(timeout=60):
  118. raise ValueError('Not all data received from ESP32')
  119. print(' - all data received from ESP32')
  120. finally:
  121. self.event_stop_client.set()
  122. thread1.join()
  123. def __exit__(self, exc_type, exc_value, traceback):
  124. self.client.disconnect()
  125. self.event_stop_client.clear()
  126. # Simple server for mqtt over TLS connection
  127. class TlsServer:
  128. def __init__(self, port, client_cert=False, refuse_connection=False, use_alpn=False):
  129. self.port = port
  130. self.socket = socket.socket()
  131. self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  132. self.socket.settimeout(10.0)
  133. self.shutdown = Event()
  134. self.client_cert = client_cert
  135. self.refuse_connection = refuse_connection
  136. self.ssl_error = None
  137. self.use_alpn = use_alpn
  138. self.negotiated_protocol = None
  139. def __enter__(self):
  140. try:
  141. self.socket.bind(('', self.port))
  142. except socket.error as e:
  143. print('Bind failed:{}'.format(e))
  144. raise
  145. self.socket.listen(1)
  146. self.server_thread = Thread(target=self.run_server)
  147. self.server_thread.start()
  148. return self
  149. def __exit__(self, exc_type, exc_value, traceback):
  150. self.shutdown.set()
  151. self.server_thread.join()
  152. self.socket.close()
  153. if (self.conn is not None):
  154. self.conn.close()
  155. def get_last_ssl_error(self):
  156. return self.ssl_error
  157. def get_negotiated_protocol(self):
  158. return self.negotiated_protocol
  159. def run_server(self):
  160. context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
  161. if self.client_cert:
  162. context.verify_mode = ssl.CERT_REQUIRED
  163. context.load_verify_locations(cafile=_path('ca.crt'))
  164. context.load_cert_chain(certfile=_path('srv.crt'), keyfile=_path('server.key'))
  165. if self.use_alpn:
  166. context.set_alpn_protocols(['mymqtt', 'http/1.1'])
  167. self.socket = context.wrap_socket(self.socket, server_side=True)
  168. try:
  169. self.conn, address = self.socket.accept() # accept new connection
  170. self.socket.settimeout(10.0)
  171. print(' - connection from: {}'.format(address))
  172. if self.use_alpn:
  173. self.negotiated_protocol = self.conn.selected_alpn_protocol()
  174. print(' - negotiated_protocol: {}'.format(self.negotiated_protocol))
  175. self.handle_conn()
  176. except ssl.SSLError as e:
  177. self.conn = None
  178. self.ssl_error = str(e)
  179. print(' - SSLError: {}'.format(str(e)))
  180. def handle_conn(self):
  181. while not self.shutdown.is_set():
  182. r,w,e = select.select([self.conn], [], [], 1)
  183. try:
  184. if self.conn in r:
  185. self.process_mqtt_connect()
  186. except socket.error as err:
  187. print(' - error: {}'.format(err))
  188. raise
  189. def process_mqtt_connect(self):
  190. try:
  191. data = bytearray(self.conn.recv(1024))
  192. message = ''.join(format(x, '02x') for x in data)
  193. if message[0:16] == '101800044d515454':
  194. if self.refuse_connection is False:
  195. print(' - received mqtt connect, sending ACK')
  196. self.conn.send(bytearray.fromhex('20020000'))
  197. else:
  198. # injecting connection not authorized error
  199. print(' - received mqtt connect, sending NAK')
  200. self.conn.send(bytearray.fromhex('20020005'))
  201. else:
  202. raise Exception(' - error process_mqtt_connect unexpected connect received: {}'.format(message))
  203. finally:
  204. # stop the server after the connect message in happy flow, or if any exception occur
  205. self.shutdown.set()
  206. def connection_tests(dut, cases):
  207. ip = get_my_ip()
  208. set_server_cert_cn(ip)
  209. server_port = 2222
  210. def teardown_connection_suite():
  211. dut.write('conn teardown 0 0')
  212. def start_connection_case(case, desc):
  213. print('Starting {}: {}'.format(case, desc))
  214. case_id = cases[case]
  215. dut.write('conn {} {} {}'.format(ip, server_port, case_id))
  216. dut.expect('Test case:{} started'.format(case_id))
  217. return case_id
  218. for case in ['CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT', 'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT', 'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_DER_CERT']:
  219. # All these cases connect to the server with no server verification or with server only verification
  220. with TlsServer(server_port):
  221. test_nr = start_connection_case(case, 'default server - expect to connect normally')
  222. dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30)
  223. with TlsServer(server_port, refuse_connection=True):
  224. test_nr = start_connection_case(case, 'ssl shall connect, but mqtt sends connect refusal')
  225. dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
  226. dut.expect('MQTT ERROR: 0x5') # expecting 0x5 ... connection not authorized error
  227. with TlsServer(server_port, client_cert=True) as s:
  228. test_nr = start_connection_case(case, 'server with client verification - handshake error since client presents no client certificate')
  229. dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
  230. dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (PEER_DID_NOT_RETURN_A_CERTIFICATE)
  231. if 'PEER_DID_NOT_RETURN_A_CERTIFICATE' not in s.get_last_ssl_error():
  232. raise('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))
  233. for case in ['CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', 'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD']:
  234. # These cases connect to server with both server and client verification (client key might be password protected)
  235. with TlsServer(server_port, client_cert=True):
  236. test_nr = start_connection_case(case, 'server with client verification - expect to connect normally')
  237. dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30)
  238. case = 'CONFIG_EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT'
  239. with TlsServer(server_port) as s:
  240. test_nr = start_connection_case(case, 'invalid server certificate on default server - expect ssl handshake error')
  241. dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
  242. dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (TLSV1_ALERT_UNKNOWN_CA)
  243. if 'alert unknown ca' not in s.get_last_ssl_error():
  244. raise Exception('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))
  245. case = 'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT'
  246. with TlsServer(server_port, client_cert=True) as s:
  247. test_nr = start_connection_case(case, 'Invalid client certificate on server with client verification - expect ssl handshake error')
  248. dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
  249. dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (CERTIFICATE_VERIFY_FAILED)
  250. if 'CERTIFICATE_VERIFY_FAILED' not in s.get_last_ssl_error():
  251. raise Exception('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))
  252. for case in ['CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT', 'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']:
  253. with TlsServer(server_port, use_alpn=True) as s:
  254. test_nr = start_connection_case(case, 'server with alpn - expect connect, check resolved protocol')
  255. dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30)
  256. if case == 'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT' and s.get_negotiated_protocol() is None:
  257. print(' - client with alpn off, no negotiated protocol: OK')
  258. elif case == 'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN' and s.get_negotiated_protocol() == 'mymqtt':
  259. print(' - client with alpn on, negotiated protocol resolved: OK')
  260. else:
  261. raise Exception('Unexpected negotiated protocol {}'.format(s.get_negotiated_protocol()))
  262. teardown_connection_suite()
  263. @ttfw_idf.idf_custom_test(env_tag='Example_WIFI', group='test-apps')
  264. def test_app_protocol_mqtt_publish_connect(env, extra_data):
  265. """
  266. steps:
  267. 1. join AP
  268. 2. connect to uri specified in the config
  269. 3. send and receive data
  270. """
  271. dut1 = env.get_dut('mqtt_publish_connect_test', 'tools/test_apps/protocols/mqtt/publish_connect_test')
  272. # check and log bin size
  273. binary_file = os.path.join(dut1.app.binary_path, 'mqtt_publish_connect_test.bin')
  274. bin_size = os.path.getsize(binary_file)
  275. ttfw_idf.log_performance('mqtt_publish_connect_test_bin_size', '{}KB'.format(bin_size // 1024))
  276. # Look for test case symbolic names and publish configs
  277. cases = {}
  278. publish_cfg = {}
  279. try:
  280. # Get connection test cases configuration: symbolic names for test cases
  281. for case in ['CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT',
  282. 'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT',
  283. 'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH',
  284. 'CONFIG_EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT',
  285. 'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_DER_CERT',
  286. 'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD',
  287. 'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT',
  288. 'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']:
  289. cases[case] = dut1.app.get_sdkconfig()[case]
  290. except Exception:
  291. print('ENV_TEST_FAILURE: Some mandatory CONNECTION test case not found in sdkconfig')
  292. raise
  293. dut1.start_app()
  294. esp_ip = dut1.expect(re.compile(r' IPv4 address: ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)'), timeout=30)
  295. print('Got IP={}'.format(esp_ip[0]))
  296. if not os.getenv('MQTT_SKIP_CONNECT_TEST'):
  297. connection_tests(dut1,cases)
  298. #
  299. # start publish tests only if enabled in the environment (for weekend tests only)
  300. if not os.getenv('MQTT_PUBLISH_TEST'):
  301. return
  302. # Get publish test configuration
  303. try:
  304. def get_host_port_from_dut(dut1, config_option):
  305. value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()[config_option])
  306. if value is None:
  307. return None, None
  308. return value.group(1), int(value.group(2))
  309. publish_cfg['publish_topic'] = dut1.app.get_sdkconfig()['CONFIG_EXAMPLE_SUBSCIBE_TOPIC'].replace('"','')
  310. publish_cfg['subscribe_topic'] = dut1.app.get_sdkconfig()['CONFIG_EXAMPLE_PUBLISH_TOPIC'].replace('"','')
  311. publish_cfg['broker_host_ssl'], publish_cfg['broker_port_ssl'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_SSL_URI')
  312. publish_cfg['broker_host_tcp'], publish_cfg['broker_port_tcp'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_TCP_URI')
  313. publish_cfg['broker_host_ws'], publish_cfg['broker_port_ws'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_WS_URI')
  314. publish_cfg['broker_host_wss'], publish_cfg['broker_port_wss'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_WSS_URI')
  315. except Exception:
  316. print('ENV_TEST_FAILURE: Some mandatory PUBLISH test case not found in sdkconfig')
  317. raise
  318. def start_publish_case(transport, qos, repeat, published, queue):
  319. print('Starting Publish test: transport:{}, qos:{}, nr_of_msgs:{}, msg_size:{}, enqueue:{}'
  320. .format(transport, qos, published, repeat * DEFAULT_MSG_SIZE, queue))
  321. with MqttPublisher(dut1, transport, qos, repeat, published, queue, publish_cfg):
  322. pass
  323. for qos in [0, 1, 2]:
  324. for transport in ['tcp', 'ssl', 'ws', 'wss']:
  325. for q in [0, 1]:
  326. if publish_cfg['broker_host_' + transport] is None:
  327. print('Skipping transport: {}...'.format(transport))
  328. continue
  329. start_publish_case(transport, qos, 0, 5, q)
  330. start_publish_case(transport, qos, 2, 5, q)
  331. start_publish_case(transport, qos, 50, 1, q)
  332. start_publish_case(transport, qos, 10, 20, q)
  333. if __name__ == '__main__':
  334. test_app_protocol_mqtt_publish_connect(dut=ttfw_idf.ESP32QEMUDUT if sys.argv[1:] == ['qemu'] else ttfw_idf.ESP32DUT)