arm_nn_vec_mat_mult_t_s4.c 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. /*
  2. * SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_nn_vec_mat_mult_t_s4
  21. * Description: s4 vector by matrix (transposed) multiplication
  22. *
  23. * $Date: 10 October 2023
  24. * $Revision: V.1.0.0
  25. *
  26. * Target : Arm(R) M-Profile Architecture
  27. *
  28. * -------------------------------------------------------------------- */
  29. #include "arm_nnsupportfunctions.h"
  30. /**
  31. */
  32. /**
  33. * @defgroup supportFC Fully Connected
  34. *
  35. * Support functions for Fully Connected
  36. *
  37. */
  38. /**
  39. * @addtogroup supportFC
  40. * @{
  41. */
  42. /*
  43. * s4 vector(lhs) by matrix (transposed) multiplication
  44. *
  45. * Refer header file for details.
  46. *
  47. */
  48. arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s4(const int8_t *lhs,
  49. const int8_t *packed_rhs,
  50. const int32_t *bias,
  51. int8_t *dst,
  52. const int32_t lhs_offset,
  53. const int32_t dst_offset,
  54. const int32_t dst_multiplier,
  55. const int32_t dst_shift,
  56. const int32_t rhs_cols,
  57. const int32_t rhs_rows,
  58. const int32_t activation_min,
  59. const int32_t activation_max,
  60. const int32_t address_offset)
  61. {
  62. #if defined(ARM_MATH_DSP)
  63. const int8_t *rhs_ptr = &packed_rhs[0];
  64. const int rhs_offset = rhs_cols * (rhs_rows / 4);
  65. int32_t spillover0, spillover1;
  66. const int16_t lhs_offset_s16 = (int16_t)lhs_offset;
  67. const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
  68. for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < rhs_rows / 4; ++i_row_loop_cnt)
  69. {
  70. const int8_t *lhs_ptr = &lhs[0];
  71. int32_t res0 = 0;
  72. int32_t res1 = 0;
  73. if (bias)
  74. {
  75. res0 += *bias;
  76. res1 += bias[2 * (rhs_rows / 4)];
  77. ++bias;
  78. }
  79. for (int rhs_cols_idx = 0; rhs_cols_idx < (rhs_cols / 4); ++rhs_cols_idx)
  80. {
  81. int32_t lhs_high, rhs_high0, rhs_low0, lhs_low, rhs_high1, rhs_low1;
  82. read_and_pad_s4(rhs_ptr, &rhs_low0, &rhs_high0);
  83. read_and_pad_s4((const int8_t *)&rhs_ptr[rhs_offset], &rhs_low1, &rhs_high1);
  84. rhs_ptr += 2;
  85. lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
  86. lhs_low = SXTAB16(lhs_offset_s16x2, lhs_high);
  87. lhs_high = SXTAB16_RORn(lhs_offset_s16x2, lhs_high, 8);
  88. res0 = SMLAD(lhs_low, rhs_low0, res0);
  89. res0 = SMLAD(lhs_high, rhs_high0, res0);
  90. res1 = SMLAD(lhs_low, rhs_low1, res1);
  91. res1 = SMLAD(lhs_high, rhs_high1, res1);
  92. }
  93. if (((rhs_cols % 4) == 2) || ((rhs_cols % 4) == 3))
  94. {
  95. const int32_t rhs_value0 = rhs_ptr[0];
  96. const int32_t lower0 = (int8_t)(rhs_value0 << 4) >> 4;
  97. const int32_t higher0 = rhs_value0 >> 4;
  98. const int32_t rhs_value1 = rhs_ptr[rhs_offset];
  99. const int32_t lower1 = (int8_t)(rhs_value1 << 4) >> 4;
  100. const int32_t higher1 = rhs_value1 >> 4;
  101. const int32_t lhs_value_0 = lhs_ptr[0] + lhs_offset;
  102. const int32_t lhs_value_1 = lhs_ptr[1] + lhs_offset;
  103. res0 += lhs_value_0 * lower0;
  104. res0 += lhs_value_1 * higher0;
  105. res1 += lhs_value_0 * lower1;
  106. res1 += lhs_value_1 * higher1;
  107. ++rhs_ptr;
  108. lhs_ptr += 2;
  109. }
  110. if (rhs_cols % 2 == 1)
  111. {
  112. const int32_t rhs_low0 = (int8_t)(rhs_ptr[0] << 4) >> 4;
  113. const int32_t rhs_high0 = rhs_ptr[0] >> 4;
  114. const int32_t rhs_low1 = (int8_t)(rhs_ptr[rhs_offset] << 4) >> 4;
  115. const int32_t rhs_high1 = rhs_ptr[rhs_offset] >> 4;
  116. const int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
  117. lhs_ptr = &lhs[0];
  118. const int32_t lhs_high = (int8_t)lhs_ptr[0] + lhs_offset;
  119. ++lhs_ptr;
  120. res0 += lhs_low * rhs_low0;
  121. spillover0 = lhs_high * rhs_high0;
  122. res1 += lhs_low * rhs_low1;
  123. spillover1 = lhs_high * rhs_high1;
  124. ++rhs_ptr;
  125. }
  126. else
  127. {
  128. spillover0 = 0;
  129. spillover1 = 0;
  130. lhs_ptr = &lhs[0];
  131. }
  132. // Quantize down
  133. res0 = arm_nn_requantize(res0, dst_multiplier, dst_shift);
  134. res1 = arm_nn_requantize(res1, dst_multiplier, dst_shift);
  135. // Add offset
  136. res0 += dst_offset;
  137. res1 += dst_offset;
  138. // Clamp the result
  139. res0 = MAX(res0, activation_min);
  140. res0 = MIN(res0, activation_max);
  141. res1 = MAX(res1, activation_min);
  142. res1 = MIN(res1, activation_max);
  143. *dst = (int8_t)res0;
  144. *(dst + 2 * address_offset * ((rhs_rows) / 4)) = (int8_t)res1;
  145. dst += address_offset;
  146. res0 = spillover0;
  147. res1 = spillover1;
  148. if (bias)
  149. {
  150. res0 += *bias;
  151. res1 += bias[2 * (rhs_rows / 4)];
  152. ++bias;
  153. }
  154. for (int rhs_cols_idx = 0; rhs_cols_idx < rhs_cols / 4; ++rhs_cols_idx)
  155. {
  156. int32_t lhs_high, rhs_high0, rhs_low0, lhs_low, rhs_high1, rhs_low1;
  157. read_and_pad_s4(rhs_ptr, &rhs_low0, &rhs_high0);
  158. read_and_pad_s4((const int8_t *)&rhs_ptr[rhs_offset], &rhs_low1, &rhs_high1);
  159. rhs_ptr += 2;
  160. lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
  161. lhs_low = SXTAB16(lhs_offset_s16x2, lhs_high);
  162. lhs_high = SXTAB16_RORn(lhs_offset_s16x2, lhs_high, 8);
  163. res0 = SMLAD(lhs_low, rhs_low0, res0);
  164. res0 = SMLAD(lhs_high, rhs_high0, res0);
  165. res1 = SMLAD(lhs_low, rhs_low1, res1);
  166. res1 = SMLAD(lhs_high, rhs_high1, res1);
  167. }
  168. if (((rhs_cols % 4) == 2) || ((rhs_cols % 4) == 3))
  169. {
  170. const int32_t rhs_value0 = rhs_ptr[0];
  171. const int32_t lower0 = (int8_t)(rhs_value0 << 4) >> 4;
  172. const int32_t higher0 = rhs_value0 >> 4;
  173. const int32_t rhs_value1 = rhs_ptr[rhs_offset];
  174. const int32_t lower1 = (int8_t)(rhs_value1 << 4) >> 4;
  175. const int32_t higher1 = rhs_value1 >> 4;
  176. const int32_t lhs_value_0 = lhs_ptr[0] + lhs_offset;
  177. const int32_t lhs_value_1 = lhs_ptr[1] + lhs_offset;
  178. res0 += lhs_value_0 * lower0;
  179. res0 += lhs_value_1 * higher0;
  180. res1 += lhs_value_0 * lower1;
  181. res1 += lhs_value_1 * higher1;
  182. ++rhs_ptr;
  183. lhs_ptr += 2;
  184. }
  185. // Quantize down
  186. res0 = arm_nn_requantize(res0, dst_multiplier, dst_shift);
  187. res1 = arm_nn_requantize(res1, dst_multiplier, dst_shift);
  188. // Add offset
  189. res0 += dst_offset;
  190. res1 += dst_offset;
  191. // Clamp the result
  192. res0 = MAX(res0, activation_min);
  193. res0 = MIN(res0, activation_max);
  194. res1 = MAX(res1, activation_min);
  195. res1 = MIN(res1, activation_max);
  196. *dst = (int8_t)res0;
  197. *(dst + 2 * address_offset * ((rhs_rows) / 4)) = (int8_t)res1;
  198. dst += address_offset;
  199. }
  200. const int8_t *lhs_ptr = &lhs[0];
  201. spillover0 = 0;
  202. for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < rhs_rows % 4; ++i_row_loop_cnt)
  203. {
  204. int32_t res0 = spillover0;
  205. if (bias)
  206. {
  207. res0 += bias[2 * (rhs_rows / 4)];
  208. ++bias;
  209. }
  210. for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols / 4; ++rhs_cols_idx)
  211. {
  212. int32_t lhs_high, rhs_high0, rhs_low0, lhs_low;
  213. read_and_pad_s4((const int8_t *)&rhs_ptr[rhs_offset], &rhs_high0, &rhs_low0);
  214. rhs_ptr += 2;
  215. lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
  216. lhs_low = SXTAB16(lhs_offset_s16x2, lhs_high);
  217. lhs_high = SXTAB16_RORn(lhs_offset_s16x2, lhs_high, 8);
  218. res0 = SMLAD(lhs_low, rhs_high0, res0);
  219. res0 = SMLAD(lhs_high, rhs_low0, res0);
  220. }
  221. if ((rhs_cols % 4) == 2 || (rhs_cols % 4 == 3))
  222. {
  223. const int32_t rhs_value0 = rhs_ptr[rhs_offset];
  224. const int32_t lower0 = (int8_t)(rhs_value0 << 4) >> 4;
  225. const int32_t higher0 = rhs_value0 >> 4;
  226. const int32_t lhs_value_0 = lhs_ptr[0] + lhs_offset;
  227. const int32_t lhs_value_1 = lhs_ptr[1] + lhs_offset;
  228. res0 += lhs_value_0 * lower0;
  229. res0 += lhs_value_1 * higher0;
  230. ++rhs_ptr;
  231. lhs_ptr += 2;
  232. }
  233. if ((rhs_cols % 2 == 1) && (i_row_loop_cnt % 2 == 0))
  234. {
  235. const int32_t rhs_low0 = (int8_t)(rhs_ptr[rhs_offset] << 4) >> 4;
  236. const int32_t rhs_high0 = rhs_ptr[rhs_offset] >> 4;
  237. const int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
  238. lhs_ptr = &lhs[0];
  239. const int32_t lhs_high = (int8_t)lhs_ptr[0] + lhs_offset;
  240. ++lhs_ptr;
  241. res0 += lhs_low * rhs_low0;
  242. spillover0 = lhs_high * rhs_high0;
  243. ++rhs_ptr;
  244. }
  245. else
  246. {
  247. spillover0 = 0;
  248. lhs_ptr = &lhs[0];
  249. }
  250. // Quantize down
  251. res0 = arm_nn_requantize(res0, dst_multiplier, dst_shift);
  252. // Add offset
  253. res0 += dst_offset;
  254. // Clamp the result
  255. res0 = MAX(res0, activation_min);
  256. res0 = MIN(res0, activation_max);
  257. *(dst + 2 * address_offset * ((rhs_rows) / 4)) = (int8_t)res0;
  258. dst += address_offset;
  259. }
  260. #else
  261. const int8_t *rhs_ptr = &packed_rhs[0];
  262. int32_t spillover0, spillover1;
  263. const int rhs_offset = rhs_cols * ((rhs_rows) / 4);
  264. for (int i_row_loop_cnt = 0; i_row_loop_cnt < rhs_rows / 4; ++i_row_loop_cnt)
  265. {
  266. const int8_t *lhs_ptr = &lhs[0];
  267. int32_t res0 = 0;
  268. int32_t res1 = 0;
  269. if (bias)
  270. {
  271. res0 += *bias;
  272. res1 += bias[2 * (rhs_rows / 4)];
  273. ++bias;
  274. }
  275. for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols / 2; ++rhs_cols_idx)
  276. {
  277. const int32_t rhs_low0 = (int8_t)(rhs_ptr[0] << 4) >> 4;
  278. const int32_t rhs_high0 = rhs_ptr[0] >> 4;
  279. const int32_t rhs_low1 = (int8_t)(rhs_ptr[rhs_offset] << 4) >> 4;
  280. const int32_t rhs_high1 = rhs_ptr[rhs_offset] >> 4;
  281. const int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
  282. const int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
  283. res0 += lhs_low * rhs_low0;
  284. res0 += lhs_high * rhs_high0;
  285. res1 += lhs_low * rhs_low1;
  286. res1 += lhs_high * rhs_high1;
  287. ++rhs_ptr;
  288. lhs_ptr += 2;
  289. }
  290. if (rhs_cols % 2 == 1)
  291. {
  292. const int32_t rhs_low0 = (int8_t)(rhs_ptr[0] << 4) >> 4;
  293. const int32_t rhs_high0 = rhs_ptr[0] >> 4;
  294. const int32_t rhs_low1 = (int8_t)(rhs_ptr[rhs_offset] << 4) >> 4;
  295. const int32_t rhs_high1 = rhs_ptr[rhs_offset] >> 4;
  296. const int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
  297. lhs_ptr = &lhs[0];
  298. const int32_t lhs_high = (int8_t)lhs_ptr[0] + lhs_offset;
  299. ++lhs_ptr;
  300. res0 += lhs_low * rhs_low0;
  301. spillover0 = lhs_high * rhs_high0;
  302. res1 += lhs_low * rhs_low1;
  303. spillover1 = lhs_high * rhs_high1;
  304. ++rhs_ptr;
  305. }
  306. else
  307. {
  308. spillover0 = 0;
  309. spillover1 = 0;
  310. lhs_ptr = &lhs[0];
  311. }
  312. // Quantize down
  313. res0 = arm_nn_requantize(res0, dst_multiplier, dst_shift);
  314. res1 = arm_nn_requantize(res1, dst_multiplier, dst_shift);
  315. // Add offset
  316. res0 += dst_offset;
  317. res1 += dst_offset;
  318. // Clamp the result
  319. res0 = MAX(res0, activation_min);
  320. res0 = MIN(res0, activation_max);
  321. res1 = MAX(res1, activation_min);
  322. res1 = MIN(res1, activation_max);
  323. *dst = (int8_t)res0;
  324. *(dst + 2 * address_offset * ((rhs_rows) / 4)) = (int8_t)res1;
  325. dst += address_offset;
  326. res0 = spillover0;
  327. res1 = spillover1;
  328. if (bias)
  329. {
  330. res0 += *bias;
  331. res1 += bias[2 * (rhs_rows / 4)];
  332. ++bias;
  333. }
  334. for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols / 2; ++rhs_cols_idx)
  335. {
  336. const int32_t rhs_low0 = (int8_t)(rhs_ptr[0] << 4) >> 4;
  337. const int32_t rhs_high0 = rhs_ptr[0] >> 4;
  338. const int32_t rhs_low1 = (int8_t)(rhs_ptr[rhs_offset] << 4) >> 4;
  339. const int32_t rhs_high1 = rhs_ptr[rhs_offset] >> 4;
  340. const int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
  341. const int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
  342. res0 += lhs_low * rhs_low0;
  343. res0 += lhs_high * rhs_high0;
  344. res1 += lhs_low * rhs_low1;
  345. res1 += lhs_high * rhs_high1;
  346. ++rhs_ptr;
  347. lhs_ptr += 2;
  348. }
  349. // Quantize down
  350. res0 = arm_nn_requantize(res0, dst_multiplier, dst_shift);
  351. res1 = arm_nn_requantize(res1, dst_multiplier, dst_shift);
  352. // Add offset
  353. res0 += dst_offset;
  354. res1 += dst_offset;
  355. // Clamp the result
  356. res0 = MAX(res0, activation_min);
  357. res0 = MIN(res0, activation_max);
  358. res1 = MAX(res1, activation_min);
  359. res1 = MIN(res1, activation_max);
  360. *dst = (int8_t)res0;
  361. *(dst + 2 * address_offset * ((rhs_rows) / 4)) = (int8_t)res1;
  362. dst += address_offset;
  363. }
  364. const int8_t *lhs_ptr = &lhs[0];
  365. spillover0 = 0;
  366. for (int i_row_loop_cnt = 0; i_row_loop_cnt < rhs_rows % 4; ++i_row_loop_cnt)
  367. {
  368. int32_t res0 = spillover0;
  369. if (bias)
  370. {
  371. res0 += bias[2 * (rhs_rows / 4)];
  372. ++bias;
  373. }
  374. for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols / 2; ++rhs_cols_idx)
  375. {
  376. const int32_t rhs_low0 = (int8_t)(rhs_ptr[rhs_offset] << 4) >> 4;
  377. const int32_t rhs_high0 = rhs_ptr[rhs_offset] >> 4;
  378. const int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
  379. const int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
  380. res0 += lhs_low * rhs_low0;
  381. res0 += lhs_high * rhs_high0;
  382. ++rhs_ptr;
  383. lhs_ptr += 2;
  384. }
  385. if ((rhs_cols % 2 == 1) && (i_row_loop_cnt % 2 == 0))
  386. {
  387. const int32_t rhs_low0 = (int8_t)(rhs_ptr[rhs_offset] << 4) >> 4;
  388. const int32_t rhs_high0 = rhs_ptr[rhs_offset] >> 4;
  389. const int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
  390. lhs_ptr = &lhs[0];
  391. const int32_t lhs_high = (int8_t)lhs_ptr[0] + lhs_offset;
  392. ++lhs_ptr;
  393. res0 += lhs_low * rhs_low0;
  394. spillover0 = lhs_high * rhs_high0;
  395. ++rhs_ptr;
  396. }
  397. else
  398. {
  399. spillover0 = 0;
  400. lhs_ptr = &lhs[0];
  401. }
  402. // Quantize down
  403. res0 = arm_nn_requantize(res0, dst_multiplier, dst_shift);
  404. // Add offset
  405. res0 += dst_offset;
  406. // Clamp the result
  407. res0 = MAX(res0, activation_min);
  408. res0 = MIN(res0, activation_max);
  409. *(dst + 2 * address_offset * ((rhs_rows) / 4)) = (int8_t)res0;
  410. dst += address_offset;
  411. }
  412. #endif
  413. return ARM_CMSIS_NN_SUCCESS;
  414. }
  415. /**
  416. * @} end of Doxygen group
  417. */