security1.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # Copyright 2018 Espressif Systems (Shanghai) PTE LTD
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # APIs for interpreting and creating protobuf packets for
  16. # protocomm endpoint with security type protocomm_security1
  17. from __future__ import print_function
  18. from future.utils import tobytes
  19. import utils
  20. import proto
  21. from .security import Security
  22. from cryptography.hazmat.backends import default_backend
  23. from cryptography.hazmat.primitives import hashes, serialization
  24. from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
  25. from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
  26. import session_pb2
  27. # Enum for state of protocomm_security1 FSM
  28. class security_state:
  29. REQUEST1 = 0
  30. RESPONSE1_REQUEST2 = 1
  31. RESPONSE2 = 2
  32. FINISHED = 3
  33. def xor(a, b):
  34. # XOR two inputs of type `bytes`
  35. ret = bytearray()
  36. # Decode the input bytes to strings
  37. a = a.decode('latin-1')
  38. b = b.decode('latin-1')
  39. for i in range(max(len(a), len(b))):
  40. # Convert the characters to corresponding 8-bit ASCII codes
  41. # then XOR them and store in bytearray
  42. ret.append(([0, ord(a[i])][i < len(a)]) ^ ([0, ord(b[i])][i < len(b)]))
  43. # Convert bytearray to bytes
  44. return bytes(ret)
  45. class Security1(Security):
  46. def __init__(self, pop, verbose):
  47. # Initialize state of the security1 FSM
  48. self.session_state = security_state.REQUEST1
  49. self.pop = tobytes(pop)
  50. self.verbose = verbose
  51. Security.__init__(self, self.security1_session)
  52. def security1_session(self, response_data):
  53. # protocomm security1 FSM which interprets/forms
  54. # protobuf packets according to present state of session
  55. if (self.session_state == security_state.REQUEST1):
  56. self.session_state = security_state.RESPONSE1_REQUEST2
  57. return self.setup0_request()
  58. if (self.session_state == security_state.RESPONSE1_REQUEST2):
  59. self.session_state = security_state.RESPONSE2
  60. self.setup0_response(response_data)
  61. return self.setup1_request()
  62. if (self.session_state == security_state.RESPONSE2):
  63. self.session_state = security_state.FINISHED
  64. self.setup1_response(response_data)
  65. return None
  66. else:
  67. print("Unexpected state")
  68. return None
  69. def __generate_key(self):
  70. # Generate private and public key pair for client
  71. self.client_private_key = X25519PrivateKey.generate()
  72. try:
  73. self.client_public_key = self.client_private_key.public_key().public_bytes(
  74. encoding=serialization.Encoding.Raw,
  75. format=serialization.PublicFormat.Raw)
  76. except TypeError:
  77. # backward compatible call for older cryptography library
  78. self.client_public_key = self.client_private_key.public_key().public_bytes()
  79. def _print_verbose(self, data):
  80. if (self.verbose):
  81. print("++++ " + data + " ++++")
  82. def setup0_request(self):
  83. # Form SessionCmd0 request packet using client public key
  84. setup_req = session_pb2.SessionData()
  85. setup_req.sec_ver = session_pb2.SecScheme1
  86. self.__generate_key()
  87. setup_req.sec1.sc0.client_pubkey = self.client_public_key
  88. self._print_verbose("Client Public Key:\t" + utils.str_to_hexstr(self.client_public_key.decode('latin-1')))
  89. return setup_req.SerializeToString().decode('latin-1')
  90. def setup0_response(self, response_data):
  91. # Interpret SessionResp0 response packet
  92. setup_resp = proto.session_pb2.SessionData()
  93. setup_resp.ParseFromString(tobytes(response_data))
  94. self._print_verbose("Security version:\t" + str(setup_resp.sec_ver))
  95. if setup_resp.sec_ver != session_pb2.SecScheme1:
  96. print("Incorrect sec scheme")
  97. exit(1)
  98. self.device_public_key = setup_resp.sec1.sr0.device_pubkey
  99. # Device random is the initialization vector
  100. device_random = setup_resp.sec1.sr0.device_random
  101. self._print_verbose("Device Public Key:\t" + utils.str_to_hexstr(self.device_public_key.decode('latin-1')))
  102. self._print_verbose("Device Random:\t" + utils.str_to_hexstr(device_random.decode('latin-1')))
  103. # Calculate Curve25519 shared key using Client private key and Device public key
  104. sharedK = self.client_private_key.exchange(X25519PublicKey.from_public_bytes(self.device_public_key))
  105. self._print_verbose("Shared Key:\t" + utils.str_to_hexstr(sharedK.decode('latin-1')))
  106. # If PoP is provided, XOR SHA256 of PoP with the previously
  107. # calculated Shared Key to form the actual Shared Key
  108. if len(self.pop) > 0:
  109. # Calculate SHA256 of PoP
  110. h = hashes.Hash(hashes.SHA256(), backend=default_backend())
  111. h.update(self.pop)
  112. digest = h.finalize()
  113. # XOR with and update Shared Key
  114. sharedK = xor(sharedK, digest)
  115. self._print_verbose("New Shared Key XORed with PoP:\t" + utils.str_to_hexstr(sharedK.decode('latin-1')))
  116. # Initialize the encryption engine with Shared Key and initialization vector
  117. cipher = Cipher(algorithms.AES(sharedK), modes.CTR(device_random), backend=default_backend())
  118. self.cipher = cipher.encryptor()
  119. def setup1_request(self):
  120. # Form SessionCmd1 request packet using encrypted device public key
  121. setup_req = proto.session_pb2.SessionData()
  122. setup_req.sec_ver = session_pb2.SecScheme1
  123. setup_req.sec1.msg = proto.sec1_pb2.Session_Command1
  124. # Encrypt device public key and attach to the request packet
  125. client_verify = self.cipher.update(self.device_public_key)
  126. self._print_verbose("Client Verify:\t" + utils.str_to_hexstr(client_verify.decode('latin-1')))
  127. setup_req.sec1.sc1.client_verify_data = client_verify
  128. return setup_req.SerializeToString().decode('latin-1')
  129. def setup1_response(self, response_data):
  130. # Interpret SessionResp1 response packet
  131. setup_resp = proto.session_pb2.SessionData()
  132. setup_resp.ParseFromString(tobytes(response_data))
  133. # Ensure security scheme matches
  134. if setup_resp.sec_ver == session_pb2.SecScheme1:
  135. # Read encrypyed device verify string
  136. device_verify = setup_resp.sec1.sr1.device_verify_data
  137. self._print_verbose("Device verify:\t" + utils.str_to_hexstr(device_verify.decode('latin-1')))
  138. # Decrypt the device verify string
  139. enc_client_pubkey = self.cipher.update(setup_resp.sec1.sr1.device_verify_data)
  140. self._print_verbose("Enc client pubkey:\t " + utils.str_to_hexstr(enc_client_pubkey.decode('latin-1')))
  141. # Match decryped string with client public key
  142. if enc_client_pubkey != self.client_public_key:
  143. print("Mismatch in device verify")
  144. return -2
  145. else:
  146. print("Unsupported security protocol")
  147. return -1
  148. def encrypt_data(self, data):
  149. return self.cipher.update(tobytes(data))
  150. def decrypt_data(self, data):
  151. return self.cipher.update(tobytes(data))