arm_nn_mat_mult_kernel_s8_s16.c 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. /*
  2. * Copyright (C) 2010-2020 Arm Limited or its affiliates. All rights reserved.
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_nn_mat_mult_kernel_s8_s16.c
  21. * Description: Matrix-multiplication function for convolution
  22. *
  23. * $Date: 09. October 2020
  24. * $Revision: V.1.0.3
  25. *
  26. * Target Processor: Cortex-M cores
  27. * -------------------------------------------------------------------- */
  28. #include "arm_nnfunctions.h"
  29. #include "arm_nnsupportfunctions.h"
  30. /*
  31. * Matrix-multiplication function for convolution with per-channel requantization.
  32. *
  33. * Refer header file for details.
  34. *
  35. */
  36. q7_t *arm_nn_mat_mult_kernel_s8_s16(const q7_t *input_a,
  37. const q15_t *input_b,
  38. const uint16_t output_ch,
  39. const int32_t *out_shift,
  40. const int32_t *out_mult,
  41. const int32_t out_offset,
  42. const int16_t activation_min,
  43. const int16_t activation_max,
  44. const uint16_t num_col_a,
  45. const int32_t *const output_bias,
  46. q7_t *out_0)
  47. {
  48. #if defined(ARM_MATH_MVEI)
  49. #define ROW_PER_LOOP (4)
  50. #define COL_PER_LOOP (8)
  51. const q7_t *ip_a0_s8 = input_a;
  52. q7_t *out_1 = out_0 + output_ch;
  53. const int32_t *bias = output_bias;
  54. int32_t row_count = output_ch / ROW_PER_LOOP;
  55. while (row_count)
  56. {
  57. const q15_t *ip_b0_s16 = input_b;
  58. const q15_t *ip_b1_s16 = input_b + num_col_a;
  59. const q7_t *ip_a1_s8 = ip_a0_s8 + num_col_a;
  60. const q7_t *ip_a2_s8 = ip_a0_s8 + num_col_a * 2;
  61. const q7_t *ip_a3_s8 = ip_a0_s8 + num_col_a * 3;
  62. q31_t ch_0_out_n = bias[0];
  63. q31_t ch_1_out_n = bias[1];
  64. q31_t ch_2_out_n = bias[2];
  65. q31_t ch_3_out_n = bias[3];
  66. q31_t ch_0_out_n1 = ch_0_out_n;
  67. q31_t ch_1_out_n1 = ch_1_out_n;
  68. q31_t ch_2_out_n1 = ch_2_out_n;
  69. q31_t ch_3_out_n1 = ch_3_out_n;
  70. bias += 4;
  71. int32_t col_count = num_col_a / COL_PER_LOOP;
  72. while (col_count)
  73. {
  74. // Load inputs
  75. const int16x8_t ip_b0 = vld1q_s16(ip_b0_s16);
  76. ip_b0_s16 += COL_PER_LOOP;
  77. const int16x8_t ip_b1 = vld1q_s16(ip_b1_s16);
  78. ip_b1_s16 += COL_PER_LOOP;
  79. // Load filters
  80. const int16x8_t ip_a0 = vldrbq_s16(ip_a0_s8);
  81. ip_a0_s8 += COL_PER_LOOP;
  82. const int16x8_t ip_a1 = vldrbq_s16(ip_a1_s8);
  83. ip_a1_s8 += COL_PER_LOOP;
  84. const int16x8_t ip_a2 = vldrbq_s16(ip_a2_s8);
  85. ip_a2_s8 += COL_PER_LOOP;
  86. const int16x8_t ip_a3 = vldrbq_s16(ip_a3_s8);
  87. ip_a3_s8 += COL_PER_LOOP;
  88. // MAC
  89. ch_0_out_n += vmladavq_s16(ip_b0, ip_a0);
  90. ch_1_out_n += vmladavq_s16(ip_b0, ip_a1);
  91. ch_2_out_n += vmladavq_s16(ip_b0, ip_a2);
  92. ch_3_out_n += vmladavq_s16(ip_b0, ip_a3);
  93. ch_0_out_n1 += vmladavq_s16(ip_b1, ip_a0);
  94. ch_1_out_n1 += vmladavq_s16(ip_b1, ip_a1);
  95. ch_2_out_n1 += vmladavq_s16(ip_b1, ip_a2);
  96. ch_3_out_n1 += vmladavq_s16(ip_b1, ip_a3);
  97. col_count--;
  98. }
  99. /* Handle tail */
  100. col_count = (num_col_a & (COL_PER_LOOP - 1)) - 1;
  101. while (col_count >= 0)
  102. {
  103. const int32_t b0 = ip_b0_s16[col_count];
  104. const int32_t b1 = ip_b1_s16[col_count];
  105. ch_0_out_n += b0 * ip_a0_s8[col_count];
  106. ch_1_out_n += b0 * ip_a1_s8[col_count];
  107. ch_2_out_n += b0 * ip_a2_s8[col_count];
  108. ch_3_out_n += b0 * ip_a3_s8[col_count];
  109. ch_0_out_n1 += b1 * ip_a0_s8[col_count];
  110. ch_1_out_n1 += b1 * ip_a1_s8[col_count];
  111. ch_2_out_n1 += b1 * ip_a2_s8[col_count];
  112. ch_3_out_n1 += b1 * ip_a3_s8[col_count];
  113. col_count--;
  114. }
  115. ip_a0_s8 += (num_col_a & (COL_PER_LOOP - 1));
  116. int32x4_t out_vec_0;
  117. int32x4_t out_vec_1;
  118. out_vec_0[0] = ch_0_out_n;
  119. out_vec_0[1] = ch_1_out_n;
  120. out_vec_0[2] = ch_2_out_n;
  121. out_vec_0[3] = ch_3_out_n;
  122. out_vec_1[0] = ch_0_out_n1;
  123. out_vec_1[1] = ch_1_out_n1;
  124. out_vec_1[2] = ch_2_out_n1;
  125. out_vec_1[3] = ch_3_out_n1;
  126. int32x4_t mult = vldrwq_s32(out_mult);
  127. int32x4_t shift = vldrwq_s32(out_shift);
  128. out_mult += ROW_PER_LOOP;
  129. out_shift += ROW_PER_LOOP;
  130. out_vec_0 = arm_requantize_mve_32x4(out_vec_0, mult, shift);
  131. out_vec_1 = arm_requantize_mve_32x4(out_vec_1, mult, shift);
  132. out_vec_0 = vaddq_n_s32(out_vec_0, out_offset);
  133. out_vec_0 = vmaxq_s32(out_vec_0, vdupq_n_s32(activation_min));
  134. out_vec_0 = vminq_s32(out_vec_0, vdupq_n_s32(activation_max));
  135. vstrbq_s32(out_0, out_vec_0);
  136. out_0 += ROW_PER_LOOP;
  137. out_vec_1 = vaddq_n_s32(out_vec_1, out_offset);
  138. out_vec_1 = vmaxq_s32(out_vec_1, vdupq_n_s32(activation_min));
  139. out_vec_1 = vminq_s32(out_vec_1, vdupq_n_s32(activation_max));
  140. vstrbq_s32(out_1, out_vec_1);
  141. out_1 += ROW_PER_LOOP;
  142. row_count--;
  143. ip_a0_s8 += (num_col_a * 3);
  144. }
  145. row_count = output_ch & (ROW_PER_LOOP - 1);
  146. if (row_count)
  147. {
  148. ip_a0_s8 = input_a + num_col_a * (output_ch & ~3);
  149. const mve_pred16_t p = vctp32q((uint32_t)row_count);
  150. int32x4_t out_vec_0 = vdupq_n_s32(0);
  151. int32x4_t out_vec_1 = vdupq_n_s32(0);
  152. int32x4_t mult_tail;
  153. int32x4_t shift_tail;
  154. for (int i_ch = 0; i_ch < row_count; i_ch++)
  155. {
  156. int32_t output_0 = bias[i_ch];
  157. int32_t output_1 = bias[i_ch];
  158. const q15_t *ip_b0_s16 = input_b;
  159. const q15_t *ip_b1_s16 = input_b + num_col_a;
  160. for (int i_idx = 0; i_idx < num_col_a; i_idx++)
  161. {
  162. output_0 += ip_b0_s16[i_idx] * ip_a0_s8[i_idx];
  163. output_1 += ip_b1_s16[i_idx] * ip_a0_s8[i_idx];
  164. }
  165. ip_a0_s8 += num_col_a;
  166. out_vec_0[i_ch] = output_0;
  167. out_vec_1[i_ch] = output_1;
  168. mult_tail[i_ch] = out_mult[i_ch];
  169. shift_tail[i_ch] = out_shift[i_ch];
  170. }
  171. out_vec_0 = arm_requantize_mve_32x4(out_vec_0, mult_tail, shift_tail);
  172. out_vec_1 = arm_requantize_mve_32x4(out_vec_1, mult_tail, shift_tail);
  173. out_vec_0 = vaddq_n_s32(out_vec_0, out_offset);
  174. out_vec_0 = vmaxq_s32(out_vec_0, vdupq_n_s32(activation_min));
  175. out_vec_0 = vminq_s32(out_vec_0, vdupq_n_s32(activation_max));
  176. vstrbq_p_s32(out_0, out_vec_0, p);
  177. out_vec_1 = vaddq_n_s32(out_vec_1, out_offset);
  178. out_vec_1 = vmaxq_s32(out_vec_1, vdupq_n_s32(activation_min));
  179. out_vec_1 = vminq_s32(out_vec_1, vdupq_n_s32(activation_max));
  180. vstrbq_p_s32(out_1, out_vec_1, p);
  181. out_1 += row_count;
  182. }
  183. return out_1;
  184. #elif defined(ARM_MATH_DSP)
  185. /* set up the second output pointers */
  186. q7_t *out_1 = out_0 + output_ch;
  187. const int32_t *bias = output_bias;
  188. uint16_t row_count = output_ch / 2;
  189. const q7_t *ip_a0 = input_a;
  190. /* this loop over rows in A */
  191. while (row_count)
  192. {
  193. /* setup pointers for B */
  194. const q15_t *ip_b0 = input_b;
  195. const q15_t *ip_b1 = ip_b0 + num_col_a;
  196. /* align the second pointer for A */
  197. const q7_t *ip_a1 = ip_a0 + num_col_a;
  198. /* Init accumulator with bias for channel N and N + 1 */
  199. q31_t ch_0_out_0 = *bias;
  200. q31_t ch_0_out_1 = *bias++;
  201. q31_t ch_1_out_0 = *bias;
  202. q31_t ch_1_out_1 = *bias++;
  203. uint16_t col_count = num_col_a / 4;
  204. /* accumulate over the vector */
  205. while (col_count)
  206. {
  207. q31_t a01, a02, a11, a12;
  208. q31_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
  209. q31_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
  210. ip_a0 = read_and_pad(ip_a0, &a01, &a02);
  211. ip_a1 = read_and_pad(ip_a1, &a11, &a12);
  212. ch_0_out_0 = __SMLAD(a01, b0, ch_0_out_0);
  213. ch_0_out_1 = __SMLAD(a01, b1, ch_0_out_1);
  214. ch_1_out_0 = __SMLAD(a11, b0, ch_1_out_0);
  215. ch_1_out_1 = __SMLAD(a11, b1, ch_1_out_1);
  216. b0 = arm_nn_read_q15x2_ia(&ip_b0);
  217. b1 = arm_nn_read_q15x2_ia(&ip_b1);
  218. ch_0_out_0 = __SMLAD(a02, b0, ch_0_out_0);
  219. ch_0_out_1 = __SMLAD(a02, b1, ch_0_out_1);
  220. ch_1_out_0 = __SMLAD(a12, b0, ch_1_out_0);
  221. ch_1_out_1 = __SMLAD(a12, b1, ch_1_out_1);
  222. col_count--;
  223. } /* while over col_count */
  224. col_count = num_col_a & 0x3;
  225. while (col_count)
  226. {
  227. q7_t a0 = *ip_a0++;
  228. q15_t b0 = *ip_b0++;
  229. q7_t a1 = *ip_a1++;
  230. q15_t b1 = *ip_b1++;
  231. ch_0_out_0 += a0 * b0;
  232. ch_0_out_1 += a0 * b1;
  233. ch_1_out_0 += a1 * b0;
  234. ch_1_out_1 += a1 * b1;
  235. col_count--;
  236. } /* while over col_count */
  237. ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
  238. ch_0_out_0 += out_offset;
  239. ch_0_out_0 = MAX(ch_0_out_0, activation_min);
  240. ch_0_out_0 = MIN(ch_0_out_0, activation_max);
  241. *out_0++ = (q7_t)ch_0_out_0;
  242. ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
  243. ch_0_out_1 += out_offset;
  244. ch_0_out_1 = MAX(ch_0_out_1, activation_min);
  245. ch_0_out_1 = MIN(ch_0_out_1, activation_max);
  246. *out_1++ = (q7_t)ch_0_out_1;
  247. out_mult++;
  248. out_shift++;
  249. ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift);
  250. ch_1_out_0 += out_offset;
  251. ch_1_out_0 = MAX(ch_1_out_0, activation_min);
  252. ch_1_out_0 = MIN(ch_1_out_0, activation_max);
  253. *out_0++ = (q7_t)ch_1_out_0;
  254. ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift);
  255. ch_1_out_1 += out_offset;
  256. ch_1_out_1 = MAX(ch_1_out_1, activation_min);
  257. ch_1_out_1 = MIN(ch_1_out_1, activation_max);
  258. *out_1++ = (q7_t)ch_1_out_1;
  259. out_mult++;
  260. out_shift++;
  261. /* skip row */
  262. ip_a0 += num_col_a;
  263. row_count--;
  264. }
  265. /* compute the last odd numbered row if any */
  266. if (output_ch & 0x1)
  267. {
  268. /* setup pointers for B */
  269. const q15_t *ip_b0 = input_b;
  270. const q15_t *ip_b1 = ip_b0 + num_col_a;
  271. /* load the bias */
  272. q31_t ch_0_out_0 = *bias;
  273. q31_t ch_0_out_1 = *bias++;
  274. uint16_t col_count = num_col_a >> 2;
  275. while (col_count)
  276. {
  277. q31_t a01, a02;
  278. q31_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
  279. q31_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
  280. ip_a0 = read_and_pad(ip_a0, &a01, &a02);
  281. ch_0_out_0 = __SMLAD(a01, b0, ch_0_out_0);
  282. ch_0_out_1 = __SMLAD(a01, b1, ch_0_out_1);
  283. b0 = arm_nn_read_q15x2_ia(&ip_b0);
  284. b1 = arm_nn_read_q15x2_ia(&ip_b1);
  285. ch_0_out_0 = __SMLAD(a02, b0, ch_0_out_0);
  286. ch_0_out_1 = __SMLAD(a02, b1, ch_0_out_1);
  287. col_count--;
  288. }
  289. col_count = num_col_a & 0x3;
  290. while (col_count)
  291. {
  292. q7_t a0 = *ip_a0++;
  293. q15_t b0 = *ip_b0++;
  294. q15_t b1 = *ip_b1++;
  295. ch_0_out_0 += a0 * b0;
  296. ch_0_out_1 += a0 * b1;
  297. col_count--;
  298. }
  299. ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
  300. ch_0_out_0 += out_offset;
  301. ch_0_out_0 = MAX(ch_0_out_0, activation_min);
  302. ch_0_out_0 = MIN(ch_0_out_0, activation_max);
  303. *out_0++ = (q7_t)ch_0_out_0;
  304. ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
  305. ch_0_out_1 += out_offset;
  306. ch_0_out_1 = MAX(ch_0_out_1, activation_min);
  307. ch_0_out_1 = MIN(ch_0_out_1, activation_max);
  308. *out_1++ = (q7_t)ch_0_out_1;
  309. out_mult++;
  310. out_shift++;
  311. }
  312. out_0 += output_ch;
  313. /* return the new output pointer with offset */
  314. return out_0;
  315. #else
  316. (void)input_a;
  317. (void)input_b;
  318. (void)output_ch;
  319. (void)out_shift;
  320. (void)out_mult;
  321. (void)out_offset;
  322. (void)activation_min;
  323. (void)activation_max;
  324. (void)num_col_a;
  325. (void)output_bias;
  326. (void)out_0;
  327. /* To be completed */
  328. return NULL;
  329. #endif
  330. }