arm_nn_vec_mat_mult_t_s16.c 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. /*
  2. * SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_nn_vec_mat_mult_t_s16
  21. * Description: s16 vector by matrix (transposed) multiplication
  22. *
  23. * $Date: 5 January 2023
  24. * $Revision: V.2.2.0
  25. *
  26. * Target : Arm(R) M-Profile Architecture
  27. *
  28. * -------------------------------------------------------------------- */
  29. #include "arm_nnsupportfunctions.h"
  30. #define MAX_COL_COUNT (512)
  31. /**
  32. * @ingroup groupSupport
  33. */
  34. /**
  35. * @addtogroup supportFC
  36. * @{
  37. */
  38. /*
  39. * s16 vector(lhs) by matrix (transposed) multiplication
  40. *
  41. * Refer header file for details.
  42. *
  43. */
  44. arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s16(const int16_t *lhs,
  45. const int8_t *rhs,
  46. const int64_t *bias,
  47. int16_t *dst,
  48. const int32_t dst_multiplier,
  49. const int32_t dst_shift,
  50. const int32_t rhs_cols,
  51. const int32_t rhs_rows,
  52. const int32_t activation_min,
  53. const int32_t activation_max)
  54. {
  55. #if defined(ARM_MATH_DSP)
  56. int32_t rhs_cols_fast = rhs_cols;
  57. if (rhs_cols > MAX_COL_COUNT)
  58. {
  59. rhs_cols_fast = MAX_COL_COUNT;
  60. }
  61. #if defined(ARM_MATH_MVEI)
  62. int32_t row_loop_cnt = rhs_rows / 4;
  63. int32_t col_loop_cnt = (rhs_cols_fast + 7) / 8;
  64. for (int32_t i_row_loop_count = 0; i_row_loop_count < row_loop_cnt; i_row_loop_count++)
  65. {
  66. int32_t col_cnt = rhs_cols_fast;
  67. const int16_t *lhs_ptr = lhs;
  68. const int8_t *rhs_ptr_0 = rhs;
  69. const int8_t *rhs_ptr_1 = rhs + rhs_cols;
  70. const int8_t *rhs_ptr_2 = rhs + rhs_cols * 2;
  71. const int8_t *rhs_ptr_3 = rhs + rhs_cols * 3;
  72. int32_t result_0 = 0;
  73. int32_t result_1 = 0;
  74. int32_t result_2 = 0;
  75. int32_t result_3 = 0;
  76. for (int i_col_loop_cnt = 0; i_col_loop_cnt < col_loop_cnt; i_col_loop_cnt++)
  77. {
  78. mve_pred16_t pred = vctp16q(col_cnt);
  79. col_cnt -= 8;
  80. int16x8_t lhs_input = vldrhq_z_s16(lhs_ptr, pred);
  81. int16x8_t rhs_input_0 = vldrbq_z_s16(rhs_ptr_0, pred);
  82. int16x8_t rhs_input_1 = vldrbq_z_s16(rhs_ptr_1, pred);
  83. int16x8_t rhs_input_2 = vldrbq_z_s16(rhs_ptr_2, pred);
  84. int16x8_t rhs_input_3 = vldrbq_z_s16(rhs_ptr_3, pred);
  85. result_0 = vmladavaq_s16(result_0, lhs_input, rhs_input_0);
  86. result_1 = vmladavaq_s16(result_1, lhs_input, rhs_input_1);
  87. result_2 = vmladavaq_s16(result_2, lhs_input, rhs_input_2);
  88. result_3 = vmladavaq_s16(result_3, lhs_input, rhs_input_3);
  89. lhs_ptr += 8;
  90. rhs_ptr_0 += 8;
  91. rhs_ptr_1 += 8;
  92. rhs_ptr_2 += 8;
  93. rhs_ptr_3 += 8;
  94. }
  95. int64_t result_64_0 = result_0;
  96. int64_t result_64_1 = result_1;
  97. int64_t result_64_2 = result_2;
  98. int64_t result_64_3 = result_3;
  99. if (rhs_cols > MAX_COL_COUNT)
  100. {
  101. for (int i_rhs_cols = MAX_COL_COUNT; i_rhs_cols < rhs_cols; i_rhs_cols++)
  102. {
  103. const int16_t lhs_temp = *lhs_ptr++;
  104. result_64_0 += *rhs_ptr_0++ * lhs_temp;
  105. result_64_1 += *rhs_ptr_1++ * lhs_temp;
  106. result_64_2 += *rhs_ptr_2++ * lhs_temp;
  107. result_64_3 += *rhs_ptr_3++ * lhs_temp;
  108. }
  109. }
  110. if (bias)
  111. {
  112. result_64_0 += *bias++;
  113. result_64_1 += *bias++;
  114. result_64_2 += *bias++;
  115. result_64_3 += *bias++;
  116. }
  117. int32_t tmp;
  118. tmp = arm_nn_requantize_s64(result_64_0, dst_multiplier, dst_shift);
  119. tmp = MAX(tmp, activation_min);
  120. tmp = MIN(tmp, activation_max);
  121. *dst++ = (int16_t)tmp;
  122. tmp = 0;
  123. tmp = arm_nn_requantize_s64(result_64_1, dst_multiplier, dst_shift);
  124. tmp = MAX(tmp, activation_min);
  125. tmp = MIN(tmp, activation_max);
  126. *dst++ = (int16_t)tmp;
  127. tmp = 0;
  128. tmp = arm_nn_requantize_s64(result_64_2, dst_multiplier, dst_shift);
  129. tmp = MAX(tmp, activation_min);
  130. tmp = MIN(tmp, activation_max);
  131. *dst++ = (int16_t)tmp;
  132. tmp = 0;
  133. tmp = arm_nn_requantize_s64(result_64_3, dst_multiplier, dst_shift);
  134. tmp = MAX(tmp, activation_min);
  135. tmp = MIN(tmp, activation_max);
  136. *dst++ = (int16_t)tmp;
  137. rhs += 4 * rhs_cols;
  138. }
  139. for (int8_t rows_left = rhs_rows & 0x3; rows_left > 0; rows_left--)
  140. {
  141. int32_t result = 0;
  142. col_loop_cnt = (rhs_cols_fast + 7) / 8;
  143. const int16_t *lhs_ptr = lhs;
  144. const int8_t *rhs_ptr = rhs;
  145. int32_t col_cnt = (int32_t)rhs_cols_fast;
  146. for (int i_col_loop_cnt = 0; i_col_loop_cnt < col_loop_cnt; i_col_loop_cnt++)
  147. {
  148. mve_pred16_t pred = vctp16q(col_cnt);
  149. col_cnt -= 8;
  150. int16x8_t lhs_input = vldrhq_z_s16(lhs_ptr, pred);
  151. int16x8_t rhs_input = vldrbq_z_s16(rhs_ptr, pred);
  152. result = vmladavaq_p_s16(result, lhs_input, rhs_input, pred);
  153. lhs_ptr += 8;
  154. rhs_ptr += 8;
  155. }
  156. int64_t result_64 = result;
  157. if (bias)
  158. {
  159. result_64 += *bias++;
  160. }
  161. if (rhs_cols > MAX_COL_COUNT)
  162. {
  163. for (int i_rhs_cols = MAX_COL_COUNT; i_rhs_cols < rhs_cols; i_rhs_cols++)
  164. {
  165. const int16_t lhs_temp = *lhs_ptr++;
  166. result_64 += *rhs_ptr++ * lhs_temp;
  167. }
  168. }
  169. int32_t tmp = 0;
  170. tmp = arm_nn_requantize_s64(result_64, dst_multiplier, dst_shift);
  171. tmp = MAX(tmp, activation_min);
  172. tmp = MIN(tmp, activation_max);
  173. *dst++ = (int16_t)tmp;
  174. rhs += rhs_cols;
  175. }
  176. #else // ARM_MATH_MVEI
  177. const int32_t row_loop_cnt = rhs_rows / 2;
  178. for (int32_t i = 0; i < row_loop_cnt; i++)
  179. {
  180. int64_t acc_64_0 = 0;
  181. int64_t acc_64_1 = 0;
  182. int32_t acc_0 = 0;
  183. int32_t acc_1 = 0;
  184. const int32_t col_loop_cnt = rhs_cols_fast / 4;
  185. const int16_t *lhs_vec = lhs;
  186. const int8_t *rhs_0 = rhs;
  187. const int8_t *rhs_1 = rhs + rhs_cols;
  188. rhs += 2 * rhs_cols;
  189. for (int j = col_loop_cnt; j != 0; j--)
  190. {
  191. int32_t ker_0, ker_1, vec_part_0, vec_part_1;
  192. vec_part_0 = arm_nn_read_q15x2_ia(&lhs_vec);
  193. vec_part_1 = arm_nn_read_q15x2_ia(&lhs_vec);
  194. rhs_0 = read_and_pad(rhs_0, &ker_0, &ker_1);
  195. acc_0 = SMLAD(ker_0, vec_part_0, acc_0);
  196. acc_0 = SMLAD(ker_1, vec_part_1, acc_0);
  197. rhs_1 = read_and_pad(rhs_1, &ker_0, &ker_1);
  198. acc_1 = SMLAD(ker_0, vec_part_0, acc_1);
  199. acc_1 = SMLAD(ker_1, vec_part_1, acc_1);
  200. }
  201. acc_64_0 += acc_0;
  202. acc_64_1 += acc_1;
  203. for (int k = col_loop_cnt * 4; k < rhs_cols; k++)
  204. {
  205. const int32_t lhs_temp = (*lhs_vec);
  206. lhs_vec++;
  207. acc_64_0 += lhs_temp * (*rhs_0);
  208. rhs_0++;
  209. acc_64_1 += lhs_temp * (*rhs_1);
  210. rhs_1++;
  211. }
  212. if (bias)
  213. {
  214. acc_64_0 += *bias++;
  215. acc_64_1 += *bias++;
  216. }
  217. int32_t tmp;
  218. tmp = arm_nn_requantize_s64(acc_64_0, dst_multiplier, dst_shift);
  219. tmp = MAX(tmp, activation_min);
  220. tmp = MIN(tmp, activation_max);
  221. *dst++ = (int16_t)tmp;
  222. tmp = arm_nn_requantize_s64(acc_64_1, dst_multiplier, dst_shift);
  223. tmp = MAX(tmp, activation_min);
  224. tmp = MIN(tmp, activation_max);
  225. *dst++ = (int16_t)tmp;
  226. }
  227. if (rhs_rows & 0x1)
  228. {
  229. int64_t acc_64_0 = 0;
  230. int32_t acc_0 = 0;
  231. const int32_t col_loop_cnt = rhs_cols_fast / 4;
  232. const int16_t *lhs_vec = lhs;
  233. const int8_t *rhs_0 = rhs;
  234. for (int i = col_loop_cnt; i != 0; i--)
  235. {
  236. int32_t ker_0, ker_1, vec;
  237. rhs_0 = read_and_pad(rhs_0, &ker_0, &ker_1);
  238. vec = arm_nn_read_q15x2_ia(&lhs_vec);
  239. acc_0 = SMLAD(ker_0, vec, acc_0);
  240. vec = arm_nn_read_q15x2_ia(&lhs_vec);
  241. acc_0 = SMLAD(ker_1, vec, acc_0);
  242. }
  243. acc_64_0 += acc_0;
  244. for (int j = col_loop_cnt * 4; j < rhs_cols; j++)
  245. {
  246. const int32_t lhs_temp = (*lhs_vec);
  247. lhs_vec++;
  248. acc_64_0 += lhs_temp * (*rhs_0);
  249. rhs_0++;
  250. }
  251. if (bias)
  252. {
  253. acc_64_0 += *bias++;
  254. }
  255. int32_t tmp;
  256. tmp = arm_nn_requantize_s64(acc_64_0, dst_multiplier, dst_shift);
  257. tmp = MAX(tmp, activation_min);
  258. tmp = MIN(tmp, activation_max);
  259. *dst++ = (int16_t)tmp;
  260. }
  261. #endif // ARM_MATH_MVEI
  262. #else // ARM_MATH_DSP
  263. for (int i_row_loop_cnt = 0; i_row_loop_cnt < rhs_rows; i_row_loop_cnt++)
  264. {
  265. const int16_t *lhs_ptr = lhs;
  266. const int8_t *rhs_ptr_0 = &rhs[0];
  267. int64_t result = 0;
  268. for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
  269. {
  270. const int64_t rhs_value0 = (int8_t)*rhs_ptr_0;
  271. const int64_t lhs_value = *lhs_ptr;
  272. result += lhs_value * rhs_value0;
  273. ++rhs_ptr_0;
  274. ++lhs_ptr;
  275. }
  276. if (bias)
  277. {
  278. result += *bias++;
  279. }
  280. // Quantize down
  281. result = arm_nn_requantize_s64(result, dst_multiplier, dst_shift);
  282. // Clamp the result
  283. result = ((result) > (activation_min) ? (result) : (activation_min));
  284. result = ((result) < (activation_max) ? (result) : (activation_max));
  285. *dst++ = (int16_t)result;
  286. rhs += rhs_cols;
  287. }
  288. #endif // ARM_MATH_DSP
  289. return ARM_CMSIS_NN_SUCCESS;
  290. }
  291. /**
  292. * @} end of Doxygen group
  293. */