Browse Source

CMSIS-NN: Match im2col buffers to number of lanes(MVE) (#1355)

Predication in load of data from im2col buffers are removed
and the buffer length is adjusted to match that of the number of
lanes(8) for the used data type(int16).

Change-Id: I0f58a577408ac449e442a876cb46589891b68f49
felix-johnny 4 năm trước cách đây
mục cha
commit
6c8ebf7a3d

+ 1 - 1
CMSIS/NN/README.md

@@ -24,7 +24,7 @@ Group | API | Base Operator | Input Constraints | Additional memory required for
 |:----| :---| :------------ | :---------------- | :--------------------------------------------------------| :-------------| :------------- | :------------- |
 |[Conv](https://arm-software.github.io/CMSIS_5/NN/html/group__NNConv.html)||||| |  ||
 ||arm_convolve_wrapper_s8()|CONV|dilation = 1|n.a.| Yes | Yes |The additional memory required depends on the optimal convolution function called|
-||arm_convolve_s8()|CONV|dilation = 1|4 * ker_x * ker_y * input_ch| Yes | Yes ||
+||arm_convolve_s8()|CONV|dilation = 1|4 * (ker_x * ker_y * input_ch + delta)| Yes | Yes |delta - MVE only|
 ||arm_convolve_1x1_s8_fast() | CONV | dilation = 1 <br/> ker_x = 1, ker_y = 1 <br/> pad = 0<br/> stride = 1<br/> input_ch % 4 = 0| 0 | Yes |Yes ||
 ||arm_convolve_1_n_s8() | CONV | dilation = 1 <br/> output_y % 4 = 0 | No |Yes ||
 || arm_depthwise_conv_3x3_s8() | DEPTHWISE_CONV | dilation = 1 <br/> depth_multiplier = 1 <br/> pad_x <= 1 | No|No|No| Preferred function for 3x3 kernel size for DSP extension. </br> For MVE, use arm_depthwise_conv_s8_opt()||

+ 11 - 4
CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_s8.c

@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2010-2021 Arm Limited or its affiliates. All rights reserved.
+ * Copyright (C) 2010-2021 Arm Limited or its affiliates.
  *
  * SPDX-License-Identifier: Apache-2.0
  *
@@ -21,8 +21,8 @@
  * Title:        arm_convolve_s8.c
  * Description:  s8 version of convolution using symmetric quantization.
  *
- * $Date:        June 23, 2021
- * $Revision:    V.2.0.5
+ * $Date:        October 27, 2021
+ * $Revision:    V.2.0.7
  *
  * Target Processor:  Cortex-M cores
  *
@@ -366,7 +366,14 @@ arm_status arm_convolve_s8(const cmsis_nn_context *ctx,
 
 int32_t arm_convolve_s8_get_buffer_size(const cmsis_nn_dims *input_dims, const cmsis_nn_dims *filter_dims)
 {
-#if defined(ARM_MATH_DSP)
+#if defined(ARM_MATH_MVEI)
+    int32_t col_length = input_dims->c * filter_dims->w * filter_dims->h;
+    // Get number of complete int16 lanes(multiple of 8) for given col_length. This is dependent on
+    // implementation of  arm_nn_mat_mult_s8
+    col_length = (col_length + 7) / 8;
+    // 4 -> number of im2col buffers, 8 -> 8 elements per Q register
+    return 4 * col_length * 8 * (int32_t)sizeof(int8_t);
+#elif defined(ARM_MATH_DSP)
     return (2 * input_dims->c * filter_dims->w * filter_dims->h) * (int32_t)sizeof(int16_t);
 #else
     (void)input_dims;

+ 18 - 18
CMSIS/NN/Source/ConvolutionFunctions/arm_nn_mat_mult_s8.c

@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2010-2020 Arm Limited or its affiliates. All rights reserved.
+ * Copyright (C) 2010-2021 Arm Limited or its affiliates.
  *
  * SPDX-License-Identifier: Apache-2.0
  *
@@ -21,8 +21,8 @@
  * Title:        arm_nn_mat_mult_s8.c
  * Description:  General Matrix-multiplication function
  *
- * $Date:        09. October 2020
- * $Revision:    V.2.0.5
+ * $Date:        27. October 2021
+ * $Revision:    V.2.0.6
  *
  * Target Processor:  Cortex-M cores
  * -------------------------------------------------------------------- */
@@ -76,24 +76,24 @@ q7_t *arm_nn_mat_mult_s8(const q7_t *input_row,
                 const int16x8_t offset = vdupq_m_n_s16(vuninitializedq_s16(), col_offset, p);
                 row_len_tmp -= 8;
 
-                int16x8_t r0 = vldrbq_z_s16(ip_r0, p);
-                ip_r0 += 8;
-
-                int16x8_t c0 = vldrbq_z_s16(ip_c0, p);
+                int16x8_t c0 = vldrbq_s16(ip_c0);
                 ip_c0 += 8;
-                c0 = vaddq_m_s16(vuninitializedq_s16(), c0, offset, p);
+                c0 = vaddq_s16(c0, offset);
 
-                int16x8_t c1 = vldrbq_z_s16(ip_c1, p);
+                int16x8_t c1 = vldrbq_s16(ip_c1);
                 ip_c1 += 8;
-                c1 = vaddq_m_s16(vuninitializedq_s16(), c1, offset, p);
+                c1 = vaddq_s16(c1, offset);
 
-                int16x8_t c2 = vldrbq_z_s16(ip_c2, p);
+                int16x8_t c2 = vldrbq_s16(ip_c2);
                 ip_c2 += 8;
-                c2 = vaddq_m_s16(vuninitializedq_s16(), c2, offset, p);
+                c2 = vaddq_s16(c2, offset);
 
-                int16x8_t c3 = vldrbq_z_s16(ip_c3, p);
+                int16x8_t c3 = vldrbq_s16(ip_c3);
                 ip_c3 += 8;
-                c3 = vaddq_m_s16(vuninitializedq_s16(), c3, offset, p);
+                c3 = vaddq_s16(c3, offset);
+
+                int16x8_t r0 = vldrbq_z_s16(ip_r0, p);
+                ip_r0 += 8;
 
                 acc_0 = vmladavaq_p_s16(acc_0, r0, c0, p);
                 acc_1 = vmladavaq_p_s16(acc_1, r0, c1, p);
@@ -136,12 +136,12 @@ q7_t *arm_nn_mat_mult_s8(const q7_t *input_row,
                     const int16x8_t offset = vdupq_m_n_s16(vuninitializedq_s16(), col_offset, p);
                     row_len_tmp -= 8;
 
-                    int16x8_t r0 = vldrbq_z_s16(ip_r0, p);
-                    ip_r0 += 8;
-                    int16x8_t c0 = vldrbq_z_s16(ip_c0, p);
+                    int16x8_t c0 = vldrbq_s16(ip_c0);
                     ip_c0 += 8;
+                    c0 = vaddq_s16(c0, offset);
 
-                    c0 = vaddq_m_s16(vuninitializedq_s16(), c0, offset, p);
+                    int16x8_t r0 = vldrbq_z_s16(ip_r0, p);
+                    ip_r0 += 8;
                     acc_0 = vmladavaq_p_s16(acc_0, r0, c0, p);
                 }