فهرست منبع

Merge branch 'feature/mbedtls_add_faster_modexp' into 'master'

mbedtls: Add a new (X^Y) mod M implementation (HAC 14.94)

Closes IDF-965

See merge request espressif/esp-idf!6418
Angus Gratton 6 سال پیش
والد
کامیت
5b33d6cf94

+ 4 - 0
components/idf_test/include/idf_performance.h

@@ -33,3 +33,7 @@
 // SHA256 hardware throughput at 240MHz, threshold set lower than worst case
 #define IDF_PERFORMANCE_MIN_SHA256_THROUGHPUT_MBSEC                             9.0
 
+#define IDF_PERFORMANCE_MAX_RSA_2048KEY_PUBLIC_OP  19000
+#define IDF_PERFORMANCE_MAX_RSA_2048KEY_PRIVATE_OP 180000
+#define IDF_PERFORMANCE_MAX_RSA_4096KEY_PUBLIC_OP  65000
+#define IDF_PERFORMANCE_MAX_RSA_4096KEY_PRIVATE_OP 850000

+ 1 - 11
components/mbedtls/Kconfig

@@ -142,7 +142,7 @@ menu "mbedTLS"
 
     config MBEDTLS_HARDWARE_MPI
         bool "Enable hardware MPI (bignum) acceleration"
-        default n
+        default y
         help
             Enable hardware accelerated multiple precision integer operations.
 
@@ -151,16 +151,6 @@ menu "mbedTLS"
 
             These operations are used by RSA.
 
-    config MBEDTLS_MPI_USE_INTERRUPT
-        bool "Use interrupt for MPI operations"
-        depends on MBEDTLS_HARDWARE_MPI
-        default n
-        help
-            Use an interrupt to coordinate MPI operations.
-
-            This allows other code to run on the CPU while an MPI operation is pending.
-            Otherwise the CPU busy-waits.
-
     config MBEDTLS_HARDWARE_SHA
         bool "Enable hardware SHA acceleration"
         default y

+ 119 - 61
components/mbedtls/port/esp32/esp_bignum.c

@@ -60,29 +60,6 @@ static const __attribute__((unused)) char *TAG = "bignum";
 #define ciL    (sizeof(mbedtls_mpi_uint))         /* chars in limb  */
 #define biL    (ciL << 3)                         /* bits  in limb  */
 
-#if defined(CONFIG_MBEDTLS_MPI_USE_INTERRUPT)
-static SemaphoreHandle_t op_complete_sem;
-
-static IRAM_ATTR void rsa_complete_isr(void *arg)
-{
-    BaseType_t higher_woken;
-    DPORT_REG_WRITE(RSA_INTERRUPT_REG, 1);
-    xSemaphoreGiveFromISR(op_complete_sem, &higher_woken);
-    if (higher_woken) {
-        portYIELD_FROM_ISR();
-    }
-}
-
-static void rsa_isr_initialise(void)
-{
-    if (op_complete_sem == NULL) {
-        op_complete_sem = xSemaphoreCreateBinary();
-        esp_intr_alloc(ETS_RSA_INTR_SOURCE, 0, rsa_complete_isr, NULL, NULL);
-    }
-}
-
-#endif /* CONFIG_MBEDTLS_MPI_USE_INTERRUPT */
-
 static _lock_t mpi_lock;
 
 void esp_mpi_acquire_hardware( void )
@@ -96,10 +73,6 @@ void esp_mpi_acquire_hardware( void )
 
     while(DPORT_REG_READ(RSA_CLEAN_REG) != 1);
     // Note: from enabling RSA clock to here takes about 1.3us
-
-#ifdef CONFIG_MBEDTLS_MPI_USE_INTERRUPT
-    rsa_isr_initialise();
-#endif
 }
 
 void esp_mpi_release_hardware( void )
@@ -264,20 +237,11 @@ static inline void start_op(uint32_t op_reg)
 */
 static inline void wait_op_complete(uint32_t op_reg)
 {
-#ifdef CONFIG_MBEDTLS_MPI_USE_INTERRUPT
-    if (!xSemaphoreTake(op_complete_sem, 2000 / portTICK_PERIOD_MS)) {
-        ESP_LOGE(TAG, "Timed out waiting for RSA operation (op_reg 0x%x int_reg 0x%x)",
-                 op_reg, DPORT_REG_READ(RSA_INTERRUPT_REG));
-        abort(); /* indicates a fundamental problem with driver */
-    }
-#else
     while(DPORT_REG_READ(RSA_INTERRUPT_REG) != 1)
        { }
 
     /* clear the interrupt */
     DPORT_REG_WRITE(RSA_INTERRUPT_REG, 1);
-#endif
-
 }
 
 /* Sub-stages of modulo multiplication/exponentiation operations */
@@ -335,8 +299,124 @@ int esp_mpi_mul_mpi_mod(mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi
 
 #if defined(MBEDTLS_MPI_EXP_MOD_ALT)
 
+static int mont(mbedtls_mpi* Z, const mbedtls_mpi* X, const mbedtls_mpi* Y, const mbedtls_mpi* M,
+                mbedtls_mpi_uint Mprime,
+                size_t hw_words,
+                bool again)
+{
+    // Note Z may be the same pointer as X or Y
+    int ret = 0;
+
+    // montgomery mult prepare
+    if (again == false) {
+        mpi_to_mem_block(RSA_MEM_M_BLOCK_BASE, M, hw_words);
+        DPORT_REG_WRITE(RSA_M_DASH_REG, Mprime);
+        DPORT_REG_WRITE(RSA_MULT_MODE_REG, hw_words / 16 - 1);
+    }
+
+    mpi_to_mem_block(RSA_MEM_X_BLOCK_BASE, X, hw_words);
+    mpi_to_mem_block(RSA_MEM_RB_BLOCK_BASE, Y, hw_words);
+
+    start_op(RSA_MULT_START_REG);
+
+    MBEDTLS_MPI_CHK( mbedtls_mpi_grow(Z, hw_words) );
+
+    wait_op_complete(RSA_MULT_START_REG);
+
+    /* Read back the result */
+    mem_block_to_mpi(Z, RSA_MEM_Z_BLOCK_BASE, hw_words);
+
+    /* from HAC 14.36 - 3. If Z >= M then Z = Z - M */
+    if (mbedtls_mpi_cmp_mpi(Z, M) >= 0) {
+        MBEDTLS_MPI_CHK(mbedtls_mpi_sub_mpi(Z, Z, M));
+    }
+ cleanup:
+    return ret;
+}
+
 /*
- * Sliding-window exponentiation: Z = X^Y mod M  (HAC 14.85)
+ * Return the most significant one-bit.
+ */
+static size_t mbedtls_mpi_msb( const mbedtls_mpi* X )
+{
+    int i, j;
+    if (X != NULL && X->n != 0) {
+        for (i = X->n - 1; i >= 0; i--) {
+            if (X->p[i] != 0) {
+                for (j = biL - 1; j >= 0; j--) {
+                    if ((X->p[i] & (1 << j)) != 0) {
+                        return (i * biL) + j;
+                    }
+                }
+            }
+        }
+    }
+    return 0;
+}
+
+/*
+ * Montgomery exponentiation: Z = X ^ Y mod M  (HAC 14.94)
+ */
+static int mpi_montgomery_exp_calc( mbedtls_mpi* Z, const mbedtls_mpi* X, const mbedtls_mpi* Y, const mbedtls_mpi* M,
+                            mbedtls_mpi* Rinv,
+                            size_t hw_words,
+                            mbedtls_mpi_uint Mprime )
+{
+    int ret = 0;
+    mbedtls_mpi X_, one;
+
+    mbedtls_mpi_init(&X_);
+    mbedtls_mpi_init(&one);
+    if( ( ( ret = mbedtls_mpi_grow(&one, hw_words) ) != 0 ) ||
+        ( ( ret = mbedtls_mpi_set_bit(&one, 0, 1) )  != 0 ) ) {
+        goto cleanup2;
+    }
+
+    // Algorithm from HAC 14.94
+    {
+        // 0 determine t (highest bit set in y)
+        int t = mbedtls_mpi_msb(Y);
+
+        esp_mpi_acquire_hardware();
+
+        // 1.1 x_ = mont(x, R^2 mod m)
+        //        = mont(x, rb)
+        MBEDTLS_MPI_CHK( mont(&X_, X, Rinv, M, Mprime, hw_words, false) );
+
+        // 1.2 z = R mod m
+        // now z = R mod m = Mont (R^2 mod m, 1) mod M (as Mont(x) = X&R^-1 mod M)
+        MBEDTLS_MPI_CHK( mont(Z, Rinv, &one, M, Mprime, hw_words, true) );
+
+        // 2 for i from t down to 0
+        for (int i = t; i >= 0; i--) {
+            // 2.1 z = mont(z,z)
+            if (i != t) { // skip on the first iteration as is still unity
+                MBEDTLS_MPI_CHK( mont(Z, Z, Z, M, Mprime, hw_words, true) );
+            }
+
+            // 2.2 if y[i] = 1 then z = mont(A, x_)
+            if (mbedtls_mpi_get_bit(Y, i)) {
+                MBEDTLS_MPI_CHK( mont(Z, Z, &X_, M, Mprime, hw_words, true) );
+            }
+        }
+
+        // 3 z = Mont(z, 1)
+        MBEDTLS_MPI_CHK( mont(Z, Z, &one, M, Mprime, hw_words, true) );
+    }
+
+ cleanup:
+    mbedtls_mpi_free(&X_);
+    mbedtls_mpi_free(&one);
+    esp_mpi_release_hardware();
+    return ret;
+
+ cleanup2:
+    mbedtls_mpi_free(&one);
+    return ret;
+}
+
+/*
+ * Z = X ^ Y mod M
  *
  * _Rinv is optional pre-calculated version of Rinv (via calculate_rinv()).
  *
@@ -389,30 +469,8 @@ int mbedtls_mpi_exp_mod( mbedtls_mpi* Z, const mbedtls_mpi* X, const mbedtls_mpi
 
     Mprime = modular_inverse(M);
 
-    esp_mpi_acquire_hardware();
-
-    /* "mode" register loaded with number of 512-bit blocks, minus 1 */
-    DPORT_REG_WRITE(RSA_MODEXP_MODE_REG, (hw_words / 16) - 1);
-
-    /* Load M, X, Rinv, M-prime (M-prime is mod 2^32) */
-    mpi_to_mem_block(RSA_MEM_X_BLOCK_BASE, X, hw_words);
-    mpi_to_mem_block(RSA_MEM_Y_BLOCK_BASE, Y, hw_words);
-    mpi_to_mem_block(RSA_MEM_M_BLOCK_BASE, M, hw_words);
-    mpi_to_mem_block(RSA_MEM_RB_BLOCK_BASE, Rinv, hw_words);
-    DPORT_REG_WRITE(RSA_M_DASH_REG, Mprime);
-
-    start_op(RSA_START_MODEXP_REG);
-
-    /* X ^ Y may actually be shorter than M, but unlikely when used for crypto */
-    if ((ret = mbedtls_mpi_grow(Z, m_words)) != 0) {
-        esp_mpi_release_hardware();
-        goto cleanup;
-    }
-
-    wait_op_complete(RSA_START_MODEXP_REG);
-
-    mem_block_to_mpi(Z, RSA_MEM_Z_BLOCK_BASE, m_words);
-    esp_mpi_release_hardware();
+    // Montgomery exponentiation: Z = X ^ Y mod M  (HAC 14.94)
+    MBEDTLS_MPI_CHK( mpi_montgomery_exp_calc(Z, X, Y, M, Rinv, hw_words, Mprime) );
 
     // Compensate for negative X
     if (X->s == -1 && (Y->p[0] & 1) != 0) {

+ 55 - 0
components/mbedtls/test/test_rsa.c

@@ -11,11 +11,13 @@
 #include "mbedtls/rsa.h"
 #include "mbedtls/pk.h"
 #include "mbedtls/x509_crt.h"
+#include "mbedtls/entropy_poll.h"
 #include "freertos/FreeRTOS.h"
 #include "freertos/task.h"
 #include "freertos/semphr.h"
 #include "unity.h"
 #include "sdkconfig.h"
+#include "test_utils.h"
 
 /* Taken from openssl s_client -connect api.gigafive.com:443 -showcerts
  */
@@ -238,3 +240,56 @@ static void test_cert(const char *cert, const uint8_t *expected_output, size_t o
 
     mbedtls_x509_crt_free(&crt);
 }
+
+static int myrand(void *rng_state, unsigned char *output, size_t len)
+{
+    size_t olen;
+    return mbedtls_hardware_poll(rng_state, output, len, &olen);
+}
+
+#ifdef CONFIG_MBEDTLS_HARDWARE_MPI
+
+TEST_CASE("test performance RSA key operations", "[bignum][ignore]")
+{
+    mbedtls_rsa_context rsa;
+    unsigned char orig_buf[4096 / 8];
+    unsigned char encrypted_buf[4096 / 8];
+    unsigned char decrypted_buf[4096 / 8];
+    int64_t start;
+    int public_perf, private_perf;
+
+    printf("First, orig_buf is encrypted by the public key, and then decrypted by the private key\n");
+
+    for (int keysize = 2048; keysize <= 4096; keysize += 2048) {
+        memset(orig_buf, 0xAA, sizeof(orig_buf));
+        orig_buf[0] = 0; // Ensure that orig_buf is smaller than rsa.N
+
+        mbedtls_rsa_init(&rsa, MBEDTLS_RSA_PRIVATE, 0);
+        TEST_ASSERT_EQUAL(0, mbedtls_rsa_gen_key(&rsa, myrand, NULL, keysize, 65537));
+
+        TEST_ASSERT_EQUAL(keysize, (int)rsa.len * 8);
+        TEST_ASSERT_EQUAL(keysize, (int)rsa.D.n * sizeof(mbedtls_mpi_uint) * 8); // The private exponent
+
+        start = esp_timer_get_time();
+        TEST_ASSERT_EQUAL(0, mbedtls_rsa_public(&rsa, orig_buf, encrypted_buf));
+        public_perf = esp_timer_get_time() - start;
+
+        start = esp_timer_get_time();
+        TEST_ASSERT_EQUAL(0, mbedtls_rsa_private(&rsa, NULL, NULL, encrypted_buf, decrypted_buf));
+        private_perf = esp_timer_get_time() - start;
+
+        if (keysize == 2048) {
+            TEST_PERFORMANCE_LESS_THAN(RSA_2048KEY_PUBLIC_OP, "public operations %d us", public_perf);
+            TEST_PERFORMANCE_LESS_THAN(RSA_2048KEY_PRIVATE_OP, "private operations %d us", private_perf);
+        } else {
+            TEST_PERFORMANCE_LESS_THAN(RSA_4096KEY_PUBLIC_OP, "public operations %d us", public_perf);
+            TEST_PERFORMANCE_LESS_THAN(RSA_4096KEY_PRIVATE_OP, "private operations %d us", private_perf);
+        }
+
+        TEST_ASSERT_EQUAL_MEMORY_MESSAGE(orig_buf, decrypted_buf, keysize / 8, "RSA operation");
+
+        mbedtls_rsa_free(&rsa);
+    }
+}
+
+#endif // CONFIG_MBEDTLS_HARDWARE_MPI

+ 0 - 1
tools/ldgen/samples/sdkconfig

@@ -404,7 +404,6 @@ CONFIG_MBEDTLS_SSL_MAX_CONTENT_LEN=16384
 CONFIG_MBEDTLS_DEBUG=
 CONFIG_MBEDTLS_HARDWARE_AES=y
 CONFIG_MBEDTLS_HARDWARE_MPI=y
-CONFIG_MBEDTLS_MPI_USE_INTERRUPT=y
 CONFIG_MBEDTLS_HARDWARE_SHA=
 CONFIG_MBEDTLS_HAVE_TIME=y
 CONFIG_MBEDTLS_HAVE_TIME_DATE=

+ 0 - 2
tools/unit-test-app/sdkconfig.defaults

@@ -11,8 +11,6 @@ CONFIG_FREERTOS_WATCHPOINT_END_OF_STACK=y
 CONFIG_FREERTOS_THREAD_LOCAL_STORAGE_POINTERS=3
 CONFIG_FREERTOS_USE_TRACE_FACILITY=y
 CONFIG_HEAP_POISONING_COMPREHENSIVE=y
-CONFIG_MBEDTLS_HARDWARE_MPI=y
-CONFIG_MBEDTLS_MPI_USE_INTERRUPT=y
 CONFIG_MBEDTLS_HARDWARE_SHA=y
 CONFIG_SPI_FLASH_ENABLE_COUNTERS=y
 CONFIG_ESP_TASK_WDT=n