Browse Source

spi_flash: support to verify written encrypted data

Also add unit test for encrypted_read
Michael (XIAO Xufeng) 6 năm trước cách đây
mục cha
commit
2660cb82ae

+ 131 - 50
components/spi_flash/flash_ops.c

@@ -41,6 +41,8 @@
 #include "esp_flash_partitions.h"
 #include "cache_utils.h"
 #include "esp_flash.h"
+#include "esp_attr.h"
+
 
 /* bytes erased by SPIEraseBlock() ROM function */
 #define BLOCK_ERASE_SIZE 65536
@@ -270,10 +272,10 @@ static IRAM_ATTR esp_rom_spiflash_result_t spi_flash_write_inner(uint32_t target
 
     uint32_t before_buf[ESP_ROM_SPIFLASH_BUFF_BYTE_READ_NUM / sizeof(uint32_t)];
     uint32_t after_buf[ESP_ROM_SPIFLASH_BUFF_BYTE_READ_NUM / sizeof(uint32_t)];
+    uint32_t *expected_buf = before_buf;
     int32_t remaining = len;
     for(int i = 0; i < len; i += sizeof(before_buf)) {
         int i_w = i / sizeof(uint32_t); // index in words (i is an index in bytes)
-
         int32_t read_len = MIN(sizeof(before_buf), remaining);
 
         // Read "before" contents from flash
@@ -282,20 +284,22 @@ static IRAM_ATTR esp_rom_spiflash_result_t spi_flash_write_inner(uint32_t target
             break;
         }
 
-#ifdef CONFIG_SPI_FLASH_WARN_SETTING_ZERO_TO_ONE
         for (int r = 0; r < read_len; r += sizeof(uint32_t)) {
             int r_w = r / sizeof(uint32_t); // index in words (r is index in bytes)
 
             uint32_t write = src_addr[i_w + r_w];
             uint32_t before = before_buf[r_w];
+            uint32_t expected = write & before;
+#ifdef CONFIG_SPI_FLASH_WARN_SETTING_ZERO_TO_ONE
             if ((before & write) != write) {
                 spi_flash_guard_end();
                 ESP_LOGW(TAG, "Write at offset 0x%x requests 0x%08x but will write 0x%08x -> 0x%08x",
                          target + i + r, write, before, before & write);
                 spi_flash_guard_start();
             }
-        }
 #endif
+            expected_buf[r_w] = expected;
+        }
 
         res = esp_rom_spiflash_write(target + i, &src_addr[i_w], read_len);
         if (res != ESP_ROM_SPIFLASH_RESULT_OK) {
@@ -310,7 +314,7 @@ static IRAM_ATTR esp_rom_spiflash_result_t spi_flash_write_inner(uint32_t target
         for (int r = 0; r < read_len; r += sizeof(uint32_t)) {
             int r_w = r / sizeof(uint32_t); // index in words (r is index in bytes)
 
-            uint32_t expected = src_addr[i_w + r_w] & before_buf[r_w];
+            uint32_t expected = expected_buf[r_w];
             uint32_t actual = after_buf[r_w];
             if (expected != actual) {
 #ifdef CONFIG_SPI_FLASH_LOG_FAILED_WRITE
@@ -427,10 +431,63 @@ out:
 }
 #endif // CONFIG_SPI_FLASH_USE_LEGACY_IMPL
 
+
+static IRAM_ATTR esp_err_t spi_flash_write_encrypted_in_rows(size_t dest_addr, const uint8_t *src, size_t size)
+{
+    assert((dest_addr % 16) == 0);
+    assert((size % 16) == 0);
+
+    /* esp_rom_spiflash_write_encrypted encrypts data in RAM as it writes,
+    so copy to a temporary buffer - 32 bytes at a time.
+
+    Each call to esp_rom_spiflash_write_encrypted takes a 32 byte "row" of
+    data to encrypt, and each row is two 16 byte AES blocks
+    that share a key (as derived from flash address).
+    */
+
+    esp_rom_spiflash_result_t rc = ESP_ROM_SPIFLASH_RESULT_OK;
+    WORD_ALIGNED_ATTR uint8_t encrypt_buf[32];
+    uint32_t row_size;
+    for (size_t i = 0; i < size; i += row_size) {
+        uint32_t row_addr = dest_addr + i;
+
+        if (i == 0 && (row_addr % 32) != 0) {
+            /* writing to second block of a 32 byte row */
+            row_size = 16;
+            row_addr -= 16;
+            /* copy to second block in buffer */
+            memcpy(encrypt_buf + 16, src + i, 16);
+            /* decrypt the first block from flash, will reencrypt to same bytes */
+            spi_flash_read_encrypted(row_addr, encrypt_buf, 16);
+        } else if (size - i == 16) {
+            /* 16 bytes left, is first block of a 32 byte row */
+            row_size = 16;
+            /* copy to first block in buffer */
+            memcpy(encrypt_buf, src + i, 16);
+            /* decrypt the second block from flash, will reencrypt to same bytes */
+            spi_flash_read_encrypted(row_addr + 16, encrypt_buf + 16, 16);
+        } else {
+            /* Writing a full 32 byte row (2 blocks) */
+            row_size = 32;
+            memcpy(encrypt_buf, src + i, 32);
+        }
+
+        spi_flash_guard_start();
+        rc = esp_rom_spiflash_write_encrypted(row_addr, (uint32_t *)encrypt_buf, 32);
+        spi_flash_guard_end();
+        if (rc != ESP_ROM_SPIFLASH_RESULT_OK) {
+            break;
+        }
+    }
+    bzero(encrypt_buf, sizeof(encrypt_buf));
+    return spi_flash_translate_rc(rc);
+}
+
+
 esp_err_t IRAM_ATTR spi_flash_write_encrypted(size_t dest_addr, const void *src, size_t size)
 {
+    esp_err_t err = ESP_OK;
     CHECK_WRITE_ADDRESS(dest_addr, size);
-    const uint8_t *ssrc = (const uint8_t *)src;
     if ((dest_addr % 16) != 0) {
         return ESP_ERR_INVALID_ARG;
     }
@@ -439,60 +496,84 @@ esp_err_t IRAM_ATTR spi_flash_write_encrypted(size_t dest_addr, const void *src,
     }
 
     COUNTER_START();
-    esp_rom_spiflash_result_t rc;
-    rc = spi_flash_unlock();
-    if (rc == ESP_ROM_SPIFLASH_RESULT_OK) {
-        /* esp_rom_spiflash_write_encrypted encrypts data in RAM as it writes,
-           so copy to a temporary buffer - 32 bytes at a time.
-
-           Each call to esp_rom_spiflash_write_encrypted takes a 32 byte "row" of
-           data to encrypt, and each row is two 16 byte AES blocks
-           that share a key (as derived from flash address).
-        */
-        uint8_t encrypt_buf[32] __attribute__((aligned(4)));
-        uint32_t row_size;
-        for (size_t i = 0; i < size; i += row_size) {
-            uint32_t row_addr = dest_addr + i;
-            if (i == 0 && (row_addr % 32) != 0) {
-                /* writing to second block of a 32 byte row */
-                row_size = 16;
-                row_addr -= 16;
-                /* copy to second block in buffer */
-                memcpy(encrypt_buf + 16, ssrc + i, 16);
-                /* decrypt the first block from flash, will reencrypt to same bytes */
-                spi_flash_read_encrypted(row_addr, encrypt_buf, 16);
-            } else if (size - i == 16) {
-                /* 16 bytes left, is first block of a 32 byte row */
-                row_size = 16;
-                /* copy to first block in buffer */
-                memcpy(encrypt_buf, ssrc + i, 16);
-                /* decrypt the second block from flash, will reencrypt to same bytes */
-                spi_flash_read_encrypted(row_addr + 16, encrypt_buf + 16, 16);
-            } else {
-                /* Writing a full 32 byte row (2 blocks) */
-                row_size = 32;
-                memcpy(encrypt_buf, ssrc + i, 32);
+    esp_rom_spiflash_result_t rc = spi_flash_unlock();
+    err = spi_flash_translate_rc(rc);
+    if (err != ESP_OK) {
+        goto fail;
+    }
+
+#ifndef CONFIG_SPI_FLASH_VERIFY_WRITE
+    err = spi_flash_write_encrypted_in_rows(dest_addr, (const uint8_t*)src, size);
+    COUNTER_ADD_BYTES(write, size);
+    spi_flash_guard_start();
+    spi_flash_check_and_flush_cache(dest_addr, size);
+    spi_flash_guard_end();
+#else
+    const uint32_t* src_w = (const uint32_t*)src;
+    uint32_t read_buf[ESP_ROM_SPIFLASH_BUFF_BYTE_READ_NUM / sizeof(uint32_t)];
+    int32_t remaining = size;
+    for(int i = 0; i < size; i += sizeof(read_buf)) {
+        int i_w = i / sizeof(uint32_t); // index in words (i is an index in bytes)
+        int32_t read_len = MIN(sizeof(read_buf), remaining);
+
+        // Read "before" contents from flash
+        esp_err_t err = spi_flash_read(dest_addr + i, read_buf, read_len);
+        if (err != ESP_OK) {
+            break;
+        }
+
+#ifdef CONFIG_SPI_FLASH_WARN_SETTING_ZERO_TO_ONE
+        //The written data cannot be predicted, so warning is shown if any of the bits is not 1.
+        for (int r = 0; r < read_len; r += sizeof(uint32_t)) {
+            uint32_t before = read_buf[r / sizeof(uint32_t)];
+            if (before != 0xFFFFFFFF) {
+                ESP_LOGW(TAG, "Encrypted write at offset 0x%x but not erased (0x%08x)",
+                         dest_addr + i + r, before);
             }
+        }
+#endif
 
-            spi_flash_guard_start();
-            rc = esp_rom_spiflash_write_encrypted(row_addr, (uint32_t *)encrypt_buf, 32);
-            spi_flash_guard_end();
-            if (rc != ESP_ROM_SPIFLASH_RESULT_OK) {
-                break;
+        err = spi_flash_write_encrypted_in_rows(dest_addr + i, src + i, read_len);
+        if (err != ESP_OK) {
+            break;
+        }
+        COUNTER_ADD_BYTES(write, size);
+
+        spi_flash_guard_start();
+        spi_flash_check_and_flush_cache(dest_addr, size);
+        spi_flash_guard_end();
+
+        err = spi_flash_read_encrypted(dest_addr + i, read_buf, read_len);
+        if (err != ESP_OK) {
+            break;
+        }
+
+        for (int r = 0; r < read_len; r += sizeof(uint32_t)) {
+            int r_w = r / sizeof(uint32_t); // index in words (r is index in bytes)
+
+            uint32_t expected = src_w[i_w + r_w];
+            uint32_t actual = read_buf[r_w];
+            if (expected != actual) {
+#ifdef CONFIG_SPI_FLASH_LOG_FAILED_WRITE
+                ESP_LOGE(TAG, "Bad write at offset 0x%x expected 0x%08x readback 0x%08x", dest_addr + i + r, expected, actual);
+#endif
+                err = ESP_FAIL;
             }
         }
-        bzero(encrypt_buf, sizeof(encrypt_buf));
+        if (err != ESP_OK) {
+            break;
+        }
+        remaining -= read_len;
     }
-    COUNTER_ADD_BYTES(write, size);
-    COUNTER_STOP(write);
+#endif // CONFIG_SPI_FLASH_VERIFY_WRITE
 
-    spi_flash_guard_start();
-    spi_flash_check_and_flush_cache(dest_addr, size);
-    spi_flash_guard_end();
+fail:
 
-    return spi_flash_translate_rc(rc);
+    COUNTER_STOP(write);
+    return err;
 }
 
+
 #ifdef CONFIG_SPI_FLASH_USE_LEGACY_IMPL
 esp_err_t IRAM_ATTR spi_flash_read(size_t src, void *dstv, size_t size)
 {

+ 68 - 0
components/spi_flash/test/test_flash_encryption.c

@@ -8,6 +8,8 @@
 #include <esp_spi_flash.h>
 #include <esp_attr.h>
 #include <esp_flash_encrypt.h>
+#include <string.h>
+
 
 #ifdef CONFIG_SECURE_FLASH_ENC_ENABLED
 
@@ -161,4 +163,70 @@ static void verify_erased_flash(size_t offset, size_t length)
     }
 }
 
+TEST_CASE("test read & write random encrypted data", "[flash_encryption][test_env=UT_T1_FlashEncryption]")
+{
+    const int MAX_LEN = 192;
+    //buffer to hold the read data
+    WORD_ALIGNED_ATTR uint8_t buffer_to_write[MAX_LEN+4];
+    //test with unaligned buffer
+    uint8_t* data_buf = &buffer_to_write[3];
+
+    setup_tests();
+
+    esp_err_t err = spi_flash_erase_sector(start / SPI_FLASH_SEC_SIZE);
+    TEST_ESP_OK(err);
+
+    //initialize the buffer to compare
+    uint8_t *cmp_buf = heap_caps_malloc(SPI_FLASH_SEC_SIZE, MALLOC_CAP_32BIT | MALLOC_CAP_8BIT | MALLOC_CAP_INTERNAL);
+    assert(((intptr_t)cmp_buf % 4) == 0);
+    err = spi_flash_read_encrypted(start, cmp_buf, SPI_FLASH_SEC_SIZE);
+    TEST_ESP_OK(err);
+
+    srand(789);
+
+    uint32_t offset = 0;
+    do {
+        //the encrypted write only works at 16-byte boundary
+        int skip = (rand() % 4) * 16;
+        int len = ((rand() % (MAX_LEN/16)) + 1) * 16;
+
+        for (int i = 0; i < MAX_LEN; i++) {
+            data_buf[i] = rand();
+        }
+
+        offset += skip;
+        if (offset + len > SPI_FLASH_SEC_SIZE) {
+            if (offset > SPI_FLASH_SEC_SIZE) {
+                break;
+            }
+            len = SPI_FLASH_SEC_SIZE - offset;
+        }
+
+        printf("write %d bytes to 0x%08x...\n", len, start + offset);
+        err = spi_flash_write_encrypted(start + offset, data_buf, len);
+        TEST_ESP_OK(err);
+
+        memcpy(cmp_buf + offset, data_buf, len);
+        offset += len;
+    } while (offset < SPI_FLASH_SEC_SIZE);
+
+    offset = 0;
+    do {
+        int len = ((rand() % (MAX_LEN/16)) + 1) * 16;
+        if (offset + len > SPI_FLASH_SEC_SIZE) {
+            len = SPI_FLASH_SEC_SIZE - offset;
+        }
+
+        err = spi_flash_read_encrypted(start + offset, data_buf, len);
+        TEST_ESP_OK(err);
+
+        printf("compare %d bytes at 0x%08x...\n", len, start + offset);
+
+        TEST_ASSERT_EQUAL_HEX8_ARRAY(cmp_buf + offset, data_buf, len);
+        offset += len;
+    } while (offset < SPI_FLASH_SEC_SIZE);
+
+    free(cmp_buf);
+}
+
 #endif // CONFIG_SECURE_FLASH_ENC_ENABLED