arm_nn_mat_mult_kernel_s8_s16.c 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. /*
  2. * SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_nn_mat_mult_kernel_s8_s16.c
  21. * Description: Matrix-multiplication function for convolution
  22. *
  23. * $Date: 29 May 2023
  24. * $Revision: V.2.0.0
  25. *
  26. * Target : Arm(R) M-Profile Architecture
  27. * -------------------------------------------------------------------- */
  28. #include "arm_nnfunctions.h"
  29. #include "arm_nnsupportfunctions.h"
  30. /*
  31. * Matrix-multiplication function for convolution with per-channel requantization.
  32. *
  33. * Refer header file for details.
  34. *
  35. */
  36. int8_t *arm_nn_mat_mult_kernel_s8_s16(const int8_t *input_a,
  37. const int16_t *input_b,
  38. const uint16_t output_ch,
  39. const int32_t *out_shift,
  40. const int32_t *out_mult,
  41. const int32_t out_offset,
  42. const int16_t activation_min,
  43. const int16_t activation_max,
  44. const int32_t num_col_a,
  45. const int32_t aligned_num_col_a,
  46. const int32_t *const output_bias,
  47. int8_t *out_0)
  48. {
  49. #if !defined(ARM_MATH_MVEI)
  50. /* set up the second output pointers */
  51. int8_t *out_1 = out_0 + output_ch;
  52. const int32_t *bias = output_bias;
  53. uint16_t row_count = output_ch / 2;
  54. const int8_t *ip_a0 = input_a;
  55. /* this loop over rows in A */
  56. while (row_count)
  57. {
  58. /* setup pointers for B */
  59. const int16_t *ip_b0 = input_b;
  60. const int16_t *ip_b1 = ip_b0 + aligned_num_col_a;
  61. /* align the second pointer for A */
  62. const int8_t *ip_a1 = ip_a0 + num_col_a;
  63. int32_t ch_0_out_0 = 0;
  64. int32_t ch_0_out_1 = 0;
  65. int32_t ch_1_out_0 = 0;
  66. int32_t ch_1_out_1 = 0;
  67. /* Init accumulator with bias for channel N and N + 1 */
  68. if (bias)
  69. {
  70. ch_0_out_0 = *bias;
  71. ch_0_out_1 = *bias++;
  72. ch_1_out_0 = *bias;
  73. ch_1_out_1 = *bias++;
  74. }
  75. #if defined(ARM_MATH_DSP)
  76. int32_t col_count = num_col_a / 4;
  77. /* accumulate over the vector */
  78. while (col_count)
  79. {
  80. int32_t a01, a02, a11, a12;
  81. int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
  82. int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
  83. ip_a0 = read_and_pad_reordered(ip_a0, &a01, &a02);
  84. ip_a1 = read_and_pad_reordered(ip_a1, &a11, &a12);
  85. ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
  86. ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
  87. ch_1_out_0 = SMLAD(a11, b0, ch_1_out_0);
  88. ch_1_out_1 = SMLAD(a11, b1, ch_1_out_1);
  89. b0 = arm_nn_read_q15x2_ia(&ip_b0);
  90. b1 = arm_nn_read_q15x2_ia(&ip_b1);
  91. ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
  92. ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
  93. ch_1_out_0 = SMLAD(a12, b0, ch_1_out_0);
  94. ch_1_out_1 = SMLAD(a12, b1, ch_1_out_1);
  95. col_count--;
  96. } /* while over col_count */
  97. col_count = num_col_a & 0x3;
  98. #else
  99. int32_t col_count = num_col_a;
  100. #endif
  101. while (col_count)
  102. {
  103. int8_t a0 = *ip_a0++;
  104. int16_t b0 = *ip_b0++;
  105. int8_t a1 = *ip_a1++;
  106. int16_t b1 = *ip_b1++;
  107. ch_0_out_0 += a0 * b0;
  108. ch_0_out_1 += a0 * b1;
  109. ch_1_out_0 += a1 * b0;
  110. ch_1_out_1 += a1 * b1;
  111. col_count--;
  112. } /* while over col_count */
  113. ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
  114. ch_0_out_0 += out_offset;
  115. ch_0_out_0 = MAX(ch_0_out_0, activation_min);
  116. ch_0_out_0 = MIN(ch_0_out_0, activation_max);
  117. *out_0++ = (int8_t)ch_0_out_0;
  118. ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
  119. ch_0_out_1 += out_offset;
  120. ch_0_out_1 = MAX(ch_0_out_1, activation_min);
  121. ch_0_out_1 = MIN(ch_0_out_1, activation_max);
  122. *out_1++ = (int8_t)ch_0_out_1;
  123. out_mult++;
  124. out_shift++;
  125. ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift);
  126. ch_1_out_0 += out_offset;
  127. ch_1_out_0 = MAX(ch_1_out_0, activation_min);
  128. ch_1_out_0 = MIN(ch_1_out_0, activation_max);
  129. *out_0++ = (int8_t)ch_1_out_0;
  130. ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift);
  131. ch_1_out_1 += out_offset;
  132. ch_1_out_1 = MAX(ch_1_out_1, activation_min);
  133. ch_1_out_1 = MIN(ch_1_out_1, activation_max);
  134. *out_1++ = (int8_t)ch_1_out_1;
  135. out_mult++;
  136. out_shift++;
  137. /* skip row */
  138. ip_a0 += num_col_a;
  139. row_count--;
  140. }
  141. /* compute the last odd numbered row if any */
  142. if (output_ch & 0x1)
  143. {
  144. /* setup pointers for B */
  145. const int16_t *ip_b0 = input_b;
  146. const int16_t *ip_b1 = ip_b0 + aligned_num_col_a;
  147. int32_t ch_0_out_0 = 0;
  148. int32_t ch_0_out_1 = 0;
  149. /* load the bias */
  150. if (bias)
  151. {
  152. ch_0_out_0 = *bias;
  153. ch_0_out_1 = *bias++;
  154. }
  155. #if defined(ARM_MATH_DSP)
  156. int32_t col_count = num_col_a >> 2;
  157. while (col_count)
  158. {
  159. int32_t a01, a02;
  160. int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
  161. int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
  162. ip_a0 = read_and_pad_reordered(ip_a0, &a01, &a02);
  163. ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
  164. ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
  165. b0 = arm_nn_read_q15x2_ia(&ip_b0);
  166. b1 = arm_nn_read_q15x2_ia(&ip_b1);
  167. ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
  168. ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
  169. col_count--;
  170. }
  171. col_count = num_col_a & 0x3;
  172. #else
  173. int32_t col_count = num_col_a;
  174. #endif
  175. while (col_count)
  176. {
  177. int8_t a0 = *ip_a0++;
  178. int16_t b0 = *ip_b0++;
  179. int16_t b1 = *ip_b1++;
  180. ch_0_out_0 += a0 * b0;
  181. ch_0_out_1 += a0 * b1;
  182. col_count--;
  183. }
  184. ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
  185. ch_0_out_0 += out_offset;
  186. ch_0_out_0 = MAX(ch_0_out_0, activation_min);
  187. ch_0_out_0 = MIN(ch_0_out_0, activation_max);
  188. *out_0++ = (int8_t)ch_0_out_0;
  189. ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
  190. ch_0_out_1 += out_offset;
  191. ch_0_out_1 = MAX(ch_0_out_1, activation_min);
  192. ch_0_out_1 = MIN(ch_0_out_1, activation_max);
  193. *out_1++ = (int8_t)ch_0_out_1;
  194. out_mult++;
  195. out_shift++;
  196. }
  197. out_0 += output_ch;
  198. /* return the new output pointer with offset */
  199. return out_0;
  200. #else
  201. (void)input_a;
  202. (void)input_b;
  203. (void)output_ch;
  204. (void)out_shift;
  205. (void)out_mult;
  206. (void)out_offset;
  207. (void)activation_min;
  208. (void)activation_max;
  209. (void)aligned_num_col_a, (void)num_col_a;
  210. (void)output_bias;
  211. (void)out_0;
  212. /* To be completed */
  213. return NULL;
  214. #endif
  215. }