فهرست منبع

Complete the update and testing of softmax functions

Liangzhen Lai 8 سال پیش
والد
کامیت
4e515bb60f
2فایلهای تغییر یافته به همراه45 افزوده شده و 28 حذف شده
  1. 30 18
      CMSIS/NN/Source/SoftmaxFunctions/arm_softmax_q15.c
  2. 15 10
      CMSIS/NN/Source/SoftmaxFunctions/arm_softmax_q7.c

+ 30 - 18
CMSIS/NN/Source/SoftmaxFunctions/arm_softmax_q15.c

@@ -21,7 +21,7 @@
  * Title:        arm_softmax_q15.c
  * Description:  Q15 softmax function
  *
- * $Date:        17. January 2018
+ * $Date:        20. February 2018
  * $Revision:    V.1.0.0
  *
  * Target Processor:  Cortex-M cores
@@ -64,41 +64,53 @@ void arm_softmax_q15(const q15_t * vec_in, const uint16_t dim_vec, q15_t * p_out
 {
     q31_t     sum;
     int16_t   i;
-    q31_t     min, max;
-    max = -1 * 0x100000;
-    min = 0x100000;
+    uint8_t   shift;
+    q31_t     base;
+    base = -1 * 0x100000;
     for (i = 0; i < dim_vec; i++)
     {
-        if (vec_in[i] > max)
+        if (vec_in[i] > base)
         {
-            max = vec_in[i];
-        }
-        if (vec_in[i] < min)
-        {
-            min = vec_in[i];
+            base = vec_in[i];
         }
     }
 
     /* we ignore really small values  
      * anyway, they will be 0 after shrinking
-     * to q7_t
+     * to q15_t
      */
-    if (max - min > 16)
-    {
-        min = max - 16;
-    }
+    base = base - 16;
 
     sum = 0;
 
     for (i = 0; i < dim_vec; i++)
     {
-        sum += 0x1 << (vec_in[i] - min);
+        if (vec_in[i] > base)
+        {
+            shift = (uint8_t)__USAT(vec_in[i] - base, 5);
+            sum += 0x1 << shift;
+        }
     }
 
+    /* This is effectively (0x1 << 32) / sum */
+    int64_t div_base = 0x100000000LL;
+    int output_base = (int32_t)(div_base / sum);
+
+    /* Final confidence will be output_base >> ( 17 - (vec_in[i] - base) )
+     * so 32768 (0x1<<15) -> 100% confidence when sum = 0x1 << 16, output_base = 0x1 << 16
+     * and vec_in[i]-base = 16
+     */
     for (i = 0; i < dim_vec; i++)
     {
-        /* we leave 7-bit dynamic range, so that 128 -> 100% confidence */
-        p_out[i] = (q15_t) __SSAT(((0x1 << (vec_in[i] - min + 14)) / sum), 16);
+        if (vec_in[i] > base) 
+        {
+            /* Here minimum value of 17+base-vec[i] will be 1 */
+            shift = (uint8_t)__USAT(17+base-vec_in[i], 5);
+            p_out[i] = (q15_t) __SSAT((output_base >> shift), 16);
+        } else
+        {
+            p_out[i] = 0;
+        }
     }
 
 }

+ 15 - 10
CMSIS/NN/Source/SoftmaxFunctions/arm_softmax_q7.c

@@ -21,7 +21,7 @@
  * Title:        arm_softmax_q7.c
  * Description:  Q7 softmax function
  *
- * $Date:        17. January 2018
+ * $Date:        20. February 2018
  * $Revision:    V.1.0.0
  *
  * Target Processor:  Cortex-M cores
@@ -77,7 +77,8 @@ void arm_softmax_q7(const q7_t * vec_in, const uint16_t dim_vec, q7_t * p_out)
         }
     }
 
-    /* So the base is set to max-8, meaning 
+    /* 
+     * So the base is set to max-8, meaning 
      * that we ignore really small values. 
      * anyway, they will be 0 after shrinking to q7_t.
      */
@@ -87,23 +88,27 @@ void arm_softmax_q7(const q7_t * vec_in, const uint16_t dim_vec, q7_t * p_out)
 
     for (i = 0; i < dim_vec; i++)
     {
-        if (vec_in[i] > base) {
-          shift = (uint8_t)__USAT(vec_in[i] - base, 7);
-          sum += 0x1 << shift;
+        if (vec_in[i] > base) 
+        {
+            shift = (uint8_t)__USAT(vec_in[i] - base, 5);
+            sum += 0x1 << shift;
         }
     }
 
     /* This is effectively (0x1 << 20) / sum */
     int output_base = 0x100000 / sum;
 
-    /* Final confidence will be output_base >> ( 13 - (vec_in[i] - min) )
-     * so 128 (0x1<<7) -> 100% confidence when output_base = 0x1<<12 and vec_in[i]-min = 8
+    /* 
+     * Final confidence will be output_base >> ( 13 - (vec_in[i] - base) )
+     * so 128 (0x1<<7) -> 100% confidence when sum = 0x1 << 8, output_base = 0x1 << 12 
+     * and vec_in[i]-base = 8
      */
     for (i = 0; i < dim_vec; i++) 
     {
-        if (vec_in[i] > base) {
-            /* Here minimum value of 13+min-vec_in[i] will be 5 */
-            shift = (uint8_t)__USAT(13+base-vec_in[i], 7);
+        if (vec_in[i] > base) 
+        {
+            /* Here minimum value of 13+base-vec_in[i] will be 5 */
+            shift = (uint8_t)__USAT(13+base-vec_in[i], 5);
             p_out[i] = (q7_t) __SSAT((output_base >> shift), 8);
         } else {
             p_out[i] = 0;