Просмотр исходного кода

CMSIS-NN: Add missing NULL bias check for conv_s8 (#1349)

The DSP impmplementation of generic conv procesing (arm_convolve_s8)
always expected the optional bias argument. Added a NULL bias
check. Removed unused MVEI implementation in same file.

Change-Id: I3cfa58123dae1a8606a67917c29c03847cef8403
felix-johnny 4 лет назад
Родитель
Сommit
e513dbdb2b
1 измененных файлов с 21 добавлено и 175 удалено
  1. 21 175
      CMSIS/NN/Source/ConvolutionFunctions/arm_nn_mat_mult_kernel_s8_s16.c

+ 21 - 175
CMSIS/NN/Source/ConvolutionFunctions/arm_nn_mat_mult_kernel_s8_s16.c

@@ -1,5 +1,5 @@
 /*
 /*
- * Copyright (C) 2010-2020 Arm Limited or its affiliates. All rights reserved.
+ * Copyright (C) 2010-2021 Arm Limited or its affiliates. All rights reserved.
  *
  *
  * SPDX-License-Identifier: Apache-2.0
  * SPDX-License-Identifier: Apache-2.0
  *
  *
@@ -49,174 +49,7 @@ q7_t *arm_nn_mat_mult_kernel_s8_s16(const q7_t *input_a,
                                     const int32_t *const output_bias,
                                     const int32_t *const output_bias,
                                     q7_t *out_0)
                                     q7_t *out_0)
 {
 {
-#if defined(ARM_MATH_MVEI)
-#define ROW_PER_LOOP (4)
-#define COL_PER_LOOP (8)
-
-    const q7_t *ip_a0_s8 = input_a;
-    q7_t *out_1 = out_0 + output_ch;
-
-    const int32_t *bias = output_bias;
-
-    int32_t row_count = output_ch / ROW_PER_LOOP;
-
-    while (row_count)
-    {
-        const q15_t *ip_b0_s16 = input_b;
-        const q15_t *ip_b1_s16 = input_b + num_col_a;
-
-        const q7_t *ip_a1_s8 = ip_a0_s8 + num_col_a;
-        const q7_t *ip_a2_s8 = ip_a0_s8 + num_col_a * 2;
-        const q7_t *ip_a3_s8 = ip_a0_s8 + num_col_a * 3;
-
-        q31_t ch_0_out_n = bias[0];
-        q31_t ch_1_out_n = bias[1];
-        q31_t ch_2_out_n = bias[2];
-        q31_t ch_3_out_n = bias[3];
-
-        q31_t ch_0_out_n1 = ch_0_out_n;
-        q31_t ch_1_out_n1 = ch_1_out_n;
-        q31_t ch_2_out_n1 = ch_2_out_n;
-        q31_t ch_3_out_n1 = ch_3_out_n;
-        bias += 4;
-
-        int32_t col_count = num_col_a / COL_PER_LOOP;
-
-        while (col_count)
-        {
-            // Load inputs
-            const int16x8_t ip_b0 = vld1q_s16(ip_b0_s16);
-            ip_b0_s16 += COL_PER_LOOP;
-            const int16x8_t ip_b1 = vld1q_s16(ip_b1_s16);
-            ip_b1_s16 += COL_PER_LOOP;
-
-            // Load filters
-            const int16x8_t ip_a0 = vldrbq_s16(ip_a0_s8);
-            ip_a0_s8 += COL_PER_LOOP;
-            const int16x8_t ip_a1 = vldrbq_s16(ip_a1_s8);
-            ip_a1_s8 += COL_PER_LOOP;
-            const int16x8_t ip_a2 = vldrbq_s16(ip_a2_s8);
-            ip_a2_s8 += COL_PER_LOOP;
-            const int16x8_t ip_a3 = vldrbq_s16(ip_a3_s8);
-            ip_a3_s8 += COL_PER_LOOP;
-
-            // MAC
-            ch_0_out_n += vmladavq_s16(ip_b0, ip_a0);
-            ch_1_out_n += vmladavq_s16(ip_b0, ip_a1);
-            ch_2_out_n += vmladavq_s16(ip_b0, ip_a2);
-            ch_3_out_n += vmladavq_s16(ip_b0, ip_a3);
-            ch_0_out_n1 += vmladavq_s16(ip_b1, ip_a0);
-            ch_1_out_n1 += vmladavq_s16(ip_b1, ip_a1);
-            ch_2_out_n1 += vmladavq_s16(ip_b1, ip_a2);
-            ch_3_out_n1 += vmladavq_s16(ip_b1, ip_a3);
-
-            col_count--;
-        }
-
-        /* Handle tail */
-        col_count = (num_col_a & (COL_PER_LOOP - 1)) - 1;
-        while (col_count >= 0)
-        {
-            const int32_t b0 = ip_b0_s16[col_count];
-            const int32_t b1 = ip_b1_s16[col_count];
-
-            ch_0_out_n += b0 * ip_a0_s8[col_count];
-            ch_1_out_n += b0 * ip_a1_s8[col_count];
-            ch_2_out_n += b0 * ip_a2_s8[col_count];
-            ch_3_out_n += b0 * ip_a3_s8[col_count];
-
-            ch_0_out_n1 += b1 * ip_a0_s8[col_count];
-            ch_1_out_n1 += b1 * ip_a1_s8[col_count];
-            ch_2_out_n1 += b1 * ip_a2_s8[col_count];
-            ch_3_out_n1 += b1 * ip_a3_s8[col_count];
-            col_count--;
-        }
-        ip_a0_s8 += (num_col_a & (COL_PER_LOOP - 1));
-
-        int32x4_t out_vec_0;
-        int32x4_t out_vec_1;
-        out_vec_0[0] = ch_0_out_n;
-        out_vec_0[1] = ch_1_out_n;
-        out_vec_0[2] = ch_2_out_n;
-        out_vec_0[3] = ch_3_out_n;
-
-        out_vec_1[0] = ch_0_out_n1;
-        out_vec_1[1] = ch_1_out_n1;
-        out_vec_1[2] = ch_2_out_n1;
-        out_vec_1[3] = ch_3_out_n1;
-
-        int32x4_t mult = vldrwq_s32(out_mult);
-        int32x4_t shift = vldrwq_s32(out_shift);
-        out_mult += ROW_PER_LOOP;
-        out_shift += ROW_PER_LOOP;
-
-        out_vec_0 = arm_requantize_mve_32x4(out_vec_0, mult, shift);
-        out_vec_1 = arm_requantize_mve_32x4(out_vec_1, mult, shift);
-
-        out_vec_0 = vaddq_n_s32(out_vec_0, out_offset);
-        out_vec_0 = vmaxq_s32(out_vec_0, vdupq_n_s32(activation_min));
-        out_vec_0 = vminq_s32(out_vec_0, vdupq_n_s32(activation_max));
-        vstrbq_s32(out_0, out_vec_0);
-        out_0 += ROW_PER_LOOP;
-
-        out_vec_1 = vaddq_n_s32(out_vec_1, out_offset);
-        out_vec_1 = vmaxq_s32(out_vec_1, vdupq_n_s32(activation_min));
-        out_vec_1 = vminq_s32(out_vec_1, vdupq_n_s32(activation_max));
-        vstrbq_s32(out_1, out_vec_1);
-        out_1 += ROW_PER_LOOP;
-        row_count--;
-        ip_a0_s8 += (num_col_a * 3);
-    }
-
-    row_count = output_ch & (ROW_PER_LOOP - 1);
-
-    if (row_count)
-    {
-        ip_a0_s8 = input_a + num_col_a * (output_ch & ~3);
-        const mve_pred16_t p = vctp32q((uint32_t)row_count);
-        int32x4_t out_vec_0 = vdupq_n_s32(0);
-        int32x4_t out_vec_1 = vdupq_n_s32(0);
-        int32x4_t mult_tail;
-        int32x4_t shift_tail;
-
-        for (int i_ch = 0; i_ch < row_count; i_ch++)
-        {
-            int32_t output_0 = bias[i_ch];
-            int32_t output_1 = bias[i_ch];
-            const q15_t *ip_b0_s16 = input_b;
-            const q15_t *ip_b1_s16 = input_b + num_col_a;
-
-            for (int i_idx = 0; i_idx < num_col_a; i_idx++)
-            {
-                output_0 += ip_b0_s16[i_idx] * ip_a0_s8[i_idx];
-                output_1 += ip_b1_s16[i_idx] * ip_a0_s8[i_idx];
-            }
-
-            ip_a0_s8 += num_col_a;
-            out_vec_0[i_ch] = output_0;
-            out_vec_1[i_ch] = output_1;
-            mult_tail[i_ch] = out_mult[i_ch];
-            shift_tail[i_ch] = out_shift[i_ch];
-        }
-        out_vec_0 = arm_requantize_mve_32x4(out_vec_0, mult_tail, shift_tail);
-        out_vec_1 = arm_requantize_mve_32x4(out_vec_1, mult_tail, shift_tail);
-
-        out_vec_0 = vaddq_n_s32(out_vec_0, out_offset);
-        out_vec_0 = vmaxq_s32(out_vec_0, vdupq_n_s32(activation_min));
-        out_vec_0 = vminq_s32(out_vec_0, vdupq_n_s32(activation_max));
-        vstrbq_p_s32(out_0, out_vec_0, p);
-
-        out_vec_1 = vaddq_n_s32(out_vec_1, out_offset);
-        out_vec_1 = vmaxq_s32(out_vec_1, vdupq_n_s32(activation_min));
-        out_vec_1 = vminq_s32(out_vec_1, vdupq_n_s32(activation_max));
-
-        vstrbq_p_s32(out_1, out_vec_1, p);
-        out_1 += row_count;
-    }
-
-    return out_1;
-
-#elif defined(ARM_MATH_DSP)
+#if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
     /* set up the second output pointers */
     /* set up the second output pointers */
     q7_t *out_1 = out_0 + output_ch;
     q7_t *out_1 = out_0 + output_ch;
     const int32_t *bias = output_bias;
     const int32_t *bias = output_bias;
@@ -233,11 +66,18 @@ q7_t *arm_nn_mat_mult_kernel_s8_s16(const q7_t *input_a,
         /* align the second pointer for A */
         /* align the second pointer for A */
         const q7_t *ip_a1 = ip_a0 + num_col_a;
         const q7_t *ip_a1 = ip_a0 + num_col_a;
 
 
+        q31_t ch_0_out_0 = 0;
+        q31_t ch_0_out_1 = 0;
+        q31_t ch_1_out_0 = 0;
+        q31_t ch_1_out_1 = 0;
         /* Init accumulator with bias for channel N and N + 1 */
         /* Init accumulator with bias for channel N and N + 1 */
-        q31_t ch_0_out_0 = *bias;
-        q31_t ch_0_out_1 = *bias++;
-        q31_t ch_1_out_0 = *bias;
-        q31_t ch_1_out_1 = *bias++;
+        if (bias)
+        {
+            ch_0_out_0 = *bias;
+            ch_0_out_1 = *bias++;
+            ch_1_out_0 = *bias;
+            ch_1_out_1 = *bias++;
+        }
 
 
         uint16_t col_count = num_col_a / 4;
         uint16_t col_count = num_col_a / 4;
         /* accumulate over the vector */
         /* accumulate over the vector */
@@ -320,9 +160,15 @@ q7_t *arm_nn_mat_mult_kernel_s8_s16(const q7_t *input_a,
         const q15_t *ip_b0 = input_b;
         const q15_t *ip_b0 = input_b;
         const q15_t *ip_b1 = ip_b0 + num_col_a;
         const q15_t *ip_b1 = ip_b0 + num_col_a;
 
 
+        q31_t ch_0_out_0 = 0;
+        q31_t ch_0_out_1 = 0;
+
         /* load the bias */
         /* load the bias */
-        q31_t ch_0_out_0 = *bias;
-        q31_t ch_0_out_1 = *bias++;
+        if (bias)
+        {
+            ch_0_out_0 = *bias;
+            ch_0_out_1 = *bias++;
+        }
 
 
         uint16_t col_count = num_col_a >> 2;
         uint16_t col_count = num_col_a >> 2;
         while (col_count)
         while (col_count)