arm_nn_mat_mult_kernel_s8_s16.c 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. /*
  2. * Copyright (C) 2010-2021 Arm Limited or its affiliates. All rights reserved.
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_nn_mat_mult_kernel_s8_s16.c
  21. * Description: Matrix-multiplication function for convolution
  22. *
  23. * $Date: 09. October 2020
  24. * $Revision: V.1.0.3
  25. *
  26. * Target Processor: Cortex-M cores
  27. * -------------------------------------------------------------------- */
  28. #include "arm_nnfunctions.h"
  29. #include "arm_nnsupportfunctions.h"
  30. /*
  31. * Matrix-multiplication function for convolution with per-channel requantization.
  32. *
  33. * Refer header file for details.
  34. *
  35. */
  36. q7_t *arm_nn_mat_mult_kernel_s8_s16(const q7_t *input_a,
  37. const q15_t *input_b,
  38. const uint16_t output_ch,
  39. const int32_t *out_shift,
  40. const int32_t *out_mult,
  41. const int32_t out_offset,
  42. const int16_t activation_min,
  43. const int16_t activation_max,
  44. const uint16_t num_col_a,
  45. const int32_t *const output_bias,
  46. q7_t *out_0)
  47. {
  48. #if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
  49. /* set up the second output pointers */
  50. q7_t *out_1 = out_0 + output_ch;
  51. const int32_t *bias = output_bias;
  52. uint16_t row_count = output_ch / 2;
  53. const q7_t *ip_a0 = input_a;
  54. /* this loop over rows in A */
  55. while (row_count)
  56. {
  57. /* setup pointers for B */
  58. const q15_t *ip_b0 = input_b;
  59. const q15_t *ip_b1 = ip_b0 + num_col_a;
  60. /* align the second pointer for A */
  61. const q7_t *ip_a1 = ip_a0 + num_col_a;
  62. q31_t ch_0_out_0 = 0;
  63. q31_t ch_0_out_1 = 0;
  64. q31_t ch_1_out_0 = 0;
  65. q31_t ch_1_out_1 = 0;
  66. /* Init accumulator with bias for channel N and N + 1 */
  67. if (bias)
  68. {
  69. ch_0_out_0 = *bias;
  70. ch_0_out_1 = *bias++;
  71. ch_1_out_0 = *bias;
  72. ch_1_out_1 = *bias++;
  73. }
  74. uint16_t col_count = num_col_a / 4;
  75. /* accumulate over the vector */
  76. while (col_count)
  77. {
  78. q31_t a01, a02, a11, a12;
  79. q31_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
  80. q31_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
  81. ip_a0 = read_and_pad(ip_a0, &a01, &a02);
  82. ip_a1 = read_and_pad(ip_a1, &a11, &a12);
  83. ch_0_out_0 = __SMLAD(a01, b0, ch_0_out_0);
  84. ch_0_out_1 = __SMLAD(a01, b1, ch_0_out_1);
  85. ch_1_out_0 = __SMLAD(a11, b0, ch_1_out_0);
  86. ch_1_out_1 = __SMLAD(a11, b1, ch_1_out_1);
  87. b0 = arm_nn_read_q15x2_ia(&ip_b0);
  88. b1 = arm_nn_read_q15x2_ia(&ip_b1);
  89. ch_0_out_0 = __SMLAD(a02, b0, ch_0_out_0);
  90. ch_0_out_1 = __SMLAD(a02, b1, ch_0_out_1);
  91. ch_1_out_0 = __SMLAD(a12, b0, ch_1_out_0);
  92. ch_1_out_1 = __SMLAD(a12, b1, ch_1_out_1);
  93. col_count--;
  94. } /* while over col_count */
  95. col_count = num_col_a & 0x3;
  96. while (col_count)
  97. {
  98. q7_t a0 = *ip_a0++;
  99. q15_t b0 = *ip_b0++;
  100. q7_t a1 = *ip_a1++;
  101. q15_t b1 = *ip_b1++;
  102. ch_0_out_0 += a0 * b0;
  103. ch_0_out_1 += a0 * b1;
  104. ch_1_out_0 += a1 * b0;
  105. ch_1_out_1 += a1 * b1;
  106. col_count--;
  107. } /* while over col_count */
  108. ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
  109. ch_0_out_0 += out_offset;
  110. ch_0_out_0 = MAX(ch_0_out_0, activation_min);
  111. ch_0_out_0 = MIN(ch_0_out_0, activation_max);
  112. *out_0++ = (q7_t)ch_0_out_0;
  113. ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
  114. ch_0_out_1 += out_offset;
  115. ch_0_out_1 = MAX(ch_0_out_1, activation_min);
  116. ch_0_out_1 = MIN(ch_0_out_1, activation_max);
  117. *out_1++ = (q7_t)ch_0_out_1;
  118. out_mult++;
  119. out_shift++;
  120. ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift);
  121. ch_1_out_0 += out_offset;
  122. ch_1_out_0 = MAX(ch_1_out_0, activation_min);
  123. ch_1_out_0 = MIN(ch_1_out_0, activation_max);
  124. *out_0++ = (q7_t)ch_1_out_0;
  125. ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift);
  126. ch_1_out_1 += out_offset;
  127. ch_1_out_1 = MAX(ch_1_out_1, activation_min);
  128. ch_1_out_1 = MIN(ch_1_out_1, activation_max);
  129. *out_1++ = (q7_t)ch_1_out_1;
  130. out_mult++;
  131. out_shift++;
  132. /* skip row */
  133. ip_a0 += num_col_a;
  134. row_count--;
  135. }
  136. /* compute the last odd numbered row if any */
  137. if (output_ch & 0x1)
  138. {
  139. /* setup pointers for B */
  140. const q15_t *ip_b0 = input_b;
  141. const q15_t *ip_b1 = ip_b0 + num_col_a;
  142. q31_t ch_0_out_0 = 0;
  143. q31_t ch_0_out_1 = 0;
  144. /* load the bias */
  145. if (bias)
  146. {
  147. ch_0_out_0 = *bias;
  148. ch_0_out_1 = *bias++;
  149. }
  150. uint16_t col_count = num_col_a >> 2;
  151. while (col_count)
  152. {
  153. q31_t a01, a02;
  154. q31_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
  155. q31_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
  156. ip_a0 = read_and_pad(ip_a0, &a01, &a02);
  157. ch_0_out_0 = __SMLAD(a01, b0, ch_0_out_0);
  158. ch_0_out_1 = __SMLAD(a01, b1, ch_0_out_1);
  159. b0 = arm_nn_read_q15x2_ia(&ip_b0);
  160. b1 = arm_nn_read_q15x2_ia(&ip_b1);
  161. ch_0_out_0 = __SMLAD(a02, b0, ch_0_out_0);
  162. ch_0_out_1 = __SMLAD(a02, b1, ch_0_out_1);
  163. col_count--;
  164. }
  165. col_count = num_col_a & 0x3;
  166. while (col_count)
  167. {
  168. q7_t a0 = *ip_a0++;
  169. q15_t b0 = *ip_b0++;
  170. q15_t b1 = *ip_b1++;
  171. ch_0_out_0 += a0 * b0;
  172. ch_0_out_1 += a0 * b1;
  173. col_count--;
  174. }
  175. ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
  176. ch_0_out_0 += out_offset;
  177. ch_0_out_0 = MAX(ch_0_out_0, activation_min);
  178. ch_0_out_0 = MIN(ch_0_out_0, activation_max);
  179. *out_0++ = (q7_t)ch_0_out_0;
  180. ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
  181. ch_0_out_1 += out_offset;
  182. ch_0_out_1 = MAX(ch_0_out_1, activation_min);
  183. ch_0_out_1 = MIN(ch_0_out_1, activation_max);
  184. *out_1++ = (q7_t)ch_0_out_1;
  185. out_mult++;
  186. out_shift++;
  187. }
  188. out_0 += output_ch;
  189. /* return the new output pointer with offset */
  190. return out_0;
  191. #else
  192. (void)input_a;
  193. (void)input_b;
  194. (void)output_ch;
  195. (void)out_shift;
  196. (void)out_mult;
  197. (void)out_offset;
  198. (void)activation_min;
  199. (void)activation_max;
  200. (void)num_col_a;
  201. (void)output_bias;
  202. (void)out_0;
  203. /* To be completed */
  204. return NULL;
  205. #endif
  206. }