IDFDUT.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. # Copyright 2015-2017 Espressif Systems (Shanghai) PTE LTD
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http:#www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """ DUT for IDF applications """
  15. import os
  16. import os.path
  17. import sys
  18. import re
  19. import functools
  20. import tempfile
  21. from serial.tools import list_ports
  22. import DUT
  23. try:
  24. import esptool
  25. except ImportError: # cheat and use IDF's copy of esptool if available
  26. idf_path = os.getenv("IDF_PATH")
  27. if not idf_path or not os.path.exists(idf_path):
  28. raise
  29. sys.path.insert(0, os.path.join(idf_path, "components", "esptool_py", "esptool"))
  30. import esptool
  31. class IDFToolError(OSError):
  32. pass
  33. def _uses_esptool(func):
  34. """ Suspend listener thread, connect with esptool,
  35. call target function with esptool instance,
  36. then resume listening for output
  37. """
  38. @functools.wraps(func)
  39. def handler(self, *args, **kwargs):
  40. self.stop_receive()
  41. settings = self.port_inst.get_settings()
  42. try:
  43. rom = esptool.ESP32ROM(self.port_inst)
  44. rom.connect('hard_reset')
  45. esp = rom.run_stub()
  46. ret = func(self, esp, *args, **kwargs)
  47. # do hard reset after use esptool
  48. esp.hard_reset()
  49. finally:
  50. # always need to restore port settings
  51. self.port_inst.apply_settings(settings)
  52. self.start_receive()
  53. return ret
  54. return handler
  55. class IDFDUT(DUT.SerialDUT):
  56. """ IDF DUT, extends serial with esptool methods
  57. (Becomes aware of IDFApp instance which holds app-specific data)
  58. """
  59. # /dev/ttyAMA0 port is listed in Raspberry Pi
  60. # /dev/tty.Bluetooth-Incoming-Port port is listed in Mac
  61. INVALID_PORT_PATTERN = re.compile(r"AMA|Bluetooth")
  62. # if need to erase NVS partition in start app
  63. ERASE_NVS = True
  64. def __init__(self, name, port, log_file, app, **kwargs):
  65. super(IDFDUT, self).__init__(name, port, log_file, app, **kwargs)
  66. @classmethod
  67. def get_mac(cls, app, port):
  68. """
  69. get MAC address via esptool
  70. :param app: application instance (to get tool)
  71. :param port: serial port as string
  72. :return: MAC address or None
  73. """
  74. try:
  75. esp = esptool.ESP32ROM(port)
  76. esp.connect()
  77. return esp.read_mac()
  78. except RuntimeError:
  79. return None
  80. finally:
  81. # do hard reset after use esptool
  82. esp.hard_reset()
  83. esp._port.close()
  84. @classmethod
  85. def confirm_dut(cls, port, app, **kwargs):
  86. return cls.get_mac(app, port) is not None
  87. @_uses_esptool
  88. def _try_flash(self, esp, erase_nvs, baud_rate):
  89. """
  90. Called by start_app() to try flashing at a particular baud rate.
  91. Structured this way so @_uses_esptool will reconnect each time
  92. """
  93. try:
  94. # note: opening here prevents us from having to seek back to 0 each time
  95. flash_files = [(offs, open(path, "rb")) for (offs, path) in self.app.flash_files]
  96. if erase_nvs:
  97. address = self.app.partition_table["nvs"]["offset"]
  98. size = self.app.partition_table["nvs"]["size"]
  99. nvs_file = tempfile.TemporaryFile()
  100. nvs_file.write(b'\xff' * size)
  101. nvs_file.seek(0)
  102. flash_files.append((int(address, 0), nvs_file))
  103. # fake flasher args object, this is a hack until
  104. # esptool Python API is improved
  105. class FlashArgs(object):
  106. def __init__(self, attributes):
  107. for key, value in attributes.items():
  108. self.__setattr__(key, value)
  109. flash_args = FlashArgs({
  110. 'flash_size': self.app.flash_settings["flash_size"],
  111. 'flash_mode': self.app.flash_settings["flash_mode"],
  112. 'flash_freq': self.app.flash_settings["flash_freq"],
  113. 'addr_filename': flash_files,
  114. 'no_stub': False,
  115. 'compress': True,
  116. 'verify': False,
  117. 'encrypt': False,
  118. 'erase_all': False,
  119. })
  120. esp.change_baud(baud_rate)
  121. esptool.detect_flash_size(esp, flash_args)
  122. esptool.write_flash(esp, flash_args)
  123. finally:
  124. for (_, f) in flash_files:
  125. f.close()
  126. def start_app(self, erase_nvs=ERASE_NVS):
  127. """
  128. download and start app.
  129. :param: erase_nvs: whether erase NVS partition during flash
  130. :return: None
  131. """
  132. for baud_rate in [921600, 115200]:
  133. try:
  134. self._try_flash(erase_nvs, baud_rate)
  135. break
  136. except RuntimeError:
  137. continue
  138. else:
  139. raise IDFToolError()
  140. @_uses_esptool
  141. def reset(self, esp):
  142. """
  143. hard reset DUT
  144. :return: None
  145. """
  146. # decorator `_use_esptool` will do reset
  147. # so we don't need to do anything in this method
  148. pass
  149. @_uses_esptool
  150. def erase_partition(self, esp, partition):
  151. """
  152. :param partition: partition name to erase
  153. :return: None
  154. """
  155. raise NotImplementedError() # TODO: implement this
  156. # address = self.app.partition_table[partition]["offset"]
  157. size = self.app.partition_table[partition]["size"]
  158. # TODO can use esp.erase_region() instead of this, I think
  159. with open(".erase_partition.tmp", "wb") as f:
  160. f.write(chr(0xFF) * size)
  161. @_uses_esptool
  162. def dump_flush(self, esp, output_file, **kwargs):
  163. """
  164. dump flush
  165. :param output_file: output file name, if relative path, will use sdk path as base path.
  166. :keyword partition: partition name, dump the partition.
  167. ``partition`` is preferred than using ``address`` and ``size``.
  168. :keyword address: dump from address (need to be used with size)
  169. :keyword size: dump size (need to be used with address)
  170. :return: None
  171. """
  172. if os.path.isabs(output_file) is False:
  173. output_file = os.path.relpath(output_file, self.app.get_log_folder())
  174. if "partition" in kwargs:
  175. partition = self.app.partition_table[kwargs["partition"]]
  176. _address = partition["offset"]
  177. _size = partition["size"]
  178. elif "address" in kwargs and "size" in kwargs:
  179. _address = kwargs["address"]
  180. _size = kwargs["size"]
  181. else:
  182. raise IDFToolError("You must specify 'partition' or ('address' and 'size') to dump flash")
  183. content = esp.read_flash(_address, _size)
  184. with open(output_file, "wb") as f:
  185. f.write(content)
  186. @classmethod
  187. def list_available_ports(cls):
  188. ports = [x.device for x in list_ports.comports()]
  189. espport = os.getenv('ESPPORT')
  190. if not espport:
  191. # It's a little hard filter out invalid port with `serial.tools.list_ports.grep()`:
  192. # The check condition in `grep` is: `if r.search(port) or r.search(desc) or r.search(hwid)`.
  193. # This means we need to make all 3 conditions fail, to filter out the port.
  194. # So some part of the filters will not be straight forward to users.
  195. # And negative regular expression (`^((?!aa|bb|cc).)*$`) is not easy to understand.
  196. # Filter out invalid port by our own will be much simpler.
  197. return [x for x in ports if not cls.INVALID_PORT_PATTERN.search(x)]
  198. # On MacOs with python3.6: type of espport is already utf8
  199. if isinstance(espport, type(u'')):
  200. port_hint = espport
  201. else:
  202. port_hint = espport.decode('utf8')
  203. # If $ESPPORT is a valid port, make it appear first in the list
  204. if port_hint in ports:
  205. ports.remove(port_hint)
  206. return [port_hint] + ports
  207. # On macOS, user may set ESPPORT to /dev/tty.xxx while
  208. # pySerial lists only the corresponding /dev/cu.xxx port
  209. if sys.platform == 'darwin' and 'tty.' in port_hint:
  210. port_hint = port_hint.replace('tty.', 'cu.')
  211. if port_hint in ports:
  212. ports.remove(port_hint)
  213. return [port_hint] + ports
  214. return ports