arm_nn_mat_mult_nt_t_s8.c 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801
  1. /*
  2. * SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_nn_mat_mult_s8_nt_t_s8
  21. * Description: Matrix multiplication support function with the right-hand-side (rhs) matrix transposed
  22. *
  23. * $Date: 04 January 2024
  24. * $Revision: V.3.0.0
  25. *
  26. * Target : Arm(R) M-Profile Architecture
  27. *
  28. * -------------------------------------------------------------------- */
  29. #include "arm_nnsupportfunctions.h"
  30. /**
  31. * @ingroup groupSupport
  32. */
  33. /**
  34. * @addtogroup supportConvolution
  35. * @{
  36. */
  37. /*
  38. * s8 matrix multiplication with the right-hand-side matrix transposed
  39. *
  40. * Refer header file for details.
  41. *
  42. */
  43. arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s8(const int8_t *lhs,
  44. const int8_t *rhs,
  45. const int32_t *bias,
  46. int8_t *dst,
  47. const int32_t *dst_multipliers,
  48. const int32_t *dst_shifts,
  49. const int32_t lhs_rows,
  50. const int32_t rhs_rows,
  51. const int32_t rhs_cols,
  52. const int32_t lhs_offset,
  53. const int32_t dst_offset,
  54. const int32_t activation_min,
  55. const int32_t activation_max,
  56. const int32_t row_address_offset,
  57. const int32_t lhs_cols_offset)
  58. {
  59. #if defined(ARM_MATH_MVEI)
  60. int i_items = 0;
  61. for (; i_items <= (lhs_rows - 4); i_items += 4)
  62. {
  63. for (int i = 0; i < rhs_rows; i++)
  64. {
  65. int32_t acc_n0 = 0;
  66. int32_t acc_n1 = 0;
  67. int32_t acc_n2 = 0;
  68. int32_t acc_n3 = 0;
  69. const int8_t *lhs_vec = lhs;
  70. const int8_t *ip_row_1 = lhs + lhs_cols_offset;
  71. const int8_t *ip_row_2 = lhs + (2 * lhs_cols_offset);
  72. const int8_t *ip_row_3 = lhs + (3 * lhs_cols_offset);
  73. const int8_t *col_base = rhs + i * rhs_cols;
  74. int32_t sum_tmp = 0;
  75. #if defined(ARM_MATH_AUTOVECTORIZE)
  76. for (int j = 0; j < rhs_cols; j++)
  77. {
  78. int32_t col = col_base[j];
  79. sum_tmp += col;
  80. acc_n0 += lhs_vec[j] * col;
  81. acc_n1 += ip_row_1[j] * col;
  82. acc_n2 += ip_row_2[j] * col;
  83. acc_n3 += ip_row_3[j] * col;
  84. }
  85. #else
  86. // Note: If operand initialization is moved around, use '&' constraint to
  87. // specify earlyclobber operands.
  88. __ASM volatile(" .p2align 2 \n"
  89. " wlstp.8 lr, %[cnt], 1f \n"
  90. " mov %[sum], 0 \n"
  91. " mov %[out0], 0 \n"
  92. " mov %[out1], 0 \n"
  93. " mov %[out2], 0 \n"
  94. " mov %[out3], 0 \n"
  95. " vldrb.8 q0, [%[col]], #16 \n"
  96. "2: \n"
  97. " vaddva.s8 %[sum], q0 \n"
  98. " vldrb.8 q1, [%[row0]], #16 \n"
  99. " vmladava.s8 %[out0], q0, q1 \n"
  100. " vldrb.8 q2, [%[row1]], #16 \n"
  101. " vmladava.s8 %[out1], q0, q2 \n"
  102. " vldrb.8 q3, [%[row2]], #16 \n"
  103. " vmladava.s8 %[out2], q0, q3 \n"
  104. " vldrb.8 q4, [%[row3]], #16 \n"
  105. " vmladava.s8 %[out3], q0, q4 \n"
  106. " vldrb.8 q0, [%[col]], #16 \n"
  107. " letp lr, 2b \n"
  108. "1: \n"
  109. : [col] "+r"(col_base),
  110. [sum] "=Te"(sum_tmp),
  111. [row0] "+r"(lhs_vec),
  112. [row1] "+r"(ip_row_1),
  113. [row2] "+r"(ip_row_2),
  114. [row3] "+r"(ip_row_3),
  115. [out0] "=Te"(acc_n0),
  116. [out1] "=Te"(acc_n1),
  117. [out2] "=Te"(acc_n2),
  118. [out3] "=Te"(acc_n3)
  119. : [cnt] "r"(rhs_cols)
  120. : "q0", "q1", "q2", "q3", "q4", "memory", "r14");
  121. #endif
  122. int32x4_t res = {acc_n0, acc_n1, acc_n2, acc_n3};
  123. sum_tmp *= lhs_offset;
  124. if (bias)
  125. {
  126. sum_tmp += bias[i];
  127. }
  128. res = vaddq_n_s32(res, sum_tmp);
  129. res = arm_requantize_mve(res, dst_multipliers[i], dst_shifts[i]);
  130. res = vaddq_n_s32(res, dst_offset);
  131. res = vmaxq_s32(res, vdupq_n_s32(activation_min));
  132. res = vminq_s32(res, vdupq_n_s32(activation_max));
  133. const uint32x4_t scatter_offset = {
  134. 0, (uint32_t)row_address_offset, (uint32_t)row_address_offset * 2, (uint32_t)row_address_offset * 3};
  135. vstrbq_scatter_offset_s32(dst, scatter_offset, res);
  136. dst++;
  137. }
  138. lhs += 4 * lhs_cols_offset;
  139. dst += 4 * row_address_offset - rhs_rows;
  140. }
  141. for (; i_items < lhs_rows; i_items++)
  142. {
  143. int32_t acc[4];
  144. const int32_t *multipliers = dst_multipliers;
  145. const int32_t *shifts = dst_shifts;
  146. for (int i = 0; i < rhs_rows; i++)
  147. {
  148. int32_t acc_n0 = 0;
  149. const int8_t *lhs_vec = lhs;
  150. const int8_t *col_base = rhs + i * rhs_cols;
  151. int32_t sum_tmp = 0;
  152. #if defined(ARM_MATH_AUTOVECTORIZE)
  153. for (int j = 0; j < rhs_cols; j++)
  154. {
  155. int32_t col = col_base[j];
  156. sum_tmp += col;
  157. acc_n0 += lhs_vec[j] * col;
  158. }
  159. #else
  160. __ASM volatile(" .p2align 2 \n"
  161. " wlstp.8 lr, %[cnt], 1f \n"
  162. " mov %[sum], 0 \n"
  163. " mov %[out0], 0 \n"
  164. " vldrb.8 q0, [%[col]], #16 \n"
  165. "2: \n"
  166. " vaddva.s8 %[sum], q0 \n"
  167. " vldrb.8 q1, [%[row0]], #16 \n"
  168. " vmladava.s8 %[out0], q0, q1 \n"
  169. " vldrb.8 q0, [%[col]], #16 \n"
  170. " letp lr, 2b \n"
  171. "1: \n"
  172. : [col] "+r"(col_base), [sum] "=Te"(sum_tmp), [row0] "+r"(lhs_vec), [out0] "=Te"(acc_n0)
  173. : [cnt] "r"(rhs_cols)
  174. : "q0", "q1", "memory", "r14");
  175. #endif
  176. sum_tmp *= lhs_offset;
  177. sum_tmp += acc_n0;
  178. if (bias)
  179. {
  180. sum_tmp += bias[i];
  181. }
  182. const int32_t index = i & 0x3;
  183. acc[index] = sum_tmp;
  184. if (index == 3)
  185. {
  186. int32x4_t res = vldrwq_s32(acc);
  187. res = arm_requantize_mve_32x4(res, vldrwq_s32(multipliers), vldrwq_s32(shifts));
  188. multipliers += 4;
  189. shifts += 4;
  190. res = vaddq_n_s32(res, dst_offset);
  191. res = vmaxq_s32(res, vdupq_n_s32(activation_min));
  192. res = vminq_s32(res, vdupq_n_s32(activation_max));
  193. vstrbq_s32(dst, res);
  194. dst += 4;
  195. }
  196. }
  197. lhs += lhs_cols_offset;
  198. const int32_t tail_rows = rhs_rows & 0x3;
  199. for (int i = 0; i < tail_rows; i++)
  200. {
  201. int32_t acc_n0 = acc[i];
  202. acc_n0 = arm_nn_requantize(acc_n0, multipliers[i], shifts[i]);
  203. acc_n0 += dst_offset;
  204. acc_n0 = MAX(acc_n0, activation_min);
  205. acc_n0 = MIN(acc_n0, activation_max);
  206. *dst++ = (int8_t)acc_n0;
  207. }
  208. dst += row_address_offset - rhs_rows;
  209. }
  210. #elif defined(ARM_MATH_DSP)
  211. (void)row_address_offset;
  212. const int32_t rhs_off0 = rhs_cols - 4;
  213. const int32_t lhs_off0 = lhs_cols_offset - 4;
  214. for (int32_t rhs_rows_idx = 0; rhs_rows_idx <= (rhs_rows - 2); rhs_rows_idx += 2)
  215. {
  216. const int8_t *lhs_ptr = &lhs[0];
  217. int8_t *dst_ptr = &dst[0];
  218. int32_t lhs_offset_contribution0 = 0;
  219. int32_t lhs_offset_contribution1 = 0;
  220. for (int32_t x = 0; x < rhs_cols; ++x)
  221. {
  222. lhs_offset_contribution0 += rhs[x];
  223. lhs_offset_contribution1 += rhs[x + rhs_cols];
  224. }
  225. lhs_offset_contribution0 *= lhs_offset;
  226. lhs_offset_contribution1 *= lhs_offset;
  227. if (bias)
  228. {
  229. lhs_offset_contribution0 += bias[rhs_rows_idx];
  230. lhs_offset_contribution1 += bias[rhs_rows_idx + 1];
  231. }
  232. int32_t lhs_rows_idx = lhs_rows >> 1;
  233. while (lhs_rows_idx)
  234. {
  235. const int8_t *rhs_ptr = &rhs[0];
  236. int32_t res00 = lhs_offset_contribution0;
  237. int32_t res01 = lhs_offset_contribution1;
  238. int32_t res10 = lhs_offset_contribution0;
  239. int32_t res11 = lhs_offset_contribution1;
  240. int32_t rhs_cols_idx = 0;
  241. int32_t val0, val1, val2, val3, val4, val5;
  242. for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
  243. {
  244. val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
  245. val2 = SXTB16(val1);
  246. val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
  247. val3 = SXTB16(val0);
  248. val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
  249. val1 = SXTB16_RORn(val1, 8);
  250. val0 = SXTB16_RORn(val0, 8);
  251. // 4 x MAC res00, res01
  252. res00 = SMLAD(val3, val2, res00);
  253. val5 = SXTB16(val4);
  254. res00 = SMLAD(val0, val1, res00);
  255. val4 = SXTB16_RORn(val4, 8);
  256. res01 = SMLAD(val3, val5, res01);
  257. res01 = SMLAD(val0, val4, res01);
  258. // 4 x MAC res10, res11
  259. val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
  260. val3 = SXTB16(val0);
  261. val0 = SXTB16_RORn(val0, 8);
  262. res10 = SMLAD(val3, val2, res10);
  263. res11 = SMLAD(val3, val5, res11);
  264. res10 = SMLAD(val0, val1, res10);
  265. val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
  266. res11 = SMLAD(val0, val4, res11);
  267. val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
  268. val2 = SXTB16(val1);
  269. val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
  270. val3 = SXTB16(val0);
  271. val1 = SXTB16_RORn(val1, 8);
  272. val0 = SXTB16_RORn(val0, 8);
  273. // 4 x MAC res00, res01
  274. res00 = SMLAD(val3, val2, res00);
  275. val5 = SXTB16(val4);
  276. res00 = SMLAD(val0, val1, res00);
  277. val4 = SXTB16_RORn(val4, 8);
  278. res01 = SMLAD(val3, val5, res01);
  279. res01 = SMLAD(val0, val4, res01);
  280. // 4 x MAC res10, res11
  281. val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
  282. val3 = SXTB16(val0);
  283. val0 = SXTB16_RORn(val0, 8);
  284. res10 = SMLAD(val3, val2, res10);
  285. res11 = SMLAD(val3, val5, res11);
  286. res10 = SMLAD(val0, val1, res10);
  287. val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
  288. res11 = SMLAD(val0, val4, res11);
  289. val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
  290. val2 = SXTB16(val1);
  291. val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
  292. val3 = SXTB16(val0);
  293. val1 = SXTB16_RORn(val1, 8);
  294. val0 = SXTB16_RORn(val0, 8);
  295. // 4 x MAC res00, res01
  296. res00 = SMLAD(val3, val2, res00);
  297. val5 = SXTB16(val4);
  298. res00 = SMLAD(val0, val1, res00);
  299. val4 = SXTB16_RORn(val4, 8);
  300. res01 = SMLAD(val3, val5, res01);
  301. res01 = SMLAD(val0, val4, res01);
  302. // 4 x MAC res10, res11
  303. val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
  304. val3 = SXTB16(val0);
  305. val0 = SXTB16_RORn(val0, 8);
  306. res10 = SMLAD(val3, val2, res10);
  307. res11 = SMLAD(val3, val5, res11);
  308. res10 = SMLAD(val0, val1, res10);
  309. val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
  310. res11 = SMLAD(val0, val4, res11);
  311. val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
  312. val2 = SXTB16(val1);
  313. val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
  314. val3 = SXTB16(val0);
  315. val1 = SXTB16_RORn(val1, 8);
  316. val0 = SXTB16_RORn(val0, 8);
  317. // 4 x MAC res00, res01
  318. res00 = SMLAD(val3, val2, res00);
  319. val5 = SXTB16(val4);
  320. res00 = SMLAD(val0, val1, res00);
  321. val4 = SXTB16_RORn(val4, 8);
  322. res01 = SMLAD(val3, val5, res01);
  323. res01 = SMLAD(val0, val4, res01);
  324. // 4 x MAC res10, res11
  325. val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
  326. val3 = SXTB16(val0);
  327. val0 = SXTB16_RORn(val0, 8);
  328. res10 = SMLAD(val3, val2, res10);
  329. res11 = SMLAD(val3, val5, res11);
  330. res10 = SMLAD(val0, val1, res10);
  331. res11 = SMLAD(val0, val4, res11);
  332. }
  333. for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
  334. {
  335. val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
  336. val2 = SXTB16(val1);
  337. val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
  338. val3 = SXTB16(val0);
  339. val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
  340. val1 = SXTB16_RORn(val1, 8);
  341. val0 = SXTB16_RORn(val0, 8);
  342. // 4 x MAC res00, res01
  343. res00 = SMLAD(val3, val2, res00);
  344. val5 = SXTB16(val4);
  345. res00 = SMLAD(val0, val1, res00);
  346. val4 = SXTB16_RORn(val4, 8);
  347. res01 = SMLAD(val3, val5, res01);
  348. res01 = SMLAD(val0, val4, res01);
  349. // 4 x MAC res10, res11
  350. val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
  351. val3 = SXTB16(val0);
  352. val0 = SXTB16_RORn(val0, 8);
  353. res10 = SMLAD(val3, val2, res10);
  354. res11 = SMLAD(val3, val5, res11);
  355. res10 = SMLAD(val0, val1, res10);
  356. res11 = SMLAD(val0, val4, res11);
  357. }
  358. for (; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
  359. {
  360. int8_t rhs_value0 = rhs_ptr[0];
  361. int8_t rhs_value1 = rhs_ptr[rhs_cols];
  362. int8_t lhs_value = lhs_ptr[0];
  363. res00 += lhs_value * rhs_value0;
  364. res01 += lhs_value * rhs_value1;
  365. lhs_value = lhs_ptr[lhs_cols_offset];
  366. res10 += lhs_value * rhs_value0;
  367. res11 += lhs_value * rhs_value1;
  368. ++rhs_ptr;
  369. ++lhs_ptr;
  370. }
  371. // Quantize down
  372. res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
  373. res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
  374. res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
  375. res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
  376. // Add offset
  377. res00 += dst_offset;
  378. res01 += dst_offset;
  379. res10 += dst_offset;
  380. res11 += dst_offset;
  381. // Clamp the result
  382. res00 = MAX(res00, activation_min);
  383. res00 = MIN(res00, activation_max);
  384. res01 = MAX(res01, activation_min);
  385. res01 = MIN(res01, activation_max);
  386. res10 = MAX(res10, activation_min);
  387. res10 = MIN(res10, activation_max);
  388. res11 = MAX(res11, activation_min);
  389. res11 = MIN(res11, activation_max);
  390. dst_ptr[0] = (int8_t)res00;
  391. dst_ptr[1] = (int8_t)res01;
  392. dst_ptr += rhs_rows;
  393. dst_ptr[0] = (int8_t)res10;
  394. dst_ptr[1] = (int8_t)res11;
  395. dst_ptr += rhs_rows;
  396. lhs_ptr -= rhs_cols;
  397. lhs_ptr += 2 * lhs_cols_offset;
  398. lhs_rows_idx--;
  399. }
  400. // Left-over rows
  401. if (lhs_rows % 2)
  402. {
  403. const int8_t *rhs_ptr = &rhs[0];
  404. int32_t res00 = lhs_offset_contribution0;
  405. int32_t res01 = lhs_offset_contribution1;
  406. int32_t rhs_cols_idx = 0;
  407. int32_t val0, val1, val2, val3, val4, val5;
  408. for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
  409. {
  410. val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
  411. val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
  412. val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
  413. val3 = SXTB16(val0);
  414. val5 = SXTB16(val2);
  415. val4 = SXTB16(val1);
  416. val0 = SXTB16_RORn(val0, 8);
  417. val2 = SXTB16_RORn(val2, 8);
  418. val1 = SXTB16_RORn(val1, 8);
  419. // 4 x MAC res00, res01
  420. res00 = SMLAD(val5, val3, res00);
  421. res00 = SMLAD(val2, val0, res00);
  422. res01 = SMLAD(val5, val4, res01);
  423. res01 = SMLAD(val2, val1, res01);
  424. val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
  425. val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
  426. val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
  427. val3 = SXTB16(val0);
  428. val5 = SXTB16(val2);
  429. val4 = SXTB16(val1);
  430. val0 = SXTB16_RORn(val0, 8);
  431. val2 = SXTB16_RORn(val2, 8);
  432. val1 = SXTB16_RORn(val1, 8);
  433. // 4 x MAC res00, res01
  434. res00 = SMLAD(val5, val3, res00);
  435. res00 = SMLAD(val2, val0, res00);
  436. res01 = SMLAD(val5, val4, res01);
  437. res01 = SMLAD(val2, val1, res01);
  438. val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
  439. val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
  440. val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
  441. val3 = SXTB16(val0);
  442. val5 = SXTB16(val2);
  443. val4 = SXTB16(val1);
  444. val0 = SXTB16_RORn(val0, 8);
  445. val2 = SXTB16_RORn(val2, 8);
  446. val1 = SXTB16_RORn(val1, 8);
  447. // 4 x MAC res00, res01
  448. res00 = SMLAD(val5, val3, res00);
  449. res00 = SMLAD(val2, val0, res00);
  450. res01 = SMLAD(val5, val4, res01);
  451. res01 = SMLAD(val2, val1, res01);
  452. val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
  453. val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
  454. val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
  455. val3 = SXTB16(val0);
  456. val5 = SXTB16(val2);
  457. val4 = SXTB16(val1);
  458. val0 = SXTB16_RORn(val0, 8);
  459. val2 = SXTB16_RORn(val2, 8);
  460. val1 = SXTB16_RORn(val1, 8);
  461. // 4 x MAC res00, res01
  462. res00 = SMLAD(val5, val3, res00);
  463. res00 = SMLAD(val2, val0, res00);
  464. res01 = SMLAD(val5, val4, res01);
  465. res01 = SMLAD(val2, val1, res01);
  466. }
  467. for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
  468. {
  469. val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
  470. val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
  471. val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
  472. val3 = SXTB16(val0);
  473. val5 = SXTB16(val2);
  474. val4 = SXTB16(val1);
  475. val0 = SXTB16_RORn(val0, 8);
  476. val2 = SXTB16_RORn(val2, 8);
  477. val1 = SXTB16_RORn(val1, 8);
  478. // 4 x MAC res00, res01
  479. res00 = SMLAD(val5, val3, res00);
  480. res00 = SMLAD(val2, val0, res00);
  481. res01 = SMLAD(val5, val4, res01);
  482. res01 = SMLAD(val2, val1, res01);
  483. }
  484. // Left-over accumulations
  485. for (; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
  486. {
  487. int8_t rhs_value0 = rhs_ptr[0];
  488. int8_t rhs_value1 = rhs_ptr[rhs_cols];
  489. int8_t lhs_value = lhs_ptr[0];
  490. res00 += lhs_value * rhs_value0;
  491. res01 += lhs_value * rhs_value1;
  492. ++rhs_ptr;
  493. ++lhs_ptr;
  494. }
  495. // Quantize down
  496. res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
  497. res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
  498. // Add offset
  499. res00 += dst_offset;
  500. res01 += dst_offset;
  501. // Clamp the result
  502. res00 = MAX(res00, activation_min);
  503. res00 = MIN(res00, activation_max);
  504. res01 = MAX(res01, activation_min);
  505. res01 = MIN(res01, activation_max);
  506. dst_ptr[0] = (int8_t)res00;
  507. dst_ptr[1] = (int8_t)res01;
  508. }
  509. rhs += 2 * rhs_cols;
  510. dst += 2;
  511. }
  512. if (rhs_rows % 2)
  513. {
  514. const int8_t *lhs_ptr = &lhs[0];
  515. int8_t *dst_ptr = &dst[0];
  516. for (int32_t lhs_rows_idx = 0; lhs_rows_idx < lhs_rows; ++lhs_rows_idx)
  517. {
  518. const int8_t *rhs_ptr = &rhs[0];
  519. int32_t res00 = 0;
  520. if (bias)
  521. {
  522. res00 = bias[rhs_rows - 1];
  523. }
  524. for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
  525. {
  526. int32_t rhs_value = rhs_ptr[0];
  527. int32_t lhs_value = lhs_ptr[0] + lhs_offset;
  528. res00 += lhs_value * rhs_value;
  529. ++rhs_ptr;
  530. ++lhs_ptr;
  531. }
  532. lhs_ptr -= rhs_cols;
  533. lhs_ptr += lhs_cols_offset;
  534. // Quantize down
  535. res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows - 1], dst_shifts[rhs_rows - 1]);
  536. // Add offset
  537. res00 += dst_offset;
  538. // Clamp the result
  539. res00 = MAX(res00, activation_min);
  540. res00 = MIN(res00, activation_max);
  541. dst_ptr[0] = (int8_t)res00;
  542. dst_ptr += rhs_rows;
  543. }
  544. }
  545. #else
  546. (void)row_address_offset;
  547. for (int32_t rhs_rows_idx = 0; rhs_rows_idx <= (rhs_rows - 2); rhs_rows_idx += 2)
  548. {
  549. const int8_t *lhs_ptr = &lhs[0];
  550. int8_t *dst_ptr = &dst[0];
  551. int32_t lhs_offset_contribution0 = 0;
  552. int32_t lhs_offset_contribution1 = 0;
  553. for (int32_t x = 0; x < rhs_cols; ++x)
  554. {
  555. lhs_offset_contribution0 += rhs[x];
  556. lhs_offset_contribution1 += rhs[x + rhs_cols];
  557. }
  558. lhs_offset_contribution0 *= lhs_offset;
  559. lhs_offset_contribution1 *= lhs_offset;
  560. if (bias)
  561. {
  562. lhs_offset_contribution0 += bias[rhs_rows_idx];
  563. lhs_offset_contribution1 += bias[rhs_rows_idx + 1];
  564. }
  565. int32_t lhs_rows_idx = lhs_rows >> 1;
  566. while (lhs_rows_idx)
  567. {
  568. const int8_t *rhs_ptr = &rhs[0];
  569. int32_t res00 = lhs_offset_contribution0;
  570. int32_t res01 = lhs_offset_contribution1;
  571. int32_t res10 = lhs_offset_contribution0;
  572. int32_t res11 = lhs_offset_contribution1;
  573. for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
  574. {
  575. int8_t rhs_value0 = rhs_ptr[0];
  576. int8_t rhs_value1 = rhs_ptr[rhs_cols];
  577. int8_t lhs_value = lhs_ptr[0];
  578. res00 += lhs_value * rhs_value0;
  579. res01 += lhs_value * rhs_value1;
  580. lhs_value = lhs_ptr[lhs_cols_offset];
  581. res10 += lhs_value * rhs_value0;
  582. res11 += lhs_value * rhs_value1;
  583. ++rhs_ptr;
  584. ++lhs_ptr;
  585. }
  586. // Quantize down
  587. res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
  588. res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
  589. res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
  590. res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
  591. // Add offset
  592. res00 += dst_offset;
  593. res01 += dst_offset;
  594. res10 += dst_offset;
  595. res11 += dst_offset;
  596. // Clamp the result
  597. res00 = MAX(res00, activation_min);
  598. res00 = MIN(res00, activation_max);
  599. res01 = MAX(res01, activation_min);
  600. res01 = MIN(res01, activation_max);
  601. res10 = MAX(res10, activation_min);
  602. res10 = MIN(res10, activation_max);
  603. res11 = MAX(res11, activation_min);
  604. res11 = MIN(res11, activation_max);
  605. dst_ptr[0] = (int8_t)res00;
  606. dst_ptr[1] = (int8_t)res01;
  607. dst_ptr += rhs_rows;
  608. dst_ptr[0] = (int8_t)res10;
  609. dst_ptr[1] = (int8_t)res11;
  610. dst_ptr += rhs_rows;
  611. lhs_ptr -= rhs_cols;
  612. lhs_ptr += 2 * lhs_cols_offset;
  613. lhs_rows_idx--;
  614. }
  615. // Left-over rows
  616. if (lhs_rows % 2)
  617. {
  618. const int8_t *rhs_ptr = &rhs[0];
  619. int32_t res00 = lhs_offset_contribution0;
  620. int32_t res01 = lhs_offset_contribution1;
  621. for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
  622. {
  623. int8_t rhs_value0 = rhs_ptr[0];
  624. int8_t rhs_value1 = rhs_ptr[rhs_cols];
  625. int8_t lhs_value = lhs_ptr[0];
  626. res00 += lhs_value * rhs_value0;
  627. res01 += lhs_value * rhs_value1;
  628. ++rhs_ptr;
  629. ++lhs_ptr;
  630. }
  631. // Quantize down
  632. res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
  633. res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
  634. // Add offset
  635. res00 += dst_offset;
  636. res01 += dst_offset;
  637. // Clamp the result
  638. res00 = MAX(res00, activation_min);
  639. res00 = MIN(res00, activation_max);
  640. res01 = MAX(res01, activation_min);
  641. res01 = MIN(res01, activation_max);
  642. dst_ptr[0] = (int8_t)res00;
  643. dst_ptr[1] = (int8_t)res01;
  644. }
  645. rhs += 2 * rhs_cols;
  646. dst += 2;
  647. }
  648. if (rhs_rows % 2)
  649. {
  650. const int8_t *lhs_ptr = &lhs[0];
  651. int8_t *dst_ptr = &dst[0];
  652. for (int32_t lhs_rows_idx = 0; lhs_rows_idx < lhs_rows; ++lhs_rows_idx)
  653. {
  654. const int8_t *rhs_ptr = &rhs[0];
  655. int32_t res00 = 0;
  656. if (bias)
  657. {
  658. res00 = bias[rhs_rows - 1];
  659. }
  660. for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
  661. {
  662. int32_t rhs_value = rhs_ptr[0];
  663. int32_t lhs_value = lhs_ptr[0] + lhs_offset;
  664. res00 += lhs_value * rhs_value;
  665. ++rhs_ptr;
  666. ++lhs_ptr;
  667. }
  668. lhs_ptr -= rhs_cols;
  669. lhs_ptr += lhs_cols_offset;
  670. // Quantize down
  671. res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows - 1], dst_shifts[rhs_rows - 1]);
  672. // Add offset
  673. res00 += dst_offset;
  674. // Clamp the result
  675. res00 = MAX(res00, activation_min);
  676. res00 = MIN(res00, activation_max);
  677. dst_ptr[0] = (int8_t)res00;
  678. dst_ptr += rhs_rows;
  679. }
  680. }
  681. #endif
  682. return ARM_CMSIS_NN_SUCCESS;
  683. }
  684. /**
  685. * @} end of Doxygen group
  686. */