فهرست منبع

CMSIS-NN: Fix incorrect data type usage in Softmax

1. Data type of diff min in input argument is changed to int32 from
   int8.
2. For non-MVE versions, the data type of  variable holding the
   difference with the max element in row is corrected from int8 to
   int32.

Change-Id: I67945a36bd926d1a8b346919fcb83b211ee7800f
Felix Johnny 5 سال پیش
والد
کامیت
de79398a45
2فایلهای تغییر یافته به همراه7 افزوده شده و 7 حذف شده
  1. 3 3
      Include/arm_nnfunctions.h
  2. 4 4
      Source/SoftmaxFunctions/arm_softmax_s8.c

+ 3 - 3
Include/arm_nnfunctions.h

@@ -21,8 +21,8 @@
  * Title:        arm_nnfunctions.h
  * Description:  Public header file for CMSIS NN Library
  *
- * $Date:        April 1, 2020
- * $Revision:    V.1.2.6
+ * $Date:        April 6, 2020
+ * $Revision:    V.2.0.0
  *
  * Target Processor:  Cortex-M cores
  * -------------------------------------------------------------------- */
@@ -1892,7 +1892,7 @@ void arm_softmax_s8(const int8_t *input,
                     const int32_t row_size,
                     const int32_t mult,
                     const int32_t shift,
-                    const int8_t diff_min,
+                    const int32_t diff_min,
                     int8_t *output);
 
   /**

+ 4 - 4
Source/SoftmaxFunctions/arm_softmax_s8.c

@@ -21,8 +21,8 @@
  * Title:        arm_softmax_s8.c
  * Description:  S8 softmax function
  *
- * $Date:        March 31, 2020
- * $Revision:    V.1.5.1
+ * $Date:        April 6, 2020
+ * $Revision:    V.2.0.0
  *
  * Target Processor:  Cortex-M cores
  *
@@ -86,7 +86,7 @@ void arm_softmax_s8(const int8_t *input,
                     const int32_t row_size,
                     const int32_t mult,
                     const int32_t shift,
-                    const int8_t diff_min,
+                    const int32_t diff_min,
                     int8_t *output)
 {
 #ifdef ARM_MATH_MVEI
@@ -217,7 +217,7 @@ void arm_softmax_s8(const int8_t *input,
             max = MAX(max, input[col]);
         }
 
-        int8_t diff = 0;
+        int32_t diff = 0;
         int32_t sum = 0;
 
         for (col = 0; col < row_size; ++col)