arm_mat_vec_mult_q15.c 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. /* ----------------------------------------------------------------------
  2. * Project: CMSIS DSP Library
  3. * Title: arm_mat_vec_mult_q15.c
  4. * Description: Q15 matrix and vector multiplication
  5. *
  6. * $Date: 23 April 2021
  7. *
  8. * $Revision: V1.9.0
  9. *
  10. * Target Processor: Cortex-M and Cortex-A cores
  11. * -------------------------------------------------------------------- */
  12. /*
  13. * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
  14. *
  15. * SPDX-License-Identifier: Apache-2.0
  16. *
  17. * Licensed under the Apache License, Version 2.0 (the License); you may
  18. * not use this file except in compliance with the License.
  19. * You may obtain a copy of the License at
  20. *
  21. * www.apache.org/licenses/LICENSE-2.0
  22. *
  23. * Unless required by applicable law or agreed to in writing, software
  24. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  25. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  26. * See the License for the specific language governing permissions and
  27. * limitations under the License.
  28. */
  29. #include "dsp/matrix_functions.h"
  30. /**
  31. * @ingroup groupMatrix
  32. */
  33. /**
  34. * @addtogroup MatrixVectMult
  35. * @{
  36. */
  37. /**
  38. * @brief Q15 matrix and vector multiplication.
  39. * @param[in] *pSrcMat points to the input matrix structure
  40. * @param[in] *pVec points to input vector
  41. * @param[out] *pDst points to output vector
  42. */
  43. #if defined(ARM_MATH_MVEI) && !defined(ARM_MATH_AUTOVECTORIZE)
  44. #include "arm_helium_utils.h"
  45. void arm_mat_vec_mult_q15(
  46. const arm_matrix_instance_q15 * pSrcMat,
  47. const q15_t *pSrcVec,
  48. q15_t *pDstVec)
  49. {
  50. const q15_t *pMatSrc = pSrcMat->pData;
  51. const q15_t *pMat0, *pMat1;
  52. uint32_t numRows = pSrcMat->numRows;
  53. uint32_t numCols = pSrcMat->numCols;
  54. q15_t *px;
  55. int32_t row;
  56. uint16_t blkCnt; /* loop counters */
  57. row = numRows;
  58. px = pDstVec;
  59. /*
  60. * compute 3x64-bit accumulators per loop
  61. */
  62. while (row >= 3)
  63. {
  64. q15_t const *pMat0Vec, *pMat1Vec, *pMat2Vec, *pVec;
  65. const q15_t *pMat2;
  66. q15_t const *pSrcVecPtr = pSrcVec;
  67. q63_t acc0, acc1, acc2;
  68. q15x8_t vecMatA0, vecMatA1, vecMatA2, vecIn;
  69. pVec = pSrcVec;
  70. /*
  71. * Initialize the pointer pIn1 to point to the starting address of the column being processed
  72. */
  73. pMat0 = pMatSrc;
  74. pMat1 = pMat0 + numCols;
  75. pMat2 = pMat1 + numCols;
  76. acc0 = 0LL;
  77. acc1 = 0LL;
  78. acc2 = 0LL;
  79. pMat0Vec = pMat0;
  80. pMat1Vec = pMat1;
  81. pMat2Vec = pMat2;
  82. pVec = pSrcVecPtr;
  83. blkCnt = numCols >> 3;
  84. while (blkCnt > 0U)
  85. {
  86. vecMatA0 = vld1q(pMat0Vec);
  87. pMat0Vec += 8;
  88. vecMatA1 = vld1q(pMat1Vec);
  89. pMat1Vec += 8;
  90. vecMatA2 = vld1q(pMat2Vec);
  91. pMat2Vec += 8;
  92. vecIn = vld1q(pVec);
  93. pVec += 8;
  94. acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
  95. acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
  96. acc2 = vmlaldavaq(acc2, vecIn, vecMatA2);
  97. blkCnt--;
  98. }
  99. /*
  100. * tail
  101. * (will be merged thru tail predication)
  102. */
  103. blkCnt = numCols & 7;
  104. if (blkCnt > 0U)
  105. {
  106. mve_pred16_t p0 = vctp16q(blkCnt);
  107. vecMatA0 = vld1q(pMat0Vec);
  108. vecMatA1 = vld1q(pMat1Vec);
  109. vecMatA2 = vld1q(pMat2Vec);
  110. vecIn = vldrhq_z_s16(pVec, p0);
  111. acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
  112. acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
  113. acc2 = vmlaldavaq(acc2, vecIn, vecMatA2);
  114. }
  115. *px++ = MVE_ASRL_SAT16(acc0, 15);
  116. *px++ = MVE_ASRL_SAT16(acc1, 15);
  117. *px++ = MVE_ASRL_SAT16(acc2, 15);
  118. pMatSrc += numCols * 3;
  119. /*
  120. * Decrement the row loop counter
  121. */
  122. row -= 3;
  123. }
  124. /*
  125. * process any remaining rows pair
  126. */
  127. if (row >= 2)
  128. {
  129. q15_t const *pMat0Vec, *pMat1Vec, *pVec;
  130. q15_t const *pSrcVecPtr = pSrcVec;
  131. q63_t acc0, acc1;
  132. q15x8_t vecMatA0, vecMatA1, vecIn;
  133. /*
  134. * For every row wise process, the pInVec pointer is set
  135. * to the starting address of the vector
  136. */
  137. pVec = pSrcVec;
  138. /*
  139. * Initialize the pointer pIn1 to point to the starting address of the column being processed
  140. */
  141. pMat0 = pMatSrc;
  142. pMat1 = pMat0 + numCols;
  143. acc0 = 0LL;
  144. acc1 = 0LL;
  145. pMat0Vec = pMat0;
  146. pMat1Vec = pMat1;
  147. pVec = pSrcVecPtr;
  148. blkCnt = numCols >> 3;
  149. while (blkCnt > 0U)
  150. {
  151. vecMatA0 = vld1q(pMat0Vec);
  152. pMat0Vec += 8;
  153. vecMatA1 = vld1q(pMat1Vec);
  154. pMat1Vec += 8;
  155. vecIn = vld1q(pVec);
  156. pVec += 8;
  157. acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
  158. acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
  159. blkCnt--;
  160. }
  161. /*
  162. * tail
  163. * (will be merged thru tail predication)
  164. */
  165. blkCnt = numCols & 7;
  166. if (blkCnt > 0U)
  167. {
  168. mve_pred16_t p0 = vctp16q(blkCnt);
  169. vecMatA0 = vld1q(pMat0Vec);
  170. vecMatA1 = vld1q(pMat1Vec);
  171. vecIn = vldrhq_z_s16(pVec, p0);
  172. acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
  173. acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
  174. }
  175. *px++ = MVE_ASRL_SAT16(acc0, 15);
  176. *px++ = MVE_ASRL_SAT16(acc1, 15);
  177. pMatSrc += numCols * 2;
  178. /*
  179. * Decrement the row loop counter
  180. */
  181. row -= 2;
  182. }
  183. if (row >= 1)
  184. {
  185. q15_t const *pMat0Vec, *pVec;
  186. q15_t const *pSrcVecPtr = pSrcVec;
  187. q63_t acc0;
  188. q15x8_t vecMatA0, vecIn;
  189. /*
  190. * For every row wise process, the pInVec pointer is set
  191. * to the starting address of the vector
  192. */
  193. pVec = pSrcVec;
  194. /*
  195. * Initialize the pointer pIn1 to point to the starting address of the column being processed
  196. */
  197. pMat0 = pMatSrc;
  198. acc0 = 0LL;
  199. pMat0Vec = pMat0;
  200. pVec = pSrcVecPtr;
  201. blkCnt = numCols >> 3;
  202. while (blkCnt > 0U)
  203. {
  204. vecMatA0 = vld1q(pMat0Vec);
  205. pMat0Vec += 8;
  206. vecIn = vld1q(pVec);
  207. pVec += 8;
  208. acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
  209. blkCnt--;
  210. }
  211. /*
  212. * tail
  213. * (will be merged thru tail predication)
  214. */
  215. blkCnt = numCols & 7;
  216. if (blkCnt > 0U)
  217. {
  218. mve_pred16_t p0 = vctp16q(blkCnt);
  219. vecMatA0 = vld1q(pMat0Vec);
  220. vecIn = vldrhq_z_s16(pVec, p0);
  221. acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
  222. }
  223. *px++ = MVE_ASRL_SAT16(acc0, 15);
  224. }
  225. }
  226. #else
  227. void arm_mat_vec_mult_q15(const arm_matrix_instance_q15 *pSrcMat, const q15_t *pVec, q15_t *pDst)
  228. {
  229. uint32_t numRows = pSrcMat->numRows;
  230. uint32_t numCols = pSrcMat->numCols;
  231. const q15_t *pSrcA = pSrcMat->pData;
  232. const q15_t *pInA1; /* input data matrix pointer A of Q15 type */
  233. const q15_t *pInA2; /* input data matrix pointer A of Q15 type */
  234. const q15_t *pInA3; /* input data matrix pointer A of Q15 type */
  235. const q15_t *pInA4; /* input data matrix pointer A of Q15 type */
  236. const q15_t *pInVec; /* input data matrix pointer B of Q15 type */
  237. q15_t *px; /* Temporary output data matrix pointer */
  238. uint16_t i, row, colCnt; /* loop counters */
  239. q31_t matData, matData2, vecData, vecData2;
  240. /* Process 4 rows at a time */
  241. row = numRows >> 2;
  242. i = 0u;
  243. px = pDst;
  244. /* The following loop performs the dot-product of each row in pSrcA with the vector */
  245. /* row loop */
  246. while (row > 0) {
  247. /* Initialize accumulators */
  248. q63_t sum1 = 0;
  249. q63_t sum2 = 0;
  250. q63_t sum3 = 0;
  251. q63_t sum4 = 0;
  252. /* For every row wise process, the pInVec pointer is set
  253. ** to the starting address of the vector */
  254. pInVec = pVec;
  255. /* Loop unrolling: process 2 columns per iteration */
  256. colCnt = numCols >> 1;
  257. /* Initialize pointers to the starting address of the column being processed */
  258. pInA1 = pSrcA + i;
  259. pInA2 = pInA1 + numCols;
  260. pInA3 = pInA2 + numCols;
  261. pInA4 = pInA3 + numCols;
  262. // Main loop: matrix-vector multiplication
  263. while (colCnt > 0u) {
  264. // Read 2 values from vector
  265. vecData = read_q15x2_ia (&pInVec);
  266. // Read 8 values from the matrix - 2 values from each of 4 rows, and do multiply accumulate
  267. matData = read_q15x2_ia (&pInA1);
  268. sum1 = __SMLALD(matData, vecData, sum1);
  269. matData = read_q15x2_ia (&pInA2);
  270. sum2 = __SMLALD(matData, vecData, sum2);
  271. matData = read_q15x2_ia (&pInA3);
  272. sum3 = __SMLALD(matData, vecData, sum3);
  273. matData = read_q15x2_ia (&pInA4);
  274. sum4 = __SMLALD(matData, vecData, sum4);
  275. // Decrement the loop counter
  276. colCnt--;
  277. }
  278. /* process any remaining columns */
  279. colCnt = numCols & 1u;
  280. if (numCols & 1u) {
  281. vecData = *pInVec++;
  282. sum1 += (q63_t)*pInA1++ * vecData;
  283. sum2 += (q63_t)*pInA2++ * vecData;
  284. sum3 += (q63_t)*pInA3++ * vecData;
  285. sum4 += (q63_t)*pInA4++ * vecData;
  286. }
  287. /* Saturate and store the result in the destination buffer */
  288. *px++ = (q15_t)(__SSAT((sum1 >> 15), 16));
  289. *px++ = (q15_t)(__SSAT((sum2 >> 15), 16));
  290. *px++ = (q15_t)(__SSAT((sum3 >> 15), 16));
  291. *px++ = (q15_t)(__SSAT((sum4 >> 15), 16));
  292. i = i + numCols * 4;
  293. /* Decrement the row loop counter */
  294. row--;
  295. }
  296. /* process any remaining rows */
  297. row = numRows & 3u;
  298. while (row > 0) {
  299. q63_t sum = 0;
  300. pInVec = pVec;
  301. pInA1 = pSrcA + i;
  302. // loop unrolling - process 4 elements at a time
  303. colCnt = numCols >> 2;
  304. while (colCnt > 0) {
  305. vecData = read_q15x2_ia (&pInVec);
  306. vecData2 = read_q15x2_ia (&pInVec);
  307. matData = read_q15x2_ia (&pInA1);
  308. matData2 = read_q15x2_ia (&pInA1);
  309. sum = __SMLALD(matData, vecData, sum);
  310. sum = __SMLALD(matData2, vecData2, sum);
  311. colCnt--;
  312. }
  313. // process remainder of row
  314. colCnt = numCols & 3u;
  315. while (colCnt > 0) {
  316. sum += (q63_t)*pInA1++ * *pInVec++;
  317. colCnt--;
  318. }
  319. *px++ = (q15_t)(__SSAT((sum >> 15), 16));
  320. i = i + numCols;
  321. row--;
  322. }
  323. }
  324. #endif /* defined(ARM_MATH_MVEI) */
  325. /**
  326. * @} end of MatrixMult group
  327. */