arm_fully_connected_q15_opt.c 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. /*
  2. * Copyright (C) 2010-2021 Arm Limited or its affiliates. All rights reserved.
  3. *
  4. * SPDX-License-Identifier: Apache-2.0
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the License); you may
  7. * not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  14. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. /* ----------------------------------------------------------------------
  19. * Project: CMSIS NN Library
  20. * Title: arm_fully_connected_q15_opt.c
  21. * Description: Q15 opt fully-connected layer function
  22. *
  23. * $Date: 20. July 2021
  24. * $Revision: V.1.1.1
  25. *
  26. * Target Processor: Cortex-M cores
  27. *
  28. * -------------------------------------------------------------------- */
  29. #include "arm_nnfunctions.h"
  30. #include "arm_nnsupportfunctions.h"
  31. /**
  32. * @ingroup groupNN
  33. */
  34. /**
  35. * @addtogroup FC
  36. * @{
  37. */
  38. /**
  39. * @brief Q15 opt fully-connected layer function
  40. * @param[in] pV pointer to input vector
  41. * @param[in] pM pointer to matrix weights
  42. * @param[in] dim_vec length of the vector
  43. * @param[in] num_of_rows number of rows in weight matrix
  44. * @param[in] bias_shift amount of left-shift for bias
  45. * @param[in] out_shift amount of right-shift for output
  46. * @param[in] bias pointer to bias
  47. * @param[in,out] pOut pointer to output vector
  48. * @param[in,out] vec_buffer pointer to buffer space for input
  49. * @return The function returns <code>ARM_MATH_SUCCESS</code>
  50. *
  51. *
  52. * @details
  53. *
  54. * <b>Buffer size:</b>
  55. *
  56. * vec_buffer size: 0
  57. *
  58. * Here we use only one pointer to read 4 rows in the weight
  59. * matrix. So if the original matrix looks like this:
  60. *
  61. * | a11 | a12 | a13 |
  62. *
  63. * | a21 | a22 | a23 |
  64. *
  65. * | a31 | a32 | a33 |
  66. *
  67. * | a41 | a42 | a43 |
  68. *
  69. * | a51 | a52 | a53 |
  70. *
  71. * | a61 | a62 | a63 |
  72. *
  73. * We operates on multiple-of-4 rows, so the first four rows becomes
  74. *
  75. * | a11 | a12 | a21 | a22 | a31 | a32 | a41 | a42 |
  76. *
  77. * | a13 | a23 | a33 | a43 |
  78. *
  79. * Remaining rows are kept the same original order.
  80. *
  81. * So the stored weight matrix looks like this:
  82. *
  83. *
  84. * | a11 | a12 | a21 | a22 | a31 | a32 | a41 | a42 |
  85. *
  86. * | a13 | a23 | a33 | a43 | a51 | a52 | a53 | a61 |
  87. *
  88. * | a62 | a63 |
  89. */
  90. arm_status arm_fully_connected_q15_opt(const q15_t *pV,
  91. const q15_t *pM,
  92. const uint16_t dim_vec,
  93. const uint16_t num_of_rows,
  94. const uint16_t bias_shift,
  95. const uint16_t out_shift,
  96. const q15_t *bias,
  97. q15_t *pOut,
  98. q15_t *vec_buffer)
  99. {
  100. (void)vec_buffer;
  101. #if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
  102. /* Run the following code for Cortex-M4 and Cortex-M7 */
  103. const q15_t *pB = pM;
  104. q15_t *pO = pOut;
  105. const q15_t *pBias = bias;
  106. const q15_t *pA = pV;
  107. uint16_t rowCnt = num_of_rows >> 2;
  108. while (rowCnt)
  109. {
  110. q31_t sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  111. q31_t sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  112. q31_t sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  113. q31_t sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  114. uint16_t colCnt = dim_vec >> 1;
  115. pA = pV;
  116. #ifdef USE_INTRINSIC
  117. while (colCnt)
  118. {
  119. q31_t inM11, inM12, inM13, inM14;
  120. q31_t inV;
  121. inV = arm_nn_read_q15x2_ia(&pA);
  122. inM11 = arm_nn_read_q15x2_ia(&pB);
  123. sum = __SMLAD(inV, inM11, sum);
  124. inM12 = arm_nn_read_q15x2_ia(&pB);
  125. sum2 = __SMLAD(inV, inM12, sum2);
  126. inM13 = arm_nn_read_q15x2_ia(&pB);
  127. sum3 = __SMLAD(inV, inM13, sum3);
  128. inM14 = arm_nn_read_q15x2_ia(&pB);
  129. sum4 = __SMLAD(inV, inM14, sum4);
  130. colCnt--;
  131. }
  132. #else
  133. /*
  134. * register needed:
  135. * loop counter: colCnt
  136. * accumulators: sum, sum2, sum3, sum4
  137. * pointers: pB, pA
  138. * weight data: inM11, inM12, inM13, inM14
  139. * activation data: inV
  140. */
  141. asm volatile("COL_LOOP_%=:\n"
  142. "ldr.w r4, [%[pA]], #4\n"
  143. "ldr.w r0, [%[pB]], #16\n"
  144. "smlad %[sum], r4, r0, %[sum]\n"
  145. "ldr.w r1, [%[pB] , #-12]\n"
  146. "smlad %[sum2], r4, r1, %[sum2]\n"
  147. "ldr.w r2, [%[pB] , #-8]\n"
  148. "smlad %[sum3], r4, r2, %[sum3]\n"
  149. "ldr.w r3, [%[pB] , #-4]\n"
  150. "smlad %[sum4], r4, r3, %[sum4]\n"
  151. "subs %[colCnt], #1\n"
  152. "bne COL_LOOP_%=\n"
  153. : [ sum ] "+r"(sum),
  154. [ sum2 ] "+r"(sum2),
  155. [ sum3 ] "+r"(sum3),
  156. [ sum4 ] "+r"(sum4),
  157. [ pB ] "+r"(pB),
  158. [ pA ] "+r"(pA)
  159. : [ colCnt ] "r"(colCnt)
  160. : "r0", "r1", "r2", "r3", "r4");
  161. #endif /* USE_INTRINSIC */
  162. colCnt = dim_vec & 0x1;
  163. while (colCnt)
  164. {
  165. q15_t inV = *pA++;
  166. q15_t inM = *pB++;
  167. q15_t inM2 = *pB++;
  168. q15_t inM3 = *pB++;
  169. q15_t inM4 = *pB++;
  170. sum += inV * inM;
  171. sum2 += inV * inM2;
  172. sum3 += inV * inM3;
  173. sum4 += inV * inM4;
  174. colCnt--;
  175. } /* while over colCnt */
  176. *pO++ = (q15_t)(__SSAT((sum >> out_shift), 16));
  177. *pO++ = (q15_t)(__SSAT((sum2 >> out_shift), 16));
  178. *pO++ = (q15_t)(__SSAT((sum3 >> out_shift), 16));
  179. *pO++ = (q15_t)(__SSAT((sum4 >> out_shift), 16));
  180. /* adjust the pointers and counters */
  181. rowCnt--;
  182. }
  183. /* left-over part of the rows */
  184. rowCnt = num_of_rows & 0x3;
  185. while (rowCnt)
  186. {
  187. q31_t sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  188. uint16_t colCnt = dim_vec >> 2;
  189. pA = pV;
  190. while (colCnt)
  191. {
  192. q31_t inV1, inV2, inM1, inM2;
  193. inM1 = arm_nn_read_q15x2_ia(&pB);
  194. inV1 = arm_nn_read_q15x2_ia(&pA);
  195. sum = __SMLAD(inV1, inM1, sum);
  196. inM2 = arm_nn_read_q15x2_ia(&pB);
  197. inV2 = arm_nn_read_q15x2_ia(&pA);
  198. sum = __SMLAD(inV2, inM2, sum);
  199. colCnt--;
  200. }
  201. /* left-over of the vector */
  202. colCnt = dim_vec & 0x3;
  203. while (colCnt)
  204. {
  205. q15_t inV = *pA++;
  206. q15_t inM = *pB++;
  207. sum += inV * inM;
  208. colCnt--;
  209. }
  210. *pO++ = (q15_t)(__SSAT((sum >> out_shift), 16));
  211. rowCnt--;
  212. }
  213. #else
  214. /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
  215. uint16_t rowCnt = num_of_rows >> 2;
  216. const q15_t *pB = pM;
  217. const q15_t *pA;
  218. q15_t *pO = pOut;
  219. const q15_t *pBias = bias;
  220. while (rowCnt)
  221. {
  222. q31_t sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  223. q31_t sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  224. q31_t sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  225. q31_t sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  226. uint16_t colCnt = dim_vec >> 1;
  227. pA = pV;
  228. while (colCnt)
  229. {
  230. q15_t inA1 = *pA++;
  231. q15_t inA2 = *pA++;
  232. q15_t inB1 = *pB++;
  233. q15_t inB2 = *pB++;
  234. sum += inA1 * inB1 + inA2 * inB2;
  235. inB1 = *pB++;
  236. inB2 = *pB++;
  237. sum2 += inA1 * inB1 + inA2 * inB2;
  238. inB1 = *pB++;
  239. inB2 = *pB++;
  240. sum3 += inA1 * inB1 + inA2 * inB2;
  241. inB1 = *pB++;
  242. inB2 = *pB++;
  243. sum4 += inA1 * inB1 + inA2 * inB2;
  244. colCnt--;
  245. }
  246. colCnt = dim_vec & 0x1;
  247. while (colCnt)
  248. {
  249. q15_t inA = *pA++;
  250. q15_t inB = *pB++;
  251. sum += inA * inB;
  252. inB = *pB++;
  253. sum2 += inA * inB;
  254. inB = *pB++;
  255. sum3 += inA * inB;
  256. inB = *pB++;
  257. sum4 += inA * inB;
  258. colCnt--;
  259. }
  260. *pO++ = (q15_t)__SSAT((sum >> out_shift), 16);
  261. *pO++ = (q15_t)__SSAT((sum2 >> out_shift), 16);
  262. *pO++ = (q15_t)__SSAT((sum3 >> out_shift), 16);
  263. *pO++ = (q15_t)__SSAT((sum4 >> out_shift), 16);
  264. rowCnt--;
  265. }
  266. rowCnt = num_of_rows & 0x3;
  267. while (rowCnt)
  268. {
  269. int ip_out = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
  270. int j;
  271. pA = pV;
  272. for (j = 0; j < dim_vec; j++)
  273. {
  274. q15_t inA = *pA++;
  275. q15_t inB = *pB++;
  276. ip_out += inA * inB;
  277. }
  278. *pO++ = (q15_t)__SSAT((ip_out >> out_shift), 16);
  279. rowCnt--;
  280. }
  281. #endif /* ARM_MATH_DSP */
  282. /* Return to ARM_MATH_SUCCESS */
  283. return (ARM_MATH_SUCCESS);
  284. }
  285. /**
  286. * @} end of FC group
  287. */