arm_mat_vec_mult_q31.c 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. /* ----------------------------------------------------------------------
  2. * Project: CMSIS DSP Library
  3. * Title: arm_mat_vec_mult_q31.c
  4. * Description: Q31 matrix and vector multiplication
  5. *
  6. * $Date: 23 April 2021
  7. *
  8. * $Revision: V1.9.0
  9. *
  10. * Target Processor: Cortex-M and Cortex-A cores
  11. * -------------------------------------------------------------------- */
  12. /*
  13. * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
  14. *
  15. * SPDX-License-Identifier: Apache-2.0
  16. *
  17. * Licensed under the Apache License, Version 2.0 (the License); you may
  18. * not use this file except in compliance with the License.
  19. * You may obtain a copy of the License at
  20. *
  21. * www.apache.org/licenses/LICENSE-2.0
  22. *
  23. * Unless required by applicable law or agreed to in writing, software
  24. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  25. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  26. * See the License for the specific language governing permissions and
  27. * limitations under the License.
  28. */
  29. #include "dsp/matrix_functions.h"
  30. /**
  31. * @ingroup groupMatrix
  32. */
  33. /**
  34. * @addtogroup MatrixVectMult
  35. * @{
  36. */
  37. /**
  38. * @brief Q31 matrix and vector multiplication.
  39. * @param[in] *pSrcMat points to the input matrix structure
  40. * @param[in] *pVec points to the input vector
  41. * @param[out] *pDst points to the output vector
  42. */
  43. #if defined(ARM_MATH_MVEI) && !defined(ARM_MATH_AUTOVECTORIZE)
  44. void arm_mat_vec_mult_q31(
  45. const arm_matrix_instance_q31 * pSrcMat,
  46. const q31_t *pSrcVec,
  47. q31_t *pDstVec)
  48. {
  49. const q31_t *pMatSrc = pSrcMat->pData;
  50. const q31_t *pMat0, *pMat1;
  51. uint32_t numRows = pSrcMat->numRows;
  52. uint32_t numCols = pSrcMat->numCols;
  53. q31_t *px;
  54. int32_t row;
  55. uint16_t blkCnt; /* loop counters */
  56. row = numRows;
  57. px = pDstVec;
  58. /*
  59. * compute 3x64-bit accumulators per loop
  60. */
  61. while (row >= 3)
  62. {
  63. q31_t const *pMat0Vec, *pMat1Vec, *pMat2Vec, *pVec;
  64. const q31_t *pMat2;
  65. q31_t const *pSrcVecPtr = pSrcVec;
  66. q63_t acc0, acc1, acc2;
  67. q31x4_t vecMatA0, vecMatA1, vecMatA2, vecIn;
  68. pVec = pSrcVec;
  69. /*
  70. * Initialize the pointer pIn1 to point to the starting address of the column being processed
  71. */
  72. pMat0 = pMatSrc;
  73. pMat1 = pMat0 + numCols;
  74. pMat2 = pMat1 + numCols;
  75. acc0 = 0LL;
  76. acc1 = 0LL;
  77. acc2 = 0LL;
  78. pMat0Vec = pMat0;
  79. pMat1Vec = pMat1;
  80. pMat2Vec = pMat2;
  81. pVec = pSrcVecPtr;
  82. blkCnt = numCols >> 2;
  83. while (blkCnt > 0U)
  84. {
  85. vecMatA0 = vld1q(pMat0Vec);
  86. pMat0Vec += 4;
  87. vecMatA1 = vld1q(pMat1Vec);
  88. pMat1Vec += 4;
  89. vecMatA2 = vld1q(pMat2Vec);
  90. pMat2Vec += 4;
  91. vecIn = vld1q(pVec);
  92. pVec += 4;
  93. acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
  94. acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
  95. acc2 = vmlaldavaq(acc2, vecIn, vecMatA2);
  96. blkCnt--;
  97. }
  98. /*
  99. * tail
  100. * (will be merged thru tail predication)
  101. */
  102. blkCnt = numCols & 3;
  103. if (blkCnt > 0U)
  104. {
  105. mve_pred16_t p0 = vctp32q(blkCnt);
  106. vecMatA0 = vld1q(pMat0Vec);
  107. vecMatA1 = vld1q(pMat1Vec);
  108. vecMatA2 = vld1q(pMat2Vec);
  109. vecIn = vldrwq_z_s32(pVec, p0);
  110. acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
  111. acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
  112. acc2 = vmlaldavaq(acc2, vecIn, vecMatA2);
  113. }
  114. *px++ = asrl(acc0, 31);
  115. *px++ = asrl(acc1, 31);
  116. *px++ = asrl(acc2, 31);
  117. pMatSrc += numCols * 3;
  118. /*
  119. * Decrement the row loop counter
  120. */
  121. row -= 3;
  122. }
  123. /*
  124. * process any remaining rows pair
  125. */
  126. if (row >= 2)
  127. {
  128. q31_t const *pMat0Vec, *pMat1Vec, *pVec;
  129. q31_t const *pSrcVecPtr = pSrcVec;
  130. q63_t acc0, acc1;
  131. q31x4_t vecMatA0, vecMatA1, vecIn;
  132. /*
  133. * For every row wise process, the pInVec pointer is set
  134. * to the starting address of the vector
  135. */
  136. pVec = pSrcVec;
  137. /*
  138. * Initialize the pointer pIn1 to point to the starting address of the column being processed
  139. */
  140. pMat0 = pMatSrc;
  141. pMat1 = pMat0 + numCols;
  142. acc0 = 0LL;
  143. acc1 = 0LL;
  144. pMat0Vec = pMat0;
  145. pMat1Vec = pMat1;
  146. pVec = pSrcVecPtr;
  147. blkCnt = numCols >> 2;
  148. while (blkCnt > 0U)
  149. {
  150. vecMatA0 = vld1q(pMat0Vec);
  151. pMat0Vec += 4;
  152. vecMatA1 = vld1q(pMat1Vec);
  153. pMat1Vec += 4;
  154. vecIn = vld1q(pVec);
  155. pVec += 4;
  156. acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
  157. acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
  158. blkCnt--;
  159. }
  160. /*
  161. * tail
  162. * (will be merged thru tail predication)
  163. */
  164. blkCnt = numCols & 3;
  165. if (blkCnt > 0U)
  166. {
  167. mve_pred16_t p0 = vctp32q(blkCnt);
  168. vecMatA0 = vld1q(pMat0Vec);
  169. vecMatA1 = vld1q(pMat1Vec);
  170. vecIn = vldrwq_z_s32(pVec, p0);
  171. acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
  172. acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
  173. }
  174. *px++ = asrl(acc0, 31);
  175. *px++ = asrl(acc1, 31);
  176. pMatSrc += numCols * 2;
  177. /*
  178. * Decrement the row loop counter
  179. */
  180. row -= 2;
  181. }
  182. if (row >= 1)
  183. {
  184. q31_t const *pMat0Vec, *pVec;
  185. q31_t const *pSrcVecPtr = pSrcVec;
  186. q63_t acc0;
  187. q31x4_t vecMatA0, vecIn;
  188. /*
  189. * For every row wise process, the pInVec pointer is set
  190. * to the starting address of the vector
  191. */
  192. pVec = pSrcVec;
  193. /*
  194. * Initialize the pointer pIn1 to point to the starting address of the column being processed
  195. */
  196. pMat0 = pMatSrc;
  197. acc0 = 0LL;
  198. pMat0Vec = pMat0;
  199. pVec = pSrcVecPtr;
  200. blkCnt = numCols >> 2;
  201. while (blkCnt > 0U)
  202. {
  203. vecMatA0 = vld1q(pMat0Vec);
  204. pMat0Vec += 4;
  205. vecIn = vld1q(pVec);
  206. pVec += 4;
  207. acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
  208. blkCnt--;
  209. }
  210. /*
  211. * tail
  212. * (will be merged thru tail predication)
  213. */
  214. blkCnt = numCols & 3;
  215. if (blkCnt > 0U)
  216. {
  217. mve_pred16_t p0 = vctp32q(blkCnt);
  218. vecMatA0 = vld1q(pMat0Vec);
  219. vecIn = vldrwq_z_s32(pVec, p0);
  220. acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
  221. }
  222. *px++ = asrl(acc0, 31);
  223. }
  224. }
  225. #else
  226. void arm_mat_vec_mult_q31(const arm_matrix_instance_q31 *pSrcMat, const q31_t *pVec, q31_t *pDst)
  227. {
  228. uint32_t numRows = pSrcMat->numRows;
  229. uint32_t numCols = pSrcMat->numCols;
  230. const q31_t *pSrcA = pSrcMat->pData;
  231. const q31_t *pInA1; /* input data matrix pointer A of Q31 type */
  232. const q31_t *pInA2; /* input data matrix pointer A of Q31 type */
  233. const q31_t *pInA3; /* input data matrix pointer A of Q31 type */
  234. const q31_t *pInA4; /* input data matrix pointer A of Q31 type */
  235. const q31_t *pInVec; /* input data matrix pointer B of Q31 type */
  236. q31_t *px; /* Temporary output data matrix pointer */
  237. uint16_t i, row, colCnt; /* loop counters */
  238. q31_t matData, matData2, vecData, vecData2;
  239. /* Process 4 rows at a time */
  240. row = numRows >> 2;
  241. i = 0u;
  242. px = pDst;
  243. /* The following loop performs the dot-product of each row in pSrcA with the vector */
  244. /* row loop */
  245. while (row > 0) {
  246. /* Initialize accumulators */
  247. q63_t sum1 = 0;
  248. q63_t sum2 = 0;
  249. q63_t sum3 = 0;
  250. q63_t sum4 = 0;
  251. /* For every row wise process, the pInVec pointer is set
  252. ** to the starting address of the vector */
  253. pInVec = pVec;
  254. /* Loop unrolling: process 2 columns per iteration */
  255. colCnt = numCols;
  256. /* Initialize pointers to the starting address of the column being processed */
  257. pInA1 = pSrcA + i;
  258. pInA2 = pInA1 + numCols;
  259. pInA3 = pInA2 + numCols;
  260. pInA4 = pInA3 + numCols;
  261. // Main loop: matrix-vector multiplication
  262. while (colCnt > 0u) {
  263. // Read 2 values from vector
  264. vecData = *(pInVec)++;
  265. // Read 8 values from the matrix - 2 values from each of 4 rows, and do multiply accumulate
  266. matData = *(pInA1)++;
  267. sum1 += (q63_t)matData * vecData;
  268. matData = *(pInA2)++;
  269. sum2 += (q63_t)matData * vecData;
  270. matData = *(pInA3)++;
  271. sum3 += (q63_t)matData * vecData;
  272. matData = *(pInA4)++;
  273. sum4 += (q63_t)matData * vecData;
  274. // Decrement the loop counter
  275. colCnt--;
  276. }
  277. /* Saturate and store the result in the destination buffer */
  278. *px++ = (q31_t)(sum1 >> 31);
  279. *px++ = (q31_t)(sum2 >> 31);
  280. *px++ = (q31_t)(sum3 >> 31);
  281. *px++ = (q31_t)(sum4 >> 31);
  282. i = i + numCols * 4;
  283. /* Decrement the row loop counter */
  284. row--;
  285. }
  286. /* process any remaining rows */
  287. row = numRows & 3u;
  288. while (row > 0) {
  289. q63_t sum = 0;
  290. pInVec = pVec;
  291. pInA1 = pSrcA + i;
  292. colCnt = numCols >> 1;
  293. while (colCnt > 0) {
  294. vecData = *(pInVec)++;
  295. vecData2 = *(pInVec)++;
  296. matData = *(pInA1)++;
  297. matData2 = *(pInA1)++;
  298. sum += (q63_t)matData * vecData;
  299. sum += (q63_t)matData2 * vecData2;
  300. colCnt--;
  301. }
  302. // process remainder of row
  303. colCnt = numCols & 1u;
  304. while (colCnt > 0) {
  305. sum += (q63_t)*pInA1++ * *pInVec++;
  306. colCnt--;
  307. }
  308. *px++ = (q31_t)(sum >> 31);
  309. i = i + numCols;
  310. row--;
  311. }
  312. }
  313. #endif /* defined(ARM_MATH_MVEI) */
  314. /**
  315. * @} end of MatrixMult group
  316. */