Sfoglia il codice sorgente

CMSIS-NN: Implement bit-accurate S8 Softmax (non-DSP)

Giorgio Arena 6 anni fa
parent
commit
95ad2aa1a7

+ 2 - 2
CMSIS/DSP/Include/arm_math.h

@@ -382,10 +382,10 @@ extern "C"
 #include <float.h>
 #include <limits.h>
 
-#define Q31_MAX   LONG_MAX
+#define Q31_MAX   (0x7FFFFFFFL)
 #define Q15_MAX   SHRT_MAX
 #define Q7_MAX    SCHAR_MAX
-#define Q31_MIN   LONG_MIN
+#define Q31_MIN   (0x80000000L)
 #define Q15_MIN   SHRT_MIN
 #define Q7_MIN    SCHAR_MIN
 

+ 32 - 0
CMSIS/NN/Include/arm_nnfunctions.h

@@ -1614,6 +1614,9 @@ extern    "C"
    * @param[in]       dim_vec     input vector dimension
    * @param[out]      p_out       pointer to output vector
    *
+   * @note This function is an optimized version which is not bit-accurate with
+   *       TensorFlow Lite's kernel
+   *
    */
 
 void arm_softmax_q7(const q7_t * vec_in, const uint16_t dim_vec, q7_t * p_out);
@@ -1626,6 +1629,9 @@ void arm_softmax_q7(const q7_t * vec_in, const uint16_t dim_vec, q7_t * p_out);
    * @param[out]      p_out       pointer to output vector
    * @return none.
    *
+   * @note This function is an optimized version which is not bit-accurate with
+   *       TensorFlow Lite's kernel
+   *
    */
 
 void arm_softmax_with_batch_q7(const q7_t * vec_in, const uint16_t nb_batches,const uint16_t dim_vec, q7_t * p_out );
@@ -1636,10 +1642,36 @@ void arm_softmax_with_batch_q7(const q7_t * vec_in, const uint16_t nb_batches,co
    * @param[out]      p_out       pointer to output vector
    * @return none.
    *
+   * @note This function is an optimized version which is not bit-accurate with
+   *       TensorFlow Lite's kernel
+   *
    */
 
 void arm_softmax_q15(const q15_t * vec_in, const uint16_t dim_vec, q15_t * p_out);
 
+  /**
+   * @brief S8 softmax function
+   * @param[in]  input     Pointer to the input tensor
+   * @param[in]  num_rows  Number of rows in the input tensor
+   * @param[in]  row_size  Number of elements in each input row
+   * @param[in]  mult      Input quantization multiplier
+   * @param[in]  shift     Input quantization shift within the range [0, 31]
+   * @param[in]  diff_min  Minimum difference with max in row. Used to check if
+   *                       the quantized exponential operation can be performed
+   * @param[out] output    Pointer to the output tensor
+   *
+   * @note Supported framework: TensorFlow Lite micro (bit-accurate)
+   *
+   */
+
+void arm_softmax_s8(const int8_t *input,
+                    const int32_t num_rows,
+                    const int32_t row_size,
+                    const int32_t mult,
+                    const int32_t shift,
+                    const int8_t diff_min,
+                    int8_t *output);
+
   /**
    * @brief uint8 depthwise convolution function with asymmetric quantization for even number of channel multiplier
    *        and input channels. Unless specified otherwise, arguments are mandatory.

+ 69 - 0
CMSIS/NN/Include/arm_nnsupportfunctions.h

@@ -40,9 +40,13 @@ extern    "C"
 
 #define LEFT_SHIFT(_shift)  (_shift > 0 ? _shift : 0)
 #define RIGHT_SHIFT(_shift) (_shift > 0 ? 0 : -_shift)
+#define MASK_IF_ZERO(x)     (x) == 0 ? ~0 : 0
+#define MASK_IF_NON_ZERO(x) (x) != 0 ? ~0 : 0
+#define SELECT_USING_MASK(mask, a, b) ((mask) & (a)) ^ (~(mask) & (b))
 
 #define MAX(A,B) ((A) > (B) ? (A) : (B))
 #define MIN(A,B) ((A) < (B) ? (A) : (B))
+#define CLAMP(x, h, l) MAX(MIN((x), (h)), (l))
 
 /**
  * @brief Union for SIMD access of q31/q15/q7 types
@@ -343,6 +347,13 @@ void arm_nn_mult_q7(
     #define NN_ROUND(out_shift) 0
 #endif
 
+// Macros for shortening quantization functions' names and avoid long lines
+#define MUL_SAT(a, b)  arm_nn_sat_doubling_high_mult((a), (b))
+#define MUL_POW2(a, b) arm_nn_mult_by_power_of_two((a), (b))
+#define DIV_POW2(a, b) arm_nn_divide_by_power_of_two((a), (b))
+#define EXP_ON_NEG(x)  arm_nn_exp_on_negative_values((x))
+#define ONE_OVER1(x)   arm_nn_one_over_one_plus_x_for_x_in_0_1((x))
+
 /**
  * @brief           Saturating doubling high multiply. Result matches
  *                  NEON instruction VQRDMULH.
@@ -467,6 +478,64 @@ __STATIC_FORCEINLINE int32x4_t arm_mve_requantize(const int32x4_t val, const q31
 }
 #endif
 
+// @note The following functions are used only for softmax layer, scaled bits = 5 assumed
+
+__STATIC_FORCEINLINE int32_t arm_nn_exp_on_negative_values(int32_t val)
+{
+    int32_t mask  = 0;
+    int32_t shift = 24;
+
+    const int32_t val_mod_minus_quarter = (val & ((1 << shift) - 1)) - (1 << shift);
+    const int32_t remainder             = val_mod_minus_quarter - val;
+    const int32_t x                     = (val_mod_minus_quarter << 5) + (1 << 28);
+    const int32_t x2                    = MUL_SAT(x, x);
+
+    int32_t result = 1895147668 + MUL_SAT(1895147668, x +
+        DIV_POW2(MUL_SAT(DIV_POW2(MUL_SAT(x2, x2), 2) + MUL_SAT(x2, x), 715827883) + x2, 1));
+
+#define SELECT_IF_NON_ZERO(x)                                     \
+{                                                                 \
+    mask   = MASK_IF_NON_ZERO(remainder & (1 << shift++));        \
+    result = SELECT_USING_MASK(mask, MUL_SAT(result, x), result); \
+}
+
+    SELECT_IF_NON_ZERO(1672461947)
+    SELECT_IF_NON_ZERO(1302514674)
+    SELECT_IF_NON_ZERO(790015084)
+    SELECT_IF_NON_ZERO(290630308)
+    SELECT_IF_NON_ZERO(39332535)
+    SELECT_IF_NON_ZERO(720401)
+    SELECT_IF_NON_ZERO(242)
+
+#undef SELECT_IF_NON_ZERO
+
+    mask = MASK_IF_ZERO(val);
+    return SELECT_USING_MASK(mask, Q31_MAX, result);
+}
+
+__STATIC_FORCEINLINE q31_t arm_nn_mult_by_power_of_two(const int32_t val, const int32_t exp)
+{
+    const int32_t thresh = ((1 << (31 - exp)) - 1);
+    int32_t result = val << exp;
+    result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val > thresh), Q31_MAX, result);
+    result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val < -thresh), Q31_MIN, result);
+    return result;
+}
+
+__STATIC_FORCEINLINE int32_t arm_nn_one_over_one_plus_x_for_x_in_0_1(int32_t val)
+{
+    const int64_t sum = (int64_t)val + (int64_t)Q31_MAX;
+    const int32_t half_denominator = (int32_t)((sum + (sum >= 0 ? 1 : -1)) / 2L);
+    int32_t x = 1515870810 + MUL_SAT(half_denominator, -1010580540);
+
+    const int32_t shift = (1 << 29);
+    x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
+    x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
+    x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
+
+    return MUL_POW2(x, 1);
+}
+
 #ifdef __cplusplus
 }
 #endif

+ 1 - 1
CMSIS/NN/Source/ReshapeFunctions/arm_reshape_s8.c

@@ -22,7 +22,7 @@
  * Description:  Reshape a s8 vector
  *
  * $Date:        September 2019
- * $Revision:    0.0.1
+ * $Revision:    V.1.0.0
  *
  * Target Processor:  Cortex-M cores
  *

+ 92 - 0
CMSIS/NN/Source/SoftmaxFunctions/arm_softmax_s8.c

@@ -0,0 +1,92 @@
+/*
+ * Copyright (C) 2010-2019 Arm Limited or its affiliates. All rights reserved.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the License); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* ----------------------------------------------------------------------
+ * Project:      CMSIS NN Library
+ * Title:        arm_softmax_s8.c
+ * Description:  S8 softmax function
+ *
+ * $Date:        October 2019
+ * $Revision:    V.1.0.0
+ *
+ * Target Processor:  Cortex-M cores
+ *
+ * -------------------------------------------------------------------- */
+
+#include "arm_nnfunctions.h"
+
+#define ACCUM_BITS 12
+
+/**
+ *  @ingroup groupNN
+ */
+
+/**
+ * @addtogroup Softmax
+ * @{
+ */
+void arm_softmax_s8(const int8_t *input,
+                    const int32_t num_rows,
+                    const int32_t row_size,
+                    const int32_t mult,
+                    const int32_t shift,
+                    const int8_t diff_min,
+                    int8_t *output)
+{
+    const int32_t mask = (1 << shift);
+
+    uint16_t row = 0;
+    uint16_t col = 0;
+    for(row = 0; row < num_rows; ++row)
+    {
+        const int32_t row_idx = row * row_size;
+
+        int8_t max = input[row_idx];
+        for (col = 1; col < row_size; ++col)
+        {
+            max = MAX(max, input[row_idx + col]);
+        }
+
+        int32_t sum = 0;
+        for (col = 0; col < row_size; ++col)
+        {
+            const int8_t diff = input[row_idx + col] - max;
+            if (diff >= diff_min)
+            {
+                sum += DIV_POW2(EXP_ON_NEG(MUL_SAT(diff * mask, mult)), ACCUM_BITS);
+            }
+        }
+
+        const int32_t headroom = __CLZ(sum);
+        const int32_t bits_over_unit = ACCUM_BITS - headroom;
+        const int32_t shifted_scale = ONE_OVER1((sum << headroom) - (1 << 31));
+        for (col = 0; col < row_size; ++col)
+        {
+            const int8_t diff = input[row_idx + col] - max;
+            if (diff >= diff_min)
+            {
+                const int32_t out_val = DIV_POW2(MUL_SAT(shifted_scale, EXP_ON_NEG(MUL_SAT(diff * mask, mult))), bits_over_unit + 23) - 128;
+                output[row_idx + col] = (int8_t) CLAMP(out_val, (int32_t)127, (int32_t)-128);
+            }
+            else
+            {
+                output[row_idx + col] = -128;
+            }
+        }
+    }
+}