Просмотр исходного кода

Fix shifting issue in this file, waiting for validation.

Liangzhen Lai 8 лет назад
Родитель
Сommit
8d03d5dec2
1 измененных файлов с 29 добавлено и 21 удалено
  1. 29 21
      CMSIS/NN/Source/SoftmaxFunctions/arm_softmax_q7.c

+ 29 - 21
CMSIS/NN/Source/SoftmaxFunctions/arm_softmax_q7.c

@@ -64,43 +64,51 @@ void arm_softmax_q7(const q7_t * vec_in, const uint16_t dim_vec, q7_t * p_out)
 {
     q31_t     sum;
     int16_t   i;
-    q15_t     min, max;
-    max = -257;
-    min = 257;
+    uint8_t   shift;
+    q15_t     base;
+    base = -257;
+
+    /* We first search for the maximum */
     for (i = 0; i < dim_vec; i++)
     {
-        if (vec_in[i] > max)
-        {
-            max = vec_in[i];
-        }
-        if (vec_in[i] < min)
+        if (vec_in[i] > base)
         {
-            min = vec_in[i];
+            base = vec_in[i];
         }
     }
 
-    /* we ignore really small values  
-     * anyway, they will be 0 after shrinking
-     * to q7_t
+    /* So the base is set to max-8, meaning 
+     * that we ignore really small values. 
+     * anyway, they will be 0 after shrinking to q7_t.
      */
-    if (max - min > 8)
-    {
-        min = max - 8;
-    }
+    base = base - 8;
 
     sum = 0;
 
     for (i = 0; i < dim_vec; i++)
     {
-        sum += 0x1 << (vec_in[i] - min);
+        if (vec_in[i] > min) {
+          shift = (uint8_t)__USAT(vec_in[i] - min, 7);
+          sum += 0x1 << shift;
+        }
     }
 
-    for (i = 0; i < dim_vec; i++)
+    /* This is effectively (0x1 << 20) / sum */
+    int output_base = 0x100000 / sum;
+
+    /* Final confidence will be output_base >> ( 13 - (vec_in[i] - min) )
+     * so 128 (0x1<<7) -> 100% confidence when output_base = 0x1<<12 and vec_in[i]-min = 8
+     */
+    for (i = 0; i < dim_vec; i++) 
     {
-        /* we leave 7-bit dynamic range, so that 128 -> 100% confidence */
-        p_out[i] = (q7_t) __SSAT(((0x1 << (vec_in[i] - min + 20)) / sum) >> 13, 8);
+        if (vec_in[i] > min) {
+            /* Here minimum value of 13+min-vec_in[i] will be 5 */
+            shift = (uint8_t)__USAT(13+min-vec_in[i], 7);
+            p_out[i] = (q7_t) __SSAT((output_base >> shift), 8);
+        } else {
+            p_out[i] = 0;
+        }
     }
-
 }
 
 /**