| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172 |
- # Copyright 2018 Espressif Systems (Shanghai) PTE LTD
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- # APIs for interpreting and creating protobuf packets for
- # protocomm endpoint with security type protocomm_security1
- from __future__ import print_function
- import proto
- import session_pb2
- import utils
- from cryptography.hazmat.backends import default_backend
- from cryptography.hazmat.primitives import hashes, serialization
- from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
- from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
- from future.utils import tobytes
- from .security import Security
- # Enum for state of protocomm_security1 FSM
- class security_state:
- REQUEST1 = 0
- RESPONSE1_REQUEST2 = 1
- RESPONSE2 = 2
- FINISHED = 3
- def xor(a, b):
- # XOR two inputs of type `bytes`
- ret = bytearray()
- # Decode the input bytes to strings
- a = a.decode('latin-1')
- b = b.decode('latin-1')
- for i in range(max(len(a), len(b))):
- # Convert the characters to corresponding 8-bit ASCII codes
- # then XOR them and store in bytearray
- ret.append(([0, ord(a[i])][i < len(a)]) ^ ([0, ord(b[i])][i < len(b)]))
- # Convert bytearray to bytes
- return bytes(ret)
- class Security1(Security):
- def __init__(self, pop, verbose):
- # Initialize state of the security1 FSM
- self.session_state = security_state.REQUEST1
- self.pop = tobytes(pop)
- self.verbose = verbose
- Security.__init__(self, self.security1_session)
- def security1_session(self, response_data):
- # protocomm security1 FSM which interprets/forms
- # protobuf packets according to present state of session
- if (self.session_state == security_state.REQUEST1):
- self.session_state = security_state.RESPONSE1_REQUEST2
- return self.setup0_request()
- if (self.session_state == security_state.RESPONSE1_REQUEST2):
- self.session_state = security_state.RESPONSE2
- self.setup0_response(response_data)
- return self.setup1_request()
- if (self.session_state == security_state.RESPONSE2):
- self.session_state = security_state.FINISHED
- self.setup1_response(response_data)
- return None
- else:
- print('Unexpected state')
- return None
- def __generate_key(self):
- # Generate private and public key pair for client
- self.client_private_key = X25519PrivateKey.generate()
- try:
- self.client_public_key = self.client_private_key.public_key().public_bytes(
- encoding=serialization.Encoding.Raw,
- format=serialization.PublicFormat.Raw)
- except TypeError:
- # backward compatible call for older cryptography library
- self.client_public_key = self.client_private_key.public_key().public_bytes()
- def _print_verbose(self, data):
- if (self.verbose):
- print('++++ ' + data + ' ++++')
- def setup0_request(self):
- # Form SessionCmd0 request packet using client public key
- setup_req = session_pb2.SessionData()
- setup_req.sec_ver = session_pb2.SecScheme1
- self.__generate_key()
- setup_req.sec1.sc0.client_pubkey = self.client_public_key
- self._print_verbose('Client Public Key:\t' + utils.str_to_hexstr(self.client_public_key.decode('latin-1')))
- return setup_req.SerializeToString().decode('latin-1')
- def setup0_response(self, response_data):
- # Interpret SessionResp0 response packet
- setup_resp = proto.session_pb2.SessionData()
- setup_resp.ParseFromString(tobytes(response_data))
- self._print_verbose('Security version:\t' + str(setup_resp.sec_ver))
- if setup_resp.sec_ver != session_pb2.SecScheme1:
- print('Incorrect sec scheme')
- exit(1)
- self.device_public_key = setup_resp.sec1.sr0.device_pubkey
- # Device random is the initialization vector
- device_random = setup_resp.sec1.sr0.device_random
- self._print_verbose('Device Public Key:\t' + utils.str_to_hexstr(self.device_public_key.decode('latin-1')))
- self._print_verbose('Device Random:\t' + utils.str_to_hexstr(device_random.decode('latin-1')))
- # Calculate Curve25519 shared key using Client private key and Device public key
- sharedK = self.client_private_key.exchange(X25519PublicKey.from_public_bytes(self.device_public_key))
- self._print_verbose('Shared Key:\t' + utils.str_to_hexstr(sharedK.decode('latin-1')))
- # If PoP is provided, XOR SHA256 of PoP with the previously
- # calculated Shared Key to form the actual Shared Key
- if len(self.pop) > 0:
- # Calculate SHA256 of PoP
- h = hashes.Hash(hashes.SHA256(), backend=default_backend())
- h.update(self.pop)
- digest = h.finalize()
- # XOR with and update Shared Key
- sharedK = xor(sharedK, digest)
- self._print_verbose('New Shared Key XORed with PoP:\t' + utils.str_to_hexstr(sharedK.decode('latin-1')))
- # Initialize the encryption engine with Shared Key and initialization vector
- cipher = Cipher(algorithms.AES(sharedK), modes.CTR(device_random), backend=default_backend())
- self.cipher = cipher.encryptor()
- def setup1_request(self):
- # Form SessionCmd1 request packet using encrypted device public key
- setup_req = proto.session_pb2.SessionData()
- setup_req.sec_ver = session_pb2.SecScheme1
- setup_req.sec1.msg = proto.sec1_pb2.Session_Command1
- # Encrypt device public key and attach to the request packet
- client_verify = self.cipher.update(self.device_public_key)
- self._print_verbose('Client Verify:\t' + utils.str_to_hexstr(client_verify.decode('latin-1')))
- setup_req.sec1.sc1.client_verify_data = client_verify
- return setup_req.SerializeToString().decode('latin-1')
- def setup1_response(self, response_data):
- # Interpret SessionResp1 response packet
- setup_resp = proto.session_pb2.SessionData()
- setup_resp.ParseFromString(tobytes(response_data))
- # Ensure security scheme matches
- if setup_resp.sec_ver == session_pb2.SecScheme1:
- # Read encrypyed device verify string
- device_verify = setup_resp.sec1.sr1.device_verify_data
- self._print_verbose('Device verify:\t' + utils.str_to_hexstr(device_verify.decode('latin-1')))
- # Decrypt the device verify string
- enc_client_pubkey = self.cipher.update(setup_resp.sec1.sr1.device_verify_data)
- self._print_verbose('Enc client pubkey:\t ' + utils.str_to_hexstr(enc_client_pubkey.decode('latin-1')))
- # Match decryped string with client public key
- if enc_client_pubkey != self.client_public_key:
- print('Mismatch in device verify')
- return -2
- else:
- print('Unsupported security protocol')
- return -1
- def encrypt_data(self, data):
- return self.cipher.update(tobytes(data))
- def decrypt_data(self, data):
- return self.cipher.update(tobytes(data))
|