Эх сурвалжийг харах

CMSIS-NN: Fix typecast error in Fully connected s8

Removed traling spaces and unified naming for input arguments.
Felix Johnny 6 жил өмнө
parent
commit
2cf84a33e9

+ 19 - 19
CMSIS/NN/Source/FullyConnectedFunctions/arm_fully_connected_s8.c

@@ -40,7 +40,7 @@
  * @{
  * @{
  */
  */
 
 
-  /*
+/*
    * S8 basic fully-connected and matrix multiplication layer function for TensorFlow Lite
    * S8 basic fully-connected and matrix multiplication layer function for TensorFlow Lite
    *
    *
    * Refer header file for details.
    * Refer header file for details.
@@ -48,8 +48,8 @@
    */
    */
 
 
 arm_status
 arm_status
-arm_fully_connected_s8(const int8_t *pInput,
-                       const int8_t *pWeight,
+arm_fully_connected_s8(const int8_t *input,
+                       const int8_t *kernel,
                        const uint16_t col_dim,
                        const uint16_t col_dim,
                        const uint16_t row_dim,
                        const uint16_t row_dim,
                        const uint16_t nb_batches,
                        const uint16_t nb_batches,
@@ -58,8 +58,8 @@ arm_fully_connected_s8(const int8_t *pInput,
                        const int32_t out_mult,
                        const int32_t out_mult,
                        const int32_t out_shift,
                        const int32_t out_shift,
                        const int32_t output_offset,
                        const int32_t output_offset,
-                       const int8_t *pBias,
-                       int8_t *pOut,
+                       const int8_t *bias,
+                       int8_t *output,
                        const int32_t output_activation_min,
                        const int32_t output_activation_min,
                        const int32_t output_activation_max,
                        const int32_t output_activation_max,
                        q15_t *vec_buffer)
                        q15_t *vec_buffer)
@@ -76,10 +76,10 @@ arm_fully_connected_s8(const int8_t *pInput,
        which are used in this implementation.
        which are used in this implementation.
 
 
     */
     */
-    const int8_t *pBiasTmp = pBias;
-    const q7_t *pB = pWeight;
+    const int8_t *pBiasTmp = bias;
+    const q7_t *pB = kernel;
     const q7_t *pB2;
     const q7_t *pB2;
-    q7_t *pO = pOut;
+    q7_t *pO = output;
     q15_t *pA;
     q15_t *pA;
     q31_t ioffset;
     q31_t ioffset;
     q31_t foffset;
     q31_t foffset;
@@ -89,9 +89,9 @@ arm_fully_connected_s8(const int8_t *pInput,
 
 
     while (batchCnt)
     while (batchCnt)
     {
     {
-        pBiasTmp = pBias;
-        pB = pWeight;
-        arm_q7_to_q15_reordered_no_shift(pInput, vec_buffer, col_dim);
+        pBiasTmp = bias;
+        pB = kernel;
+        arm_q7_to_q15_reordered_no_shift(input, vec_buffer, col_dim);
         uint16_t rowCnt = row_dim >> 1;
         uint16_t rowCnt = row_dim >> 1;
         /* Unroll on the rows */
         /* Unroll on the rows */
         while (rowCnt)
         while (rowCnt)
@@ -207,26 +207,26 @@ arm_fully_connected_s8(const int8_t *pInput,
 
 
             rowCnt--;
             rowCnt--;
         }
         }
-        pInput += col_dim;
+        input += col_dim;
         batchCnt--;
         batchCnt--;
     }
     }
     return (ARM_MATH_SUCCESS);
     return (ARM_MATH_SUCCESS);
 
 
 #else
 #else
     const int8_t *pInputA;
     const int8_t *pInputA;
-    const int8_t *pBiasTmp = pBias;
-    const int8_t *pWeightTmp = pWeight;
+    const int8_t *pBiasTmp = bias;
+    const int8_t *pWeightTmp = kernel;
     uint16_t batchCnt = nb_batches;
     uint16_t batchCnt = nb_batches;
 
 
     while (batchCnt)
     while (batchCnt)
     {
     {
-        pBiasTmp = pBias;
-        pWeightTmp = pWeight;
+        pBiasTmp = bias;
+        pWeightTmp = kernel;
         for (int out_c = 0; out_c < row_dim; out_c++)
         for (int out_c = 0; out_c < row_dim; out_c++)
         {
         {
 
 
             int32_t acc = *pBiasTmp++;
             int32_t acc = *pBiasTmp++;
-            pInputA = pInput;
+            pInputA = input;
             for (int d = 0; d < col_dim; d++)
             for (int d = 0; d < col_dim; d++)
             {
             {
 
 
@@ -245,9 +245,9 @@ arm_fully_connected_s8(const int8_t *pInput,
             acc = MAX(acc, output_activation_min);
             acc = MAX(acc, output_activation_min);
             acc = MIN(acc, output_activation_max);
             acc = MIN(acc, output_activation_max);
 
 
-            *pOut++ = (uint8_t)(acc);
+            *output++ = (int8_t)(acc);
         }
         }
-        pInput += col_dim;
+        input += col_dim;
         batchCnt--;
         batchCnt--;
     }
     }
     return (ARM_MATH_SUCCESS);
     return (ARM_MATH_SUCCESS);