|
|
@@ -281,31 +281,37 @@ int hmac_sha1(const u8 *key, size_t key_len, const u8 *data, size_t data_len,
|
|
|
return hmac_sha1_vector(key, key_len, 1, &data, &data_len, mac);
|
|
|
}
|
|
|
|
|
|
-void *aes_crypt_init(const u8 *key, size_t len)
|
|
|
+static void *aes_crypt_init(int mode, const u8 *key, size_t len)
|
|
|
{
|
|
|
+ int ret = -1;
|
|
|
mbedtls_aes_context *aes = os_malloc(sizeof(*aes));
|
|
|
if (!aes) {
|
|
|
return NULL;
|
|
|
}
|
|
|
mbedtls_aes_init(aes);
|
|
|
|
|
|
- if (mbedtls_aes_setkey_enc(aes, key, len * 8) < 0) {
|
|
|
+ if (mode == MBEDTLS_AES_ENCRYPT) {
|
|
|
+ ret = mbedtls_aes_setkey_enc(aes, key, len * 8);
|
|
|
+ } else if (mode == MBEDTLS_AES_DECRYPT){
|
|
|
+ ret = mbedtls_aes_setkey_dec(aes, key, len * 8);
|
|
|
+ }
|
|
|
+ if (ret < 0) {
|
|
|
mbedtls_aes_free(aes);
|
|
|
os_free(aes);
|
|
|
- wpa_printf(MSG_ERROR, "%s: mbedtls_aes_setkey_enc failed", __func__);
|
|
|
+ wpa_printf(MSG_ERROR, "%s: mbedtls_aes_setkey_enc/mbedtls_aes_setkey_dec failed", __func__);
|
|
|
return NULL;
|
|
|
}
|
|
|
|
|
|
return (void *) aes;
|
|
|
}
|
|
|
|
|
|
-int aes_crypt(void *ctx, int mode, const u8 *in, u8 *out)
|
|
|
+static int aes_crypt(void *ctx, int mode, const u8 *in, u8 *out)
|
|
|
{
|
|
|
return mbedtls_aes_crypt_ecb((mbedtls_aes_context *)ctx,
|
|
|
mode, in, out);
|
|
|
}
|
|
|
|
|
|
-void aes_crypt_deinit(void *ctx)
|
|
|
+static void aes_crypt_deinit(void *ctx)
|
|
|
{
|
|
|
mbedtls_aes_free((mbedtls_aes_context *)ctx);
|
|
|
os_free(ctx);
|
|
|
@@ -313,7 +319,7 @@ void aes_crypt_deinit(void *ctx)
|
|
|
|
|
|
void *aes_encrypt_init(const u8 *key, size_t len)
|
|
|
{
|
|
|
- return aes_crypt_init(key, len);
|
|
|
+ return aes_crypt_init(MBEDTLS_AES_ENCRYPT, key, len);
|
|
|
}
|
|
|
|
|
|
int aes_encrypt(void *ctx, const u8 *plain, u8 *crypt)
|
|
|
@@ -328,7 +334,7 @@ void aes_encrypt_deinit(void *ctx)
|
|
|
|
|
|
void * aes_decrypt_init(const u8 *key, size_t len)
|
|
|
{
|
|
|
- return aes_crypt_init(key, len);
|
|
|
+ return aes_crypt_init(MBEDTLS_AES_DECRYPT, key, len);
|
|
|
}
|
|
|
|
|
|
int aes_decrypt(void *ctx, const u8 *crypt, u8 *plain)
|