generate.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. #!/usr/bin/env python3
  2. #
  3. # Copyright (c) 2022 Project CHIP Authors
  4. # All rights reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. import argparse
  19. import hashlib
  20. import logging
  21. import subprocess
  22. import sys
  23. from custom import (CertDeclaration, DacCert, DacPKey, Discriminator, HardwareVersion, HardwareVersionStr, IterationCount,
  24. ManufacturingDate, PaiCert, PartNumber, ProductId, ProductLabel, ProductName, ProductURL, Salt, SerialNum,
  25. SetupPasscode, StrArgument, UniqueId, VendorId, VendorName, Verifier)
  26. from default import InputArgument
  27. def set_logger():
  28. stdout_handler = logging.StreamHandler(stream=sys.stdout)
  29. logging.basicConfig(
  30. level=logging.DEBUG,
  31. format='[%(levelname)s] %(message)s',
  32. handlers=[stdout_handler]
  33. )
  34. class Spake2p:
  35. def __init__(self):
  36. pass
  37. def generate(self, args):
  38. params = self._generate_params(args)
  39. args.spake2p_verifier = Verifier(params["Verifier"])
  40. args.salt = Salt(params["Salt"])
  41. args.it = IterationCount(params["Iteration Count"])
  42. def _generate_params(self, args):
  43. cmd = [
  44. args.spake2p_path, "gen-verifier",
  45. "--iteration-count", str(args.it.val),
  46. "--salt", args.salt.encode(),
  47. "--pin-code", str(args.passcode.val),
  48. "--out", "-"
  49. ]
  50. out = subprocess.run(cmd, check=True, stdout=subprocess.PIPE).stdout
  51. out = out.decode("utf-8").splitlines()
  52. return dict(zip(out[0].split(','), out[1].split(',')))
  53. class KlvGenerator:
  54. def __init__(self, args):
  55. self.args = args
  56. self._validate_args()
  57. self.spake2p = Spake2p()
  58. if self.args.spake2p_verifier is None:
  59. self.spake2p.generate(self.args)
  60. self.args.dac_key.generate_private_key(self.args.dac_key_password)
  61. def _validate_args(self):
  62. if self.args.dac_key_password is None:
  63. logging.warning(
  64. "DAC Key password not provided. It means DAC Key is not protected."
  65. )
  66. str_args = [obj for key, obj in vars(self.args).items() if isinstance(obj, StrArgument)]
  67. for str_arg in str_args:
  68. logging.info("key: {} len: {} maxlen: {}".format(str_arg.key(), str_arg.length(), str_arg.max_length()))
  69. assert str_arg.length() <= str_arg.max_length()
  70. def generate(self):
  71. '''Return a list of (K, L, V) tuples.
  72. args is essentially a dict, so the entries are not ordered.
  73. Sort the objects to ensure the same order of KLV data every
  74. time (sorted by key), thus ensuring that SHA256 can be used
  75. correctly to compare two output binaries.
  76. The new list will contain only InputArgument objects, which
  77. generate a (K, L, V) tuple through output() method.
  78. '''
  79. data = list()
  80. data = [obj for key, obj in vars(self.args).items() if isinstance(obj, InputArgument)]
  81. data = [arg.output() for arg in sorted(data, key=lambda x: x.key())]
  82. return data
  83. def to_bin(self, klv, out, aes128_key):
  84. fullContent = bytearray()
  85. with open(out, "wb") as file:
  86. for entry in klv:
  87. fullContent += entry[0].to_bytes(1, "little")
  88. fullContent += entry[1].to_bytes(2, "little")
  89. fullContent += entry[2]
  90. size = len(fullContent)
  91. if (aes128_key is None):
  92. # Calculate 4 bytes of hashing
  93. hashing = hashlib.sha256(fullContent).hexdigest()
  94. hashing = hashing[0:8]
  95. logging.info("4 byte section hash (for integrity check): {}".format(hashing))
  96. # Add 4 bytes of hashing to generated binary to check for integrity
  97. fullContent = bytearray.fromhex(hashing) + fullContent
  98. # Add length of data to binary to know how to calculate SHA on embedded
  99. fullContent = size.to_bytes(4, "little") + fullContent
  100. # Add hash id
  101. hashId = bytearray.fromhex("CE47BA5E")
  102. hashId.reverse()
  103. fullContent = hashId + fullContent
  104. size = len(fullContent)
  105. logging.info("Size of final generated binary is: {} bytes".format(size))
  106. file.write(fullContent)
  107. else:
  108. # In case a aes128_key is given the data will be encrypted
  109. # Always add a padding to be 16 bytes aligned
  110. padding_len = size % 16
  111. padding_len = 16 - padding_len
  112. padding_bytes = bytearray(padding_len)
  113. logging.info("(Before padding) Size of generated binary is: {} bytes".format(size))
  114. fullContent += padding_bytes
  115. size = len(fullContent)
  116. logging.info("(After padding) Size of generated binary is: {} bytes".format(size))
  117. from Crypto.Cipher import AES
  118. cipher = AES.new(bytes.fromhex(aes128_key), AES.MODE_ECB)
  119. fullContentCipher = cipher.encrypt(fullContent)
  120. # Add 4 bytes of hashing to generated binary to check for integrity
  121. hashing = hashlib.sha256(fullContent).hexdigest()
  122. hashing = hashing[0:8]
  123. logging.info("4 byte section hash (for integrity check): {}".format(hashing))
  124. fullContentCipher = bytearray.fromhex(hashing) + fullContentCipher
  125. # Add length of data to binary to know how to calculate SHA on embedded
  126. fullContentCipher = size.to_bytes(4, "little") + fullContentCipher
  127. # Add hash id
  128. hashId = bytearray.fromhex("CE47BA5E")
  129. hashId.reverse()
  130. fullContentCipher = hashId.reverse() + fullContentCipher
  131. size = len(fullContentCipher)
  132. logging.info("Size of final generated binary is: {} bytes".format(size))
  133. file.write(fullContentCipher)
  134. out_hash = hashlib.sha256(fullContent).hexdigest()
  135. logging.info("SHA256 of generated binary: {}".format(out_hash))
  136. def main():
  137. set_logger()
  138. parser = argparse.ArgumentParser(description="NXP Factory Data Generator")
  139. optional = parser
  140. required = parser.add_argument_group("required arguments")
  141. required.add_argument("-i", "--it", required=True, type=IterationCount,
  142. help="[int | hex] Spake2 Iteration Counter")
  143. required.add_argument("-s", "--salt", required=True, type=Salt,
  144. help="[base64 str] Spake2 Salt")
  145. required.add_argument("-p", "--passcode", required=True, type=SetupPasscode,
  146. help="[int | hex] PASE session passcode")
  147. required.add_argument("-d", "--discriminator", required=True, type=Discriminator,
  148. help="[int | hex] BLE Pairing discriminator")
  149. required.add_argument("--vid", required=True, type=VendorId,
  150. help="[int | hex] Vendor Identifier (VID)")
  151. required.add_argument("--pid", required=True, type=ProductId,
  152. help="[int | hex] Product Identifier (PID)")
  153. required.add_argument("--vendor_name", required=True, type=VendorName,
  154. help="[str] Vendor Name")
  155. required.add_argument("--product_name", required=True, type=ProductName,
  156. help="[str] Product Name")
  157. required.add_argument("--hw_version", required=True, type=HardwareVersion,
  158. help="[int | hex] Hardware version as number")
  159. required.add_argument("--hw_version_str", required=True, type=HardwareVersionStr,
  160. help="[str] Hardware version as string")
  161. required.add_argument("--cert_declaration", required=True, type=CertDeclaration,
  162. help="[path] Path to Certification Declaration in DER format")
  163. required.add_argument("--dac_cert", required=True, type=DacCert,
  164. help="[path] Path to DAC certificate in DER format")
  165. required.add_argument("--dac_key", required=True, type=DacPKey,
  166. help="[path] Path to DAC key in DER format")
  167. required.add_argument("--pai_cert", required=True, type=PaiCert,
  168. help="[path] Path to PAI certificate in DER format")
  169. required.add_argument("--spake2p_path", required=True, type=str,
  170. help="[path] Path to spake2p tool")
  171. required.add_argument("--out", required=True, type=str,
  172. help="[path] Path to output binary")
  173. optional.add_argument("--dac_key_password", type=str,
  174. help="[path] Password to decode DAC Key if available")
  175. optional.add_argument("--spake2p_verifier", type=Verifier,
  176. help="[base64 str] Already generated spake2p verifier")
  177. optional.add_argument("--aes128_key",
  178. help="[hex] AES 128 bits key used to encrypt the whole dataset")
  179. optional.add_argument("--date", type=ManufacturingDate,
  180. help="[str] Manufacturing Date (YYYY-MM-DD)")
  181. optional.add_argument("--part_number", type=PartNumber,
  182. help="[str] PartNumber as String")
  183. optional.add_argument("--product_url", type=ProductURL,
  184. help="[str] ProductURL as String")
  185. optional.add_argument("--product_label", type=ProductLabel,
  186. help="[str] ProductLabel as String")
  187. optional.add_argument("--serial_num", type=SerialNum,
  188. help="[str] Serial Number")
  189. optional.add_argument("--unique_id", type=UniqueId,
  190. help="[str] Unique identifier for the device")
  191. args = parser.parse_args()
  192. klv = KlvGenerator(args)
  193. data = klv.generate()
  194. klv.to_bin(data, args.out, args.aes128_key)
  195. if __name__ == "__main__":
  196. main()