arm_nn_mat_mult_kernel_s4_s16.c 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. /*
  2. * SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_nn_mat_mult_kernel_s4_s16.c
  21. * Description: Matrix-multiplication function for convolution
  22. *
  23. * $Date: 01 November 2023
  24. * $Revision: V.1.0.0
  25. *
  26. * Target : Arm(R) M-Profile Architecture
  27. * -------------------------------------------------------------------- */
  28. #include "arm_nnsupportfunctions.h"
  29. /*
  30. * Matrix-multiplication function for convolution with per-channel requantization and 4bit weights.
  31. *
  32. * Refer header file for details.
  33. *
  34. */
  35. int8_t *arm_nn_mat_mult_kernel_s4_s16(const int8_t *packed_input_a,
  36. const int16_t *input_b,
  37. const uint16_t output_ch,
  38. const int32_t *out_shift,
  39. const int32_t *out_mult,
  40. const int32_t out_offset,
  41. const int32_t activation_min,
  42. const int32_t activation_max,
  43. const int32_t num_col_a,
  44. const int32_t *const output_bias,
  45. int8_t *out_0)
  46. {
  47. /* set up the second output pointers */
  48. int8_t *out_1 = out_0 + output_ch;
  49. const int32_t *bias = output_bias;
  50. uint16_t row_count = output_ch / 4;
  51. const int8_t *packed_ip_a0 = packed_input_a;
  52. /* this loop over rows in A */
  53. while (row_count)
  54. {
  55. int8_t spillover0 = 0;
  56. int8_t spillover1 = 0;
  57. /* setup pointers for B */
  58. const int16_t *ip_b0 = input_b;
  59. const int16_t *ip_b1 = ip_b0 + num_col_a;
  60. /* Align the second pointer for A.
  61. * This will skip a row so that we can ensure the that spilled rows
  62. * don't offset the symmetry.
  63. */
  64. const int8_t *packed_ip_a1 = packed_ip_a0 + num_col_a;
  65. int32_t ch_0_out_0 = 0;
  66. int32_t ch_0_out_1 = 0;
  67. int32_t ch_1_out_0 = 0;
  68. int32_t ch_1_out_1 = 0;
  69. /* Init accumulator with bias for channel N and N + 1 */
  70. if (bias)
  71. {
  72. ch_0_out_0 = *bias;
  73. ch_0_out_1 = *bias;
  74. bias += 2;
  75. ch_1_out_0 = *bias;
  76. ch_1_out_1 = *bias--;
  77. }
  78. #if defined(ARM_MATH_DSP)
  79. int32_t col_count = num_col_a / 4;
  80. /* accumulate over the vector */
  81. while (col_count)
  82. {
  83. int32_t a01, a02, a11, a12;
  84. int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
  85. int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
  86. read_and_pad_s4_ordered(packed_ip_a0, &a01, &a02);
  87. read_and_pad_s4_ordered(packed_ip_a1, &a11, &a12);
  88. packed_ip_a0 += 2;
  89. packed_ip_a1 += 2;
  90. ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
  91. ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
  92. ch_1_out_0 = SMLAD(a11, b0, ch_1_out_0);
  93. ch_1_out_1 = SMLAD(a11, b1, ch_1_out_1);
  94. b0 = arm_nn_read_q15x2_ia(&ip_b0);
  95. b1 = arm_nn_read_q15x2_ia(&ip_b1);
  96. ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
  97. ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
  98. ch_1_out_0 = SMLAD(a12, b0, ch_1_out_0);
  99. ch_1_out_1 = SMLAD(a12, b1, ch_1_out_1);
  100. col_count--;
  101. } /* while over col_count */
  102. col_count = (num_col_a & 0x3) >> 1;
  103. #else
  104. int32_t col_count = num_col_a >> 1;
  105. #endif
  106. while (col_count)
  107. {
  108. int8_t lower_a0 = (int8_t)(packed_ip_a0[0] << 4) >> 4;
  109. int8_t higher_a0 = packed_ip_a0[0] >> 4;
  110. int16_t b0 = *ip_b0++;
  111. int8_t lower_a1 = (int8_t)(packed_ip_a1[0] << 4) >> 4;
  112. int8_t higher_a1 = packed_ip_a1[0] >> 4;
  113. int16_t b1 = *ip_b1++;
  114. packed_ip_a0++;
  115. packed_ip_a1++;
  116. ch_0_out_0 += lower_a0 * b0;
  117. ch_0_out_1 += lower_a0 * b1;
  118. ch_1_out_0 += lower_a1 * b0;
  119. ch_1_out_1 += lower_a1 * b1;
  120. b0 = *ip_b0++;
  121. b1 = *ip_b1++;
  122. ch_0_out_0 += higher_a0 * b0;
  123. ch_0_out_1 += higher_a0 * b1;
  124. ch_1_out_0 += higher_a1 * b0;
  125. ch_1_out_1 += higher_a1 * b1;
  126. col_count--;
  127. } /* while over col_count */
  128. /* left over column */
  129. if (num_col_a % 2)
  130. {
  131. int8_t lower_a0 = (int8_t)(packed_ip_a0[0] << 4) >> 4;
  132. spillover0 = packed_ip_a0[0] >> 4;
  133. int16_t b0 = *ip_b0++;
  134. int8_t lower_a1 = (int8_t)(packed_ip_a1[0] << 4) >> 4;
  135. spillover1 = packed_ip_a1[0] >> 4;
  136. int16_t b1 = *ip_b1++;
  137. packed_ip_a0++;
  138. packed_ip_a1++;
  139. ch_0_out_0 += lower_a0 * b0;
  140. ch_0_out_1 += lower_a0 * b1;
  141. ch_1_out_0 += lower_a1 * b0;
  142. ch_1_out_1 += lower_a1 * b1;
  143. }
  144. ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
  145. ch_0_out_0 += out_offset;
  146. ch_0_out_0 = MAX(ch_0_out_0, activation_min);
  147. ch_0_out_0 = MIN(ch_0_out_0, activation_max);
  148. *out_0 = (int8_t)ch_0_out_0;
  149. out_0 += 2;
  150. ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
  151. ch_0_out_1 += out_offset;
  152. ch_0_out_1 = MAX(ch_0_out_1, activation_min);
  153. ch_0_out_1 = MIN(ch_0_out_1, activation_max);
  154. *out_1 = (int8_t)ch_0_out_1;
  155. out_1 += 2;
  156. out_mult += 2;
  157. out_shift += 2;
  158. ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift);
  159. ch_1_out_0 += out_offset;
  160. ch_1_out_0 = MAX(ch_1_out_0, activation_min);
  161. ch_1_out_0 = MIN(ch_1_out_0, activation_max);
  162. *out_0-- = (int8_t)ch_1_out_0;
  163. ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift);
  164. ch_1_out_1 += out_offset;
  165. ch_1_out_1 = MAX(ch_1_out_1, activation_min);
  166. ch_1_out_1 = MIN(ch_1_out_1, activation_max);
  167. *out_1-- = (int8_t)ch_1_out_1;
  168. out_mult--;
  169. out_shift--;
  170. /* setup pointers for B */
  171. ip_b0 = input_b;
  172. ip_b1 = ip_b0 + num_col_a;
  173. /* Align the second pointer for A.
  174. * This will skip a row so that we can ensure the that spilled rows
  175. * don't offset the symmetry.
  176. */
  177. packed_ip_a1 = packed_ip_a0 + num_col_a;
  178. ch_0_out_0 = 0;
  179. ch_0_out_1 = 0;
  180. ch_1_out_0 = 0;
  181. ch_1_out_1 = 0;
  182. /* Init accumulator with bias for channel N and N + 1 */
  183. if (bias)
  184. {
  185. ch_0_out_0 = *bias;
  186. ch_0_out_1 = *bias;
  187. bias += 2;
  188. ch_1_out_0 = *bias;
  189. ch_1_out_1 = *bias++;
  190. }
  191. if (num_col_a % 2)
  192. {
  193. int16_t b0 = *ip_b0++;
  194. int16_t b1 = *ip_b1++;
  195. ch_0_out_0 += spillover0 * b0;
  196. ch_0_out_1 += spillover0 * b1;
  197. ch_1_out_0 += spillover1 * b0;
  198. ch_1_out_1 += spillover1 * b1;
  199. }
  200. #if defined(ARM_MATH_DSP)
  201. col_count = num_col_a / 4;
  202. /* accumulate over the vector */
  203. while (col_count)
  204. {
  205. int32_t a01, a02, a11, a12;
  206. int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
  207. int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
  208. read_and_pad_s4_ordered(packed_ip_a0, &a01, &a02);
  209. read_and_pad_s4_ordered(packed_ip_a1, &a11, &a12);
  210. packed_ip_a0 += 2;
  211. packed_ip_a1 += 2;
  212. ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
  213. ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
  214. ch_1_out_0 = SMLAD(a11, b0, ch_1_out_0);
  215. ch_1_out_1 = SMLAD(a11, b1, ch_1_out_1);
  216. b0 = arm_nn_read_q15x2_ia(&ip_b0);
  217. b1 = arm_nn_read_q15x2_ia(&ip_b1);
  218. ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
  219. ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
  220. ch_1_out_0 = SMLAD(a12, b0, ch_1_out_0);
  221. ch_1_out_1 = SMLAD(a12, b1, ch_1_out_1);
  222. col_count--;
  223. } /* while over col_count */
  224. col_count = (num_col_a & 0x3) >> 1;
  225. #else
  226. col_count = num_col_a >> 1;
  227. #endif
  228. while (col_count)
  229. {
  230. int8_t lower_a0 = (int8_t)(packed_ip_a0[0] << 4) >> 4;
  231. int8_t higher_a0 = packed_ip_a0[0] >> 4;
  232. int16_t b0 = *ip_b0++;
  233. int8_t lower_a1 = (int8_t)(packed_ip_a1[0] << 4) >> 4;
  234. int8_t higher_a1 = packed_ip_a1[0] >> 4;
  235. int16_t b1 = *ip_b1++;
  236. packed_ip_a0++;
  237. packed_ip_a1++;
  238. ch_0_out_0 += lower_a0 * b0;
  239. ch_0_out_1 += lower_a0 * b1;
  240. ch_1_out_0 += lower_a1 * b0;
  241. ch_1_out_1 += lower_a1 * b1;
  242. b0 = *ip_b0++;
  243. b1 = *ip_b1++;
  244. ch_0_out_0 += higher_a0 * b0;
  245. ch_0_out_1 += higher_a0 * b1;
  246. ch_1_out_0 += higher_a1 * b0;
  247. ch_1_out_1 += higher_a1 * b1;
  248. col_count--;
  249. } /* while over col_count */
  250. ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
  251. ch_0_out_0 += out_offset;
  252. ch_0_out_0 = MAX(ch_0_out_0, activation_min);
  253. ch_0_out_0 = MIN(ch_0_out_0, activation_max);
  254. *out_0 = (int8_t)ch_0_out_0;
  255. out_0 += 2;
  256. ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
  257. ch_0_out_1 += out_offset;
  258. ch_0_out_1 = MAX(ch_0_out_1, activation_min);
  259. ch_0_out_1 = MIN(ch_0_out_1, activation_max);
  260. *out_1 = (int8_t)ch_0_out_1;
  261. out_1 += 2;
  262. out_mult += 2;
  263. out_shift += 2;
  264. ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift);
  265. ch_1_out_0 += out_offset;
  266. ch_1_out_0 = MAX(ch_1_out_0, activation_min);
  267. ch_1_out_0 = MIN(ch_1_out_0, activation_max);
  268. *out_0++ = (int8_t)ch_1_out_0;
  269. ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift);
  270. ch_1_out_1 += out_offset;
  271. ch_1_out_1 = MAX(ch_1_out_1, activation_min);
  272. ch_1_out_1 = MIN(ch_1_out_1, activation_max);
  273. *out_1++ = (int8_t)ch_1_out_1;
  274. out_mult++;
  275. out_shift++;
  276. /* skip 2 rows */
  277. packed_ip_a0 += num_col_a;
  278. row_count--;
  279. }
  280. /* compute the 0 - 3 rows if any */
  281. int16_t left_over_rows = 0;
  282. while (left_over_rows < output_ch % 4)
  283. {
  284. /* setup pointers for B */
  285. const int16_t *ip_b0 = input_b;
  286. const int16_t *ip_b1 = ip_b0 + num_col_a;
  287. int32_t ch_0_out_0 = 0;
  288. int32_t ch_0_out_1 = 0;
  289. /* load the bias */
  290. if (bias)
  291. {
  292. ch_0_out_0 = *bias;
  293. ch_0_out_1 = *bias++;
  294. }
  295. if (left_over_rows == 1 && num_col_a % 2)
  296. {
  297. int16_t b0 = *ip_b0++;
  298. int16_t b1 = *ip_b1++;
  299. int8_t spilled_column = packed_ip_a0[0] >> 4;
  300. ++packed_ip_a0;
  301. ch_0_out_0 += spilled_column * b0;
  302. ch_0_out_1 += spilled_column * b1;
  303. }
  304. #if defined(ARM_MATH_DSP)
  305. int32_t col_count = num_col_a / 4;
  306. while (col_count)
  307. {
  308. int32_t a01, a02;
  309. int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
  310. int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
  311. read_and_pad_s4_ordered(packed_ip_a0, &a01, &a02);
  312. packed_ip_a0 += 2;
  313. ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
  314. ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
  315. b0 = arm_nn_read_q15x2_ia(&ip_b0);
  316. b1 = arm_nn_read_q15x2_ia(&ip_b1);
  317. ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
  318. ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
  319. col_count--;
  320. }
  321. col_count = (num_col_a & 0x3) >> 1;
  322. #else
  323. int32_t col_count = num_col_a >> 1;
  324. #endif
  325. while (col_count)
  326. {
  327. int8_t a0 = (int8_t)(packed_ip_a0[0] << 4) >> 4;
  328. int8_t a1 = packed_ip_a0[0] >> 4;
  329. int16_t b0 = *ip_b0++;
  330. int16_t b1 = *ip_b1++;
  331. ++packed_ip_a0;
  332. ch_0_out_0 += a0 * b0;
  333. ch_0_out_1 += a0 * b1;
  334. b0 = *ip_b0++;
  335. b1 = *ip_b1++;
  336. ch_0_out_0 += a1 * b0;
  337. ch_0_out_1 += a1 * b1;
  338. col_count--;
  339. }
  340. if (num_col_a % 2 && left_over_rows != 1)
  341. {
  342. int8_t a0 = (int8_t)(packed_ip_a0[0] << 4) >> 4;
  343. int16_t b0 = *ip_b0++;
  344. int16_t b1 = *ip_b1++;
  345. ch_0_out_0 += a0 * b0;
  346. ch_0_out_1 += a0 * b1;
  347. }
  348. ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
  349. ch_0_out_0 += out_offset;
  350. ch_0_out_0 = MAX(ch_0_out_0, activation_min);
  351. ch_0_out_0 = MIN(ch_0_out_0, activation_max);
  352. *out_0++ = (int8_t)ch_0_out_0;
  353. ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
  354. ch_0_out_1 += out_offset;
  355. ch_0_out_1 = MAX(ch_0_out_1, activation_min);
  356. ch_0_out_1 = MIN(ch_0_out_1, activation_max);
  357. *out_1++ = (int8_t)ch_0_out_1;
  358. out_mult++;
  359. out_shift++;
  360. ++left_over_rows;
  361. }
  362. out_0 += output_ch;
  363. /* return the new output pointer with offset */
  364. return out_0;
  365. }