IDFDUT.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. # Copyright 2015-2017 Espressif Systems (Shanghai) PTE LTD
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http:#www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """ DUT for IDF applications """
  15. import os
  16. import os.path
  17. import sys
  18. import re
  19. import functools
  20. import tempfile
  21. from serial.tools import list_ports
  22. import DUT
  23. try:
  24. import esptool
  25. except ImportError: # cheat and use IDF's copy of esptool if available
  26. idf_path = os.getenv("IDF_PATH")
  27. if not idf_path or not os.path.exists(idf_path):
  28. raise
  29. sys.path.insert(0, os.path.join(idf_path, "components", "esptool_py", "esptool"))
  30. import esptool
  31. class IDFToolError(OSError):
  32. pass
  33. def _uses_esptool(func):
  34. """ Suspend listener thread, connect with esptool,
  35. call target function with esptool instance,
  36. then resume listening for output
  37. """
  38. @functools.wraps(func)
  39. def handler(self, *args, **kwargs):
  40. self.stop_receive()
  41. settings = self.port_inst.get_settings()
  42. try:
  43. rom = esptool.ESP32ROM(self.port_inst)
  44. rom.connect('hard_reset')
  45. esp = rom.run_stub()
  46. ret = func(self, esp, *args, **kwargs)
  47. # do hard reset after use esptool
  48. esp.hard_reset()
  49. finally:
  50. # always need to restore port settings
  51. self.port_inst.apply_settings(settings)
  52. self.start_receive()
  53. return ret
  54. return handler
  55. class IDFDUT(DUT.SerialDUT):
  56. """ IDF DUT, extends serial with esptool methods
  57. (Becomes aware of IDFApp instance which holds app-specific data)
  58. """
  59. # /dev/ttyAMA0 port is listed in Raspberry Pi
  60. # /dev/tty.Bluetooth-Incoming-Port port is listed in Mac
  61. INVALID_PORT_PATTERN = re.compile(r"AMA|Bluetooth")
  62. # if need to erase NVS partition in start app
  63. ERASE_NVS = True
  64. def __init__(self, name, port, log_file, app, **kwargs):
  65. super(IDFDUT, self).__init__(name, port, log_file, app, **kwargs)
  66. @classmethod
  67. def get_mac(cls, app, port):
  68. """
  69. get MAC address via esptool
  70. :param app: application instance (to get tool)
  71. :param port: serial port as string
  72. :return: MAC address or None
  73. """
  74. try:
  75. esp = esptool.ESP32ROM(port)
  76. esp.connect()
  77. return esp.read_mac()
  78. except RuntimeError:
  79. return None
  80. finally:
  81. # do hard reset after use esptool
  82. esp.hard_reset()
  83. esp._port.close()
  84. @classmethod
  85. def confirm_dut(cls, port, app, **kwargs):
  86. return cls.get_mac(app, port) is not None
  87. @_uses_esptool
  88. def _try_flash(self, esp, erase_nvs, baud_rate):
  89. """
  90. Called by start_app() to try flashing at a particular baud rate.
  91. Structured this way so @_uses_esptool will reconnect each time
  92. """
  93. try:
  94. # note: opening here prevents us from having to seek back to 0 each time
  95. flash_files = [(offs, open(path, "rb")) for (offs, path) in self.app.flash_files]
  96. if erase_nvs:
  97. address = self.app.partition_table["nvs"]["offset"]
  98. size = self.app.partition_table["nvs"]["size"]
  99. nvs_file = tempfile.TemporaryFile()
  100. nvs_file.write(b'\xff' * size)
  101. nvs_file.seek(0)
  102. flash_files.append((int(address, 0), nvs_file))
  103. # fake flasher args object, this is a hack until
  104. # esptool Python API is improved
  105. class FlashArgs(object):
  106. def __init__(self, attributes):
  107. for key, value in attributes.items():
  108. self.__setattr__(key, value)
  109. flash_args = FlashArgs({
  110. 'flash_size': self.app.flash_settings["flash_size"],
  111. 'flash_mode': self.app.flash_settings["flash_mode"],
  112. 'flash_freq': self.app.flash_settings["flash_freq"],
  113. 'addr_filename': flash_files,
  114. 'no_stub': False,
  115. 'compress': True,
  116. 'verify': False,
  117. 'encrypt': False,
  118. })
  119. esp.change_baud(baud_rate)
  120. esptool.detect_flash_size(esp, flash_args)
  121. esptool.write_flash(esp, flash_args)
  122. finally:
  123. for (_, f) in flash_files:
  124. f.close()
  125. def start_app(self, erase_nvs=ERASE_NVS):
  126. """
  127. download and start app.
  128. :param: erase_nvs: whether erase NVS partition during flash
  129. :return: None
  130. """
  131. for baud_rate in [921600, 115200]:
  132. try:
  133. self._try_flash(erase_nvs, baud_rate)
  134. break
  135. except RuntimeError:
  136. continue
  137. else:
  138. raise IDFToolError()
  139. @_uses_esptool
  140. def reset(self, esp):
  141. """
  142. hard reset DUT
  143. :return: None
  144. """
  145. # decorator `_use_esptool` will do reset
  146. # so we don't need to do anything in this method
  147. pass
  148. @_uses_esptool
  149. def erase_partition(self, esp, partition):
  150. """
  151. :param partition: partition name to erase
  152. :return: None
  153. """
  154. raise NotImplementedError() # TODO: implement this
  155. # address = self.app.partition_table[partition]["offset"]
  156. size = self.app.partition_table[partition]["size"]
  157. # TODO can use esp.erase_region() instead of this, I think
  158. with open(".erase_partition.tmp", "wb") as f:
  159. f.write(chr(0xFF) * size)
  160. @_uses_esptool
  161. def dump_flush(self, esp, output_file, **kwargs):
  162. """
  163. dump flush
  164. :param output_file: output file name, if relative path, will use sdk path as base path.
  165. :keyword partition: partition name, dump the partition.
  166. ``partition`` is preferred than using ``address`` and ``size``.
  167. :keyword address: dump from address (need to be used with size)
  168. :keyword size: dump size (need to be used with address)
  169. :return: None
  170. """
  171. if os.path.isabs(output_file) is False:
  172. output_file = os.path.relpath(output_file, self.app.get_log_folder())
  173. if "partition" in kwargs:
  174. partition = self.app.partition_table[kwargs["partition"]]
  175. _address = partition["offset"]
  176. _size = partition["size"]
  177. elif "address" in kwargs and "size" in kwargs:
  178. _address = kwargs["address"]
  179. _size = kwargs["size"]
  180. else:
  181. raise IDFToolError("You must specify 'partition' or ('address' and 'size') to dump flash")
  182. content = esp.read_flash(_address, _size)
  183. with open(output_file, "wb") as f:
  184. f.write(content)
  185. @classmethod
  186. def list_available_ports(cls):
  187. ports = [x.device for x in list_ports.comports()]
  188. espport = os.getenv('ESPPORT')
  189. if not espport:
  190. # It's a little hard filter out invalid port with `serial.tools.list_ports.grep()`:
  191. # The check condition in `grep` is: `if r.search(port) or r.search(desc) or r.search(hwid)`.
  192. # This means we need to make all 3 conditions fail, to filter out the port.
  193. # So some part of the filters will not be straight forward to users.
  194. # And negative regular expression (`^((?!aa|bb|cc).)*$`) is not easy to understand.
  195. # Filter out invalid port by our own will be much simpler.
  196. return [x for x in ports if not cls.INVALID_PORT_PATTERN.search(x)]
  197. # On MacOs with python3.6: type of espport is already utf8
  198. if isinstance(espport, type(u'')):
  199. port_hint = espport
  200. else:
  201. port_hint = espport.decode('utf8')
  202. # If $ESPPORT is a valid port, make it appear first in the list
  203. if port_hint in ports:
  204. ports.remove(port_hint)
  205. return [port_hint] + ports
  206. # On macOS, user may set ESPPORT to /dev/tty.xxx while
  207. # pySerial lists only the corresponding /dev/cu.xxx port
  208. if sys.platform == 'darwin' and 'tty.' in port_hint:
  209. port_hint = port_hint.replace('tty.', 'cu.')
  210. if port_hint in ports:
  211. ports.remove(port_hint)
  212. return [port_hint] + ports
  213. return ports