arm_nn_mat_mul_kernel_s16.c 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. /*
  2. * SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates
  3. * <open-source-office@arm.com>
  4. *
  5. * SPDX-License-Identifier: Apache-2.0
  6. *
  7. * Licensed under the Apache License, Version 2.0 (the License); you may
  8. * not use this file except in compliance with the License.
  9. * You may obtain a copy of the License at
  10. *
  11. * www.apache.org/licenses/LICENSE-2.0
  12. *
  13. * Unless required by applicable law or agreed to in writing, software
  14. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  15. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. * See the License for the specific language governing permissions and
  17. * limitations under the License.
  18. */
  19. /* ----------------------------------------------------------------------
  20. * Project: CMSIS NN Library
  21. * Title: arm_nn_mat_mult_kernel_s16.c
  22. * Description: Matrix-multiplication function for convolution
  23. *
  24. * $Date: 5 Janauray 2023
  25. * $Revision: V.1.2.0
  26. *
  27. * Target : Arm(R) M-Profile Architecture
  28. * -------------------------------------------------------------------- */
  29. #include "arm_nnfunctions.h"
  30. #include "arm_nnsupportfunctions.h"
  31. /**
  32. * @ingroup groupSupport
  33. */
  34. /**
  35. * @addtogroup supportConvolution
  36. * @{
  37. */
  38. /*
  39. * Matrix-multiplication function for convolution with per-channel requantization.
  40. *
  41. * Refer header file for details.
  42. *
  43. */
  44. int16_t *arm_nn_mat_mult_kernel_s16(const int8_t *input_a,
  45. const int16_t *input_b,
  46. const int32_t output_ch,
  47. const int32_t *out_shift,
  48. const int32_t *out_mult,
  49. const int16_t activation_min,
  50. const int16_t activation_max,
  51. const int32_t num_col_a,
  52. const int64_t *const output_bias,
  53. int16_t *out_0)
  54. {
  55. #if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
  56. /* set up the second output pointers */
  57. int16_t *out_1 = out_0 + output_ch;
  58. const int64_t *bias = output_bias;
  59. uint16_t row_count = output_ch / 2;
  60. const int8_t *ip_a0 = input_a;
  61. /* this loop over rows in A */
  62. while (row_count)
  63. {
  64. /* setup pointers for B */
  65. const int16_t *ip_b0 = input_b;
  66. const int16_t *ip_b1 = ip_b0 + num_col_a;
  67. /* align the second pointer for A */
  68. const int8_t *ip_a1 = ip_a0 + num_col_a;
  69. /* Init accumulator for channel N and N + 1 */
  70. int32_t ch_0_out_0 = 0;
  71. int32_t ch_0_out_1 = 0;
  72. int32_t ch_1_out_0 = 0;
  73. int32_t ch_1_out_1 = 0;
  74. uint16_t col_count = num_col_a / 4;
  75. /* accumulate over the vector */
  76. while (col_count)
  77. {
  78. int32_t a01, a02, a11, a12;
  79. int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
  80. int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
  81. ip_a0 = read_and_pad(ip_a0, &a01, &a02);
  82. ip_a1 = read_and_pad(ip_a1, &a11, &a12);
  83. ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
  84. ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
  85. ch_1_out_0 = SMLAD(a11, b0, ch_1_out_0);
  86. ch_1_out_1 = SMLAD(a11, b1, ch_1_out_1);
  87. b0 = arm_nn_read_q15x2_ia(&ip_b0);
  88. b1 = arm_nn_read_q15x2_ia(&ip_b1);
  89. ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
  90. ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
  91. ch_1_out_0 = SMLAD(a12, b0, ch_1_out_0);
  92. ch_1_out_1 = SMLAD(a12, b1, ch_1_out_1);
  93. col_count--;
  94. } /* while over col_count */
  95. col_count = num_col_a & 0x3;
  96. while (col_count)
  97. {
  98. int8_t a0 = *ip_a0++;
  99. int16_t b0 = *ip_b0++;
  100. int8_t a1 = *ip_a1++;
  101. int16_t b1 = *ip_b1++;
  102. ch_0_out_0 += a0 * b0;
  103. ch_0_out_1 += a0 * b1;
  104. ch_1_out_0 += a1 * b0;
  105. ch_1_out_1 += a1 * b1;
  106. col_count--;
  107. } /* while over col_count */
  108. if (bias)
  109. {
  110. int32_t reduced_multiplier = REDUCE_MULTIPLIER(*out_mult);
  111. int64_t acc_64 = ch_0_out_0 + *bias;
  112. ch_0_out_0 = arm_nn_requantize_s64(acc_64, reduced_multiplier, *out_shift);
  113. acc_64 = ch_0_out_1 + *bias++;
  114. ch_0_out_1 = arm_nn_requantize_s64(acc_64, reduced_multiplier, *out_shift);
  115. out_mult++;
  116. }
  117. else
  118. {
  119. ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
  120. ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
  121. out_mult++;
  122. }
  123. ch_0_out_0 = MAX(ch_0_out_0, activation_min);
  124. ch_0_out_0 = MIN(ch_0_out_0, activation_max);
  125. *out_0++ = (int16_t)ch_0_out_0;
  126. ch_0_out_1 = MAX(ch_0_out_1, activation_min);
  127. ch_0_out_1 = MIN(ch_0_out_1, activation_max);
  128. *out_1++ = (int16_t)ch_0_out_1;
  129. out_shift++;
  130. if (bias)
  131. {
  132. int32_t reduced_multiplier = REDUCE_MULTIPLIER(*out_mult);
  133. int64_t acc_64 = ch_1_out_0 + *bias;
  134. ch_1_out_0 = arm_nn_requantize_s64(acc_64, reduced_multiplier, *out_shift);
  135. acc_64 = ch_1_out_1 + *bias++;
  136. ch_1_out_1 = arm_nn_requantize_s64(acc_64, reduced_multiplier, *out_shift);
  137. out_mult++;
  138. }
  139. else
  140. {
  141. ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift);
  142. ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift);
  143. out_mult++;
  144. }
  145. ch_1_out_0 = MAX(ch_1_out_0, activation_min);
  146. ch_1_out_0 = MIN(ch_1_out_0, activation_max);
  147. *out_0++ = (int16_t)ch_1_out_0;
  148. ch_1_out_1 = MAX(ch_1_out_1, activation_min);
  149. ch_1_out_1 = MIN(ch_1_out_1, activation_max);
  150. *out_1++ = (int16_t)ch_1_out_1;
  151. out_shift++;
  152. /* skip row */
  153. ip_a0 += num_col_a;
  154. row_count--;
  155. }
  156. /* compute the last odd numbered row if any */
  157. if (output_ch & 0x1)
  158. {
  159. /* setup pointers for B */
  160. const int16_t *ip_b0 = input_b;
  161. const int16_t *ip_b1 = ip_b0 + num_col_a;
  162. int32_t ch_0_out_0 = 0;
  163. int32_t ch_0_out_1 = 0;
  164. uint16_t col_count = num_col_a >> 2;
  165. while (col_count)
  166. {
  167. int32_t a01, a02;
  168. int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
  169. int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
  170. ip_a0 = read_and_pad(ip_a0, &a01, &a02);
  171. ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
  172. ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
  173. b0 = arm_nn_read_q15x2_ia(&ip_b0);
  174. b1 = arm_nn_read_q15x2_ia(&ip_b1);
  175. ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
  176. ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
  177. col_count--;
  178. }
  179. col_count = num_col_a & 0x3;
  180. while (col_count)
  181. {
  182. int8_t a0 = *ip_a0++;
  183. int16_t b0 = *ip_b0++;
  184. int16_t b1 = *ip_b1++;
  185. ch_0_out_0 += a0 * b0;
  186. ch_0_out_1 += a0 * b1;
  187. col_count--;
  188. }
  189. if (bias)
  190. {
  191. int32_t reduced_multiplier = REDUCE_MULTIPLIER(*out_mult);
  192. int64_t acc_64 = ch_0_out_0 + *bias;
  193. ch_0_out_0 = arm_nn_requantize_s64(acc_64, reduced_multiplier, *out_shift);
  194. acc_64 = ch_0_out_1 + *bias++;
  195. ch_0_out_1 = arm_nn_requantize_s64(acc_64, reduced_multiplier, *out_shift);
  196. }
  197. else
  198. {
  199. ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
  200. ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
  201. }
  202. ch_0_out_0 = MAX(ch_0_out_0, activation_min);
  203. ch_0_out_0 = MIN(ch_0_out_0, activation_max);
  204. *out_0++ = (int16_t)ch_0_out_0;
  205. ch_0_out_1 = MAX(ch_0_out_1, activation_min);
  206. ch_0_out_1 = MIN(ch_0_out_1, activation_max);
  207. *out_1++ = (int16_t)ch_0_out_1;
  208. out_mult++;
  209. out_shift++;
  210. }
  211. out_0 += output_ch;
  212. /* return the new output pointer with offset */
  213. return out_0;
  214. #else
  215. (void)input_a;
  216. (void)input_b;
  217. (void)output_ch;
  218. (void)out_shift;
  219. (void)out_mult;
  220. (void)activation_min;
  221. (void)activation_max;
  222. (void)num_col_a;
  223. (void)output_bias;
  224. (void)out_0;
  225. /* To be completed */
  226. return NULL;
  227. #endif
  228. }
  229. /**
  230. * @} end of Doxygen group
  231. */