فهرست منبع

CMSIS-NN: Add MVEI support for int16 max pooling (#1509)

Change-Id: Ia75af0245dbe471aae629110b53b0d0998afca2b
Måns Nilsson 3 سال پیش
والد
کامیت
2edf28404a
2فایلهای تغییر یافته به همراه36 افزوده شده و 4 حذف شده
  1. 1 1
      README.md
  2. 35 3
      Source/PoolingFunctions/arm_max_pool_s16.c

+ 1 - 1
README.md

@@ -46,7 +46,7 @@ Group | API | Base Operator | Input Constraints | Additional memory required for
 || arm_avgpool_s8() | AVERAGE POOL | None | input_ch * 4<br/>(DSP only) | Yes| Yes| Best case is when channels are multiple of 4 or <br/> at the least >= 4 |
 || arm_avgpool_s16() | AVERAGE POOL | None | input_ch * 4<br/>(DSP only) | Yes| No| Best case is when channels are multiple of 4 or <br/> at the least >= 4 |
 || arm_maxpool_s8() | MAX POOL | None | None | Yes| Yes|  |
-|| arm_maxpool_s16() | MAX POOL | None | None | No| No|  |
+|| arm_maxpool_s16() | MAX POOL | None | None | No| Yes|  |
 |[Softmax](https://arm-software.github.io/CMSIS_5/NN/html/group__Softmax.html)||||| |  ||
 ||arm_softmax_q7()| SOFTMAX | None | None | Yes | No | Not bit exact to TFLu but can be up to 70x faster |
 ||arm_softmax_s8()| SOFTMAX | None | None | No | Yes | Bit exact to TFLu |

+ 35 - 3
Source/PoolingFunctions/arm_max_pool_s16.c

@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2022 Arm Limited or its affiliates.
+ * SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
  *
  * SPDX-License-Identifier: Apache-2.0
  *
@@ -21,8 +21,8 @@
  * Title:        arm_max_pool_s16.c
  * Description:  Pooling function implementations
  *
- * $Date:        19 April 2022
- * $Revision:    V.2.0.0
+ * $Date:        20 June 2022
+ * $Revision:    V.2.1.0
  *
  * Target Processor:  Cortex-M CPUs
  *
@@ -33,6 +33,20 @@
 
 static void compare_and_replace_if_larger(int16_t *base, const int16_t *target, int32_t length)
 {
+#if defined(ARM_MATH_MVEI)
+    int32_t loop_count = (length + 7) / 8;
+    for (int i = 0; i < loop_count; i++)
+    {
+        mve_pred16_t p = vctp16q((uint32_t)length);
+        const int16x8_t op_1 = vldrhq_z_s16(base, p);
+        const int16x8_t op_2 = vldrhq_z_s16(target, p);
+        const int16x8_t max = vmaxq_s16(op_1, op_2);
+        vstrhq_p_s16(base, max, p);
+        base += 8;
+        target += 8;
+        length -= 8;
+    }
+#else
     q15_t *dst = base;
     const q15_t *src = target;
     union arm_nnword ref_max;
@@ -65,10 +79,27 @@ static void compare_and_replace_if_larger(int16_t *base, const int16_t *target,
             *dst = *src;
         }
     }
+#endif
 }
 
 static void clamp_output(int16_t *source, int32_t length, const int16_t act_min, const int16_t act_max)
 {
+#if defined(ARM_MATH_MVEI)
+    const int16x8_t min = vdupq_n_s16((int16_t)act_min);
+    const int16x8_t max = vdupq_n_s16((int16_t)act_max);
+
+    int32_t loop_count = (length + 7) / 8;
+    for (int i = 0; i < loop_count; i++)
+    {
+        mve_pred16_t p = vctp16q((uint32_t)length);
+        length -= 8;
+        const int16x8_t src = vldrhq_z_s16(source, p);
+        int16x8_t res = vmaxq_m_s16(vuninitializedq_s16(), src, min, p);
+        res = vminq_m_s16(vuninitializedq_s16(), res, max, p);
+        vstrhq_p_s16(source, res, p);
+        source += 8;
+    }
+#else
     union arm_nnword in;
     int32_t cnt = length >> 1;
 
@@ -92,6 +123,7 @@ static void clamp_output(int16_t *source, int32_t length, const int16_t act_min,
         comp = MIN(comp, act_max);
         *source = comp;
     }
+#endif
 }
 
 /**