arm_nn_vec_mat_mult_t_s8.c 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763
  1. /*
  2. * SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_nn_vec_mat_mult_t_s8
  21. * Description: s8 vector by matrix (transposed) multiplication
  22. *
  23. * $Date: 14 Feb 2023
  24. * $Revision: V.6.0.0
  25. *
  26. * Target : Arm(R) M-Profile Architecture
  27. *
  28. * -------------------------------------------------------------------- */
  29. #include "arm_nnsupportfunctions.h"
  30. /**
  31. * @ingroup groupSupport
  32. */
  33. /**
  34. * @defgroup supportFC Fully Connected
  35. *
  36. * Support functions for Fully Connected
  37. *
  38. */
  39. /**
  40. * @addtogroup supportFC
  41. * @{
  42. */
  43. /*
  44. * s8 vector(lhs) by matrix (transposed) multiplication
  45. *
  46. * Refer header file for details.
  47. *
  48. */
  49. #if defined(ARM_MATH_DSP) && !defined(__ARMCC_VERSION) && !defined(__ICCARM__)
  50. #pragma GCC optimize("unroll-loops")
  51. #endif
  52. arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s8(const int8_t *lhs,
  53. const int8_t *rhs,
  54. const int32_t *kernel_sum,
  55. const int32_t *bias,
  56. int8_t *dst,
  57. const int32_t lhs_offset,
  58. const int32_t dst_offset,
  59. const int32_t dst_multiplier,
  60. const int32_t dst_shift,
  61. const int32_t rhs_cols,
  62. const int32_t rhs_rows,
  63. const int32_t activation_min,
  64. const int32_t activation_max,
  65. const int32_t address_offset,
  66. const int32_t rhs_offset)
  67. {
  68. if (rhs_offset)
  69. {
  70. #if defined(ARM_MATH_MVEI)
  71. const int32_t row_loop_cnt = rhs_rows / 4;
  72. const uint32x4_t address_offset_array = {0, address_offset, address_offset * 2, address_offset * 3};
  73. for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
  74. {
  75. int32_t acc_0 = 0;
  76. int32_t acc_1 = 0;
  77. int32_t acc_2 = 0;
  78. int32_t acc_3 = 0;
  79. const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
  80. const int8_t *lhs_vec = lhs;
  81. const int8_t *rhs_0_ptr = rhs;
  82. const int8_t *rhs_1_ptr = rhs + rhs_cols;
  83. const int8_t *rhs_2_ptr = rhs + 2 * rhs_cols;
  84. const int8_t *rhs_3_ptr = rhs + 3 * rhs_cols;
  85. int32_t lhs_sum = 0;
  86. if (bias)
  87. {
  88. acc_0 = *bias++;
  89. acc_1 = *bias++;
  90. acc_2 = *bias++;
  91. acc_3 = *bias++;
  92. }
  93. uint32_t col_cnt = (uint32_t)rhs_cols;
  94. for (int32_t i = 0; i < col_loop_cnt; i++)
  95. {
  96. mve_pred16_t p = vctp8q(col_cnt);
  97. col_cnt -= 16;
  98. const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
  99. lhs_sum = vaddvaq_s8(lhs_sum, input);
  100. const int8x16_t ker_0 = vldrbq_z_s8(rhs_0_ptr, p);
  101. acc_0 = vmladavaq_s8(acc_0, ker_0, input);
  102. const int8x16_t ker_1 = vldrbq_z_s8(rhs_1_ptr, p);
  103. acc_1 = vmladavaq_s8(acc_1, ker_1, input);
  104. const int8x16_t ker_2 = vldrbq_z_s8(rhs_2_ptr, p);
  105. acc_2 = vmladavaq_s8(acc_2, ker_2, input);
  106. const int8x16_t ker_3 = vldrbq_z_s8(rhs_3_ptr, p);
  107. acc_3 = vmladavaq_s8(acc_3, ker_3, input);
  108. lhs_vec += 16;
  109. rhs_0_ptr += 16;
  110. rhs_1_ptr += 16;
  111. rhs_2_ptr += 16;
  112. rhs_3_ptr += 16;
  113. }
  114. rhs += 4 * rhs_cols;
  115. int32x4_t acc = {acc_0, acc_1, acc_2, acc_3};
  116. const int32x4_t rhs_sum = {kernel_sum[0], kernel_sum[1], kernel_sum[2], kernel_sum[3]};
  117. acc += vdupq_n_s32(lhs_offset) * rhs_sum;
  118. kernel_sum += 4;
  119. acc += vdupq_n_s32(rhs_offset) * vdupq_n_s32(lhs_sum);
  120. acc += vdupq_n_s32(rhs_offset * lhs_offset) * vdupq_n_s32(rhs_cols);
  121. acc = arm_requantize_mve(acc, dst_multiplier, dst_shift);
  122. acc = vaddq_s32(acc, vdupq_n_s32(dst_offset));
  123. acc = vmaxq_s32(acc, vdupq_n_s32(activation_min));
  124. acc = vminq_s32(acc, vdupq_n_s32(activation_max));
  125. vstrbq_scatter_offset_s32(dst, address_offset_array, acc);
  126. dst += 4 * address_offset;
  127. }
  128. const int loop_cnt = rhs_rows % 4;
  129. for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
  130. {
  131. int32_t acc_0 = 0;
  132. const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
  133. const int8_t *lhs_vec = lhs;
  134. const int8_t *rhs_ptr = rhs;
  135. int32_t lhs_sum = 0;
  136. uint32_t col_cnt = (uint32_t)rhs_cols;
  137. for (int32_t i = 0; i < col_loop_cnt; i++)
  138. {
  139. mve_pred16_t p = vctp8q(col_cnt);
  140. col_cnt -= 16;
  141. const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
  142. lhs_sum = vaddvaq_s8(lhs_sum, input);
  143. const int8x16_t ker_0 = vldrbq_z_s8(rhs_ptr, p);
  144. acc_0 = vmladavaq_s8(acc_0, ker_0, input);
  145. lhs_vec += 16;
  146. rhs_ptr += 16;
  147. }
  148. rhs += rhs_cols;
  149. if (bias)
  150. {
  151. acc_0 += *bias;
  152. bias++;
  153. }
  154. const int32_t rhs_sum = kernel_sum[i_row_loop_cnt];
  155. acc_0 += rhs_sum * lhs_offset;
  156. acc_0 += lhs_sum * rhs_offset;
  157. acc_0 += rhs_cols * lhs_offset * rhs_offset;
  158. acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
  159. acc_0 += dst_offset;
  160. // Clamp the result
  161. acc_0 = MAX(acc_0, activation_min);
  162. *dst = MIN(acc_0, activation_max);
  163. dst += address_offset;
  164. }
  165. #elif defined(ARM_MATH_DSP)
  166. (void)kernel_sum;
  167. const int32_t row_loop_cnt = rhs_rows / 2;
  168. const int16_t lhs_offset_s16 = (int16_t)lhs_offset;
  169. const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
  170. const int16_t rhs_offset_s16 = (int16_t)rhs_offset;
  171. const uint32_t rhs_offset_s16x2 = PKHBT(rhs_offset_s16, rhs_offset_s16, 16);
  172. for (int32_t i = 0; i < row_loop_cnt; i++)
  173. {
  174. int32_t acc_0 = 0;
  175. int32_t acc_1 = 0;
  176. if (bias)
  177. {
  178. acc_0 = *bias++;
  179. acc_1 = *bias++;
  180. }
  181. const int32_t col_loop_cnt = rhs_cols / 4;
  182. const int8_t *lhs_vec = lhs;
  183. const int8_t *rhs_0_ptr = rhs;
  184. const int8_t *rhs_1_ptr = rhs + rhs_cols;
  185. rhs += 2 * rhs_cols;
  186. for (int32_t j = col_loop_cnt; j != 0; j--)
  187. {
  188. int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
  189. int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
  190. vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
  191. int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_0_ptr);
  192. int32_t ker_1 = SXTAB16_RORn(rhs_offset_s16x2, (uint32_t)ker_0, 8);
  193. ker_0 = SXTAB16(rhs_offset_s16x2, ker_0);
  194. acc_0 = SMLAD(ker_1, vec_1, acc_0);
  195. acc_0 = SMLAD(ker_0, vec_0, acc_0);
  196. ker_0 = arm_nn_read_s8x4_ia(&rhs_1_ptr);
  197. ker_1 = SXTAB16_RORn(rhs_offset_s16x2, (uint32_t)ker_0, 8);
  198. ker_0 = SXTAB16(rhs_offset_s16x2, ker_0);
  199. acc_1 = SMLAD(ker_1, vec_1, acc_1);
  200. acc_1 = SMLAD(ker_0, vec_0, acc_1);
  201. }
  202. for (int32_t k = col_loop_cnt * 4; k < rhs_cols; k++)
  203. {
  204. const int32_t lhs_temp = (*lhs_vec + lhs_offset);
  205. lhs_vec++;
  206. acc_0 += lhs_temp * (*rhs_0_ptr + rhs_offset);
  207. rhs_0_ptr++;
  208. acc_1 += lhs_temp * (*rhs_1_ptr + rhs_offset);
  209. rhs_1_ptr++;
  210. }
  211. acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
  212. acc_1 = arm_nn_requantize(acc_1, dst_multiplier, dst_shift);
  213. // Add offset
  214. acc_0 += dst_offset;
  215. acc_1 += dst_offset;
  216. // Clamp the result
  217. acc_0 = MAX(acc_0, activation_min);
  218. acc_0 = MIN(acc_0, activation_max);
  219. acc_1 = MAX(acc_1, activation_min);
  220. acc_1 = MIN(acc_1, activation_max);
  221. *dst = (int8_t)acc_0;
  222. *(dst + address_offset) = (int8_t)acc_1;
  223. dst += 2 * address_offset;
  224. }
  225. if (rhs_rows & 0x1)
  226. {
  227. int32_t acc_0 = 0;
  228. if (bias)
  229. {
  230. acc_0 = *bias++;
  231. }
  232. const int32_t col_loop_cnt = rhs_cols / 4;
  233. const int8_t *lhs_vec = lhs;
  234. const int8_t *rhs_ptr = rhs;
  235. for (int32_t i = col_loop_cnt; i != 0; i--)
  236. {
  237. int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
  238. int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
  239. vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
  240. int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_ptr);
  241. int32_t ker_1 = SXTAB16_RORn(rhs_offset_s16x2, (uint32_t)ker_0, 8);
  242. ker_0 = SXTAB16(rhs_offset_s16x2, ker_0);
  243. acc_0 = SMLAD(ker_1, vec_1, acc_0);
  244. acc_0 = SMLAD(ker_0, vec_0, acc_0);
  245. }
  246. for (int32_t j = col_loop_cnt * 4; j < rhs_cols; j++)
  247. {
  248. const int32_t lhs_temp = (*lhs_vec + lhs_offset);
  249. lhs_vec++;
  250. acc_0 += lhs_temp * (*rhs_ptr + rhs_offset);
  251. rhs_ptr++;
  252. }
  253. acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
  254. // Add offset
  255. acc_0 += dst_offset;
  256. // Clamp the result
  257. acc_0 = MAX(acc_0, activation_min);
  258. acc_0 = MIN(acc_0, activation_max);
  259. *dst = (int8_t)acc_0;
  260. dst += address_offset;
  261. }
  262. #else
  263. (void)kernel_sum;
  264. const int32_t row_loop_cnt = rhs_rows / 3;
  265. for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
  266. {
  267. const int8_t *lhs_ptr = lhs;
  268. const int8_t *rhs_ptr_0 = &rhs[0];
  269. const int8_t *rhs_ptr_1 = &rhs[rhs_cols];
  270. const int8_t *rhs_ptr_2 = &rhs[rhs_cols * 2];
  271. int32_t res00 = 0;
  272. int32_t res01 = 0;
  273. int32_t res02 = 0;
  274. if (bias)
  275. {
  276. res00 = *bias++;
  277. res01 = *bias++;
  278. res02 = *bias++;
  279. }
  280. for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
  281. {
  282. const int32_t rhs_value0 = (int8_t)*rhs_ptr_0 + rhs_offset;
  283. const int32_t rhs_value1 = (int8_t)*rhs_ptr_1 + rhs_offset;
  284. const int32_t rhs_value2 = (int8_t)*rhs_ptr_2 + rhs_offset;
  285. const int32_t lhs_value = (int8_t)*lhs_ptr + lhs_offset;
  286. res00 += lhs_value * rhs_value0;
  287. res01 += lhs_value * rhs_value1;
  288. res02 += lhs_value * rhs_value2;
  289. ++rhs_ptr_0;
  290. ++rhs_ptr_1;
  291. ++rhs_ptr_2;
  292. ++lhs_ptr;
  293. }
  294. // Quantize down
  295. res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
  296. res01 = arm_nn_requantize(res01, dst_multiplier, dst_shift);
  297. res02 = arm_nn_requantize(res02, dst_multiplier, dst_shift);
  298. // Add offset
  299. res00 += dst_offset;
  300. res01 += dst_offset;
  301. res02 += dst_offset;
  302. // Clamp the result
  303. res00 = MAX(res00, activation_min);
  304. res00 = MIN(res00, activation_max);
  305. res01 = MAX(res01, activation_min);
  306. res01 = MIN(res01, activation_max);
  307. res02 = MAX(res02, activation_min);
  308. res02 = MIN(res02, activation_max);
  309. *dst = (int8_t)res00;
  310. *(dst + address_offset) = (int8_t)res01;
  311. *(dst + 2 * address_offset) = (int8_t)res02;
  312. dst += 3 * address_offset;
  313. rhs += 3 * rhs_cols;
  314. }
  315. const int loop_cnt = rhs_rows % 3;
  316. for (int32_t i_loop_cnt = 0; i_loop_cnt < loop_cnt; i_loop_cnt++)
  317. {
  318. const int8_t *lhs_ptr = &lhs[0];
  319. const int8_t *rhs_ptr = &rhs[0];
  320. int32_t res00 = 0;
  321. if (bias)
  322. {
  323. res00 = *bias++;
  324. }
  325. for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
  326. {
  327. int32_t rhs_value0 = (int8_t)rhs_ptr[0] + rhs_offset;
  328. int32_t lhs_value = (int8_t)lhs_ptr[0] + lhs_offset;
  329. res00 += lhs_value * rhs_value0;
  330. ++rhs_ptr;
  331. ++lhs_ptr;
  332. }
  333. // Quantize down
  334. res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
  335. // Add offset
  336. res00 += dst_offset;
  337. // Clamp the result
  338. res00 = MAX(res00, activation_min);
  339. res00 = MIN(res00, activation_max);
  340. *dst = (int8_t)res00;
  341. dst += address_offset;
  342. rhs += rhs_cols;
  343. }
  344. #endif
  345. }
  346. else
  347. {
  348. #if defined(ARM_MATH_MVEI)
  349. const int32_t row_loop_cnt = rhs_rows / 4;
  350. const uint32x4_t address_offset_array = {0, address_offset, address_offset * 2, address_offset * 3};
  351. for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
  352. {
  353. int32_t acc_0 = 0;
  354. int32_t acc_1 = 0;
  355. int32_t acc_2 = 0;
  356. int32_t acc_3 = 0;
  357. const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
  358. const int8_t *lhs_vec = lhs;
  359. const int8_t *rhs_0_ptr = rhs;
  360. const int8_t *rhs_1_ptr = rhs + rhs_cols;
  361. const int8_t *rhs_2_ptr = rhs + 2 * rhs_cols;
  362. const int8_t *rhs_3_ptr = rhs + 3 * rhs_cols;
  363. if (bias)
  364. {
  365. acc_0 = *bias++;
  366. acc_1 = *bias++;
  367. acc_2 = *bias++;
  368. acc_3 = *bias++;
  369. }
  370. uint32_t col_cnt = (uint32_t)rhs_cols;
  371. for (int32_t i = 0; i < col_loop_cnt; i++)
  372. {
  373. mve_pred16_t p = vctp8q(col_cnt);
  374. col_cnt -= 16;
  375. const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
  376. const int8x16_t ker_0 = vldrbq_z_s8(rhs_0_ptr, p);
  377. acc_0 = vmladavaq_s8(acc_0, ker_0, input);
  378. const int8x16_t ker_1 = vldrbq_z_s8(rhs_1_ptr, p);
  379. acc_1 = vmladavaq_s8(acc_1, ker_1, input);
  380. const int8x16_t ker_2 = vldrbq_z_s8(rhs_2_ptr, p);
  381. acc_2 = vmladavaq_s8(acc_2, ker_2, input);
  382. const int8x16_t ker_3 = vldrbq_z_s8(rhs_3_ptr, p);
  383. acc_3 = vmladavaq_s8(acc_3, ker_3, input);
  384. lhs_vec += 16;
  385. rhs_0_ptr += 16;
  386. rhs_1_ptr += 16;
  387. rhs_2_ptr += 16;
  388. rhs_3_ptr += 16;
  389. }
  390. rhs += 4 * rhs_cols;
  391. int32x4_t acc = {acc_0, acc_1, acc_2, acc_3};
  392. const int32x4_t rhs_sum = {kernel_sum[0], kernel_sum[1], kernel_sum[2], kernel_sum[3]};
  393. acc += vdupq_n_s32(lhs_offset) * rhs_sum;
  394. kernel_sum += 4;
  395. acc = arm_requantize_mve(acc, dst_multiplier, dst_shift);
  396. acc = vaddq_s32(acc, vdupq_n_s32(dst_offset));
  397. acc = vmaxq_s32(acc, vdupq_n_s32(activation_min));
  398. acc = vminq_s32(acc, vdupq_n_s32(activation_max));
  399. vstrbq_scatter_offset_s32(dst, address_offset_array, acc);
  400. dst += 4 * address_offset;
  401. }
  402. const int loop_cnt = rhs_rows % 4;
  403. for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
  404. {
  405. int32_t acc_0 = 0;
  406. const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
  407. const int8_t *lhs_vec = lhs;
  408. const int8_t *rhs_ptr = rhs;
  409. uint32_t col_cnt = (uint32_t)rhs_cols;
  410. for (int32_t i = 0; i < col_loop_cnt; i++)
  411. {
  412. mve_pred16_t p = vctp8q(col_cnt);
  413. col_cnt -= 16;
  414. const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
  415. const int8x16_t ker_0 = vldrbq_z_s8(rhs_ptr, p);
  416. acc_0 = vmladavaq_s8(acc_0, ker_0, input);
  417. lhs_vec += 16;
  418. rhs_ptr += 16;
  419. }
  420. rhs += rhs_cols;
  421. if (bias)
  422. {
  423. acc_0 += *bias;
  424. bias++;
  425. }
  426. const int32_t rhs_sum = kernel_sum[i_row_loop_cnt];
  427. const int32_t offsets = rhs_sum * lhs_offset;
  428. acc_0 += offsets;
  429. acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
  430. acc_0 += dst_offset;
  431. // Clamp the result
  432. acc_0 = MAX(acc_0, activation_min);
  433. *dst = MIN(acc_0, activation_max);
  434. dst += address_offset;
  435. }
  436. #elif defined(ARM_MATH_DSP)
  437. (void)kernel_sum;
  438. const int32_t row_loop_cnt = rhs_rows / 2;
  439. const int16_t lhs_offset_s16 = (int16_t)lhs_offset;
  440. const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
  441. for (int32_t i = 0; i < row_loop_cnt; i++)
  442. {
  443. int32_t acc_0 = 0;
  444. int32_t acc_1 = 0;
  445. if (bias)
  446. {
  447. acc_0 = *bias++;
  448. acc_1 = *bias++;
  449. }
  450. const int32_t col_loop_cnt = rhs_cols / 4;
  451. const int8_t *lhs_vec = lhs;
  452. const int8_t *rhs_0_ptr = rhs;
  453. const int8_t *rhs_1_ptr = rhs + rhs_cols;
  454. rhs += 2 * rhs_cols;
  455. for (int32_t j = col_loop_cnt; j != 0; j--)
  456. {
  457. int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
  458. int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
  459. vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
  460. int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_0_ptr);
  461. int32_t ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
  462. ker_0 = SXTB16(ker_0);
  463. acc_0 = SMLAD(ker_1, vec_1, acc_0);
  464. acc_0 = SMLAD(ker_0, vec_0, acc_0);
  465. ker_0 = arm_nn_read_s8x4_ia(&rhs_1_ptr);
  466. ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
  467. ker_0 = SXTB16(ker_0);
  468. acc_1 = SMLAD(ker_1, vec_1, acc_1);
  469. acc_1 = SMLAD(ker_0, vec_0, acc_1);
  470. }
  471. for (int32_t k = col_loop_cnt * 4; k < rhs_cols; k++)
  472. {
  473. const int32_t lhs_temp = (*lhs_vec + lhs_offset);
  474. lhs_vec++;
  475. acc_0 += lhs_temp * (*rhs_0_ptr);
  476. rhs_0_ptr++;
  477. acc_1 += lhs_temp * (*rhs_1_ptr);
  478. rhs_1_ptr++;
  479. }
  480. acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
  481. acc_1 = arm_nn_requantize(acc_1, dst_multiplier, dst_shift);
  482. // Add offset
  483. acc_0 += dst_offset;
  484. acc_1 += dst_offset;
  485. // Clamp the result
  486. acc_0 = MAX(acc_0, activation_min);
  487. acc_0 = MIN(acc_0, activation_max);
  488. acc_1 = MAX(acc_1, activation_min);
  489. acc_1 = MIN(acc_1, activation_max);
  490. *dst = (int8_t)acc_0;
  491. *(dst + address_offset) = (int8_t)acc_1;
  492. dst += 2 * address_offset;
  493. }
  494. if (rhs_rows & 0x1)
  495. {
  496. int32_t acc_0 = 0;
  497. if (bias)
  498. {
  499. acc_0 = *bias++;
  500. }
  501. const int32_t col_loop_cnt = rhs_cols / 4;
  502. const int8_t *lhs_vec = lhs;
  503. const int8_t *rhs_ptr = rhs;
  504. for (int32_t i = col_loop_cnt; i != 0; i--)
  505. {
  506. int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
  507. int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
  508. vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
  509. int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_ptr);
  510. int32_t ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
  511. ker_0 = SXTB16(ker_0);
  512. acc_0 = SMLAD(ker_1, vec_1, acc_0);
  513. acc_0 = SMLAD(ker_0, vec_0, acc_0);
  514. }
  515. for (int32_t j = col_loop_cnt * 4; j < rhs_cols; j++)
  516. {
  517. const int32_t lhs_temp = (*lhs_vec + lhs_offset);
  518. lhs_vec++;
  519. acc_0 += lhs_temp * (*rhs_ptr);
  520. rhs_ptr++;
  521. }
  522. acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
  523. // Add offset
  524. acc_0 += dst_offset;
  525. // Clamp the result
  526. acc_0 = MAX(acc_0, activation_min);
  527. acc_0 = MIN(acc_0, activation_max);
  528. *dst = (int8_t)acc_0;
  529. dst += address_offset;
  530. }
  531. #else
  532. (void)kernel_sum;
  533. const int32_t row_loop_cnt = rhs_rows / 3;
  534. for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
  535. {
  536. const int8_t *lhs_ptr = lhs;
  537. const int8_t *rhs_ptr_0 = &rhs[0];
  538. const int8_t *rhs_ptr_1 = &rhs[rhs_cols];
  539. const int8_t *rhs_ptr_2 = &rhs[rhs_cols * 2];
  540. int32_t res00 = 0;
  541. int32_t res01 = 0;
  542. int32_t res02 = 0;
  543. if (bias)
  544. {
  545. res00 = *bias++;
  546. res01 = *bias++;
  547. res02 = *bias++;
  548. }
  549. for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
  550. {
  551. const int32_t rhs_value0 = (int8_t)*rhs_ptr_0;
  552. const int32_t rhs_value1 = (int8_t)*rhs_ptr_1;
  553. const int32_t rhs_value2 = (int8_t)*rhs_ptr_2;
  554. const int32_t lhs_value = (int8_t)*lhs_ptr + lhs_offset;
  555. res00 += lhs_value * rhs_value0;
  556. res01 += lhs_value * rhs_value1;
  557. res02 += lhs_value * rhs_value2;
  558. ++rhs_ptr_0;
  559. ++rhs_ptr_1;
  560. ++rhs_ptr_2;
  561. ++lhs_ptr;
  562. }
  563. // Quantize down
  564. res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
  565. res01 = arm_nn_requantize(res01, dst_multiplier, dst_shift);
  566. res02 = arm_nn_requantize(res02, dst_multiplier, dst_shift);
  567. // Add offset
  568. res00 += dst_offset;
  569. res01 += dst_offset;
  570. res02 += dst_offset;
  571. // Clamp the result
  572. res00 = MAX(res00, activation_min);
  573. res00 = MIN(res00, activation_max);
  574. res01 = MAX(res01, activation_min);
  575. res01 = MIN(res01, activation_max);
  576. res02 = MAX(res02, activation_min);
  577. res02 = MIN(res02, activation_max);
  578. *dst = (int8_t)res00;
  579. *(dst + address_offset) = (int8_t)res01;
  580. *(dst + 2 * address_offset) = (int8_t)res02;
  581. dst += 3 * address_offset;
  582. rhs += 3 * rhs_cols;
  583. }
  584. const int loop_cnt = rhs_rows % 3;
  585. for (int32_t i_loop_cnt = 0; i_loop_cnt < loop_cnt; i_loop_cnt++)
  586. {
  587. const int8_t *lhs_ptr = &lhs[0];
  588. const int8_t *rhs_ptr = &rhs[0];
  589. int32_t res00 = 0;
  590. if (bias)
  591. {
  592. res00 = *bias++;
  593. }
  594. for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
  595. {
  596. int32_t rhs_value0 = (int8_t)rhs_ptr[0];
  597. int32_t lhs_value = (int8_t)lhs_ptr[0] + lhs_offset;
  598. res00 += lhs_value * rhs_value0;
  599. ++rhs_ptr;
  600. ++lhs_ptr;
  601. }
  602. // Quantize down
  603. res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
  604. // Add offset
  605. res00 += dst_offset;
  606. // Clamp the result
  607. res00 = MAX(res00, activation_min);
  608. res00 = MIN(res00, activation_max);
  609. *dst = (int8_t)res00;
  610. dst += address_offset;
  611. rhs += rhs_cols;
  612. }
  613. #endif
  614. }
  615. return ARM_CMSIS_NN_SUCCESS;
  616. }
  617. /**
  618. * @} end of Doxygen group
  619. */