arm_mat_vec_mult_f32.c 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. /* ----------------------------------------------------------------------
  2. * Project: CMSIS DSP Library
  3. * Title: arm_mat_vec_mult_f32.c
  4. * Description: Floating-point matrix and vector multiplication
  5. *
  6. * $Date: 23 April 2021
  7. *
  8. * $Revision: V1.9.0
  9. *
  10. * Target Processor: Cortex-M and Cortex-A cores
  11. * -------------------------------------------------------------------- */
  12. /*
  13. * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
  14. *
  15. * SPDX-License-Identifier: Apache-2.0
  16. *
  17. * Licensed under the Apache License, Version 2.0 (the License); you may
  18. * not use this file except in compliance with the License.
  19. * You may obtain a copy of the License at
  20. *
  21. * www.apache.org/licenses/LICENSE-2.0
  22. *
  23. * Unless required by applicable law or agreed to in writing, software
  24. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  25. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  26. * See the License for the specific language governing permissions and
  27. * limitations under the License.
  28. */
  29. #include "dsp/matrix_functions.h"
  30. /**
  31. * @ingroup groupMatrix
  32. */
  33. /**
  34. * @defgroup MatrixVectMult Matrix Vector Multiplication
  35. *
  36. * Multiplies a matrix and a vector.
  37. *
  38. */
  39. /**
  40. * @addtogroup MatrixVectMult
  41. * @{
  42. */
  43. /**
  44. * @brief Floating-point matrix and vector multiplication.
  45. * @param[in] *pSrcMat points to the input matrix structure
  46. * @param[in] *pVec points to input vector
  47. * @param[out] *pDst points to output vector
  48. */
  49. #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
  50. #include "arm_helium_utils.h"
  51. void arm_mat_vec_mult_f32(
  52. const arm_matrix_instance_f32 *pSrcMat,
  53. const float32_t *pSrcVec,
  54. float32_t *pDstVec)
  55. {
  56. uint32_t numRows = pSrcMat->numRows;
  57. uint32_t numCols = pSrcMat->numCols;
  58. const float32_t *pSrcA = pSrcMat->pData;
  59. const float32_t *pInA0;
  60. const float32_t *pInA1;
  61. float32_t *px;
  62. int32_t row;
  63. uint32_t blkCnt; /* loop counters */
  64. row = numRows;
  65. px = pDstVec;
  66. /*
  67. * compute 4 rows in parallel
  68. */
  69. while (row >= 4)
  70. {
  71. const float32_t *pInA2, *pInA3;
  72. float32_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec, *pInVec;
  73. f32x4_t vecIn, acc0, acc1, acc2, acc3;
  74. float32_t const *pSrcVecPtr = pSrcVec;
  75. /*
  76. * Initialize the pointers to 4 consecutive MatrixA rows
  77. */
  78. pInA0 = pSrcA;
  79. pInA1 = pInA0 + numCols;
  80. pInA2 = pInA1 + numCols;
  81. pInA3 = pInA2 + numCols;
  82. /*
  83. * Initialize the vector pointer
  84. */
  85. pInVec = pSrcVecPtr;
  86. /*
  87. * reset accumulators
  88. */
  89. acc0 = vdupq_n_f32(0.0f);
  90. acc1 = vdupq_n_f32(0.0f);
  91. acc2 = vdupq_n_f32(0.0f);
  92. acc3 = vdupq_n_f32(0.0f);
  93. pSrcA0Vec = pInA0;
  94. pSrcA1Vec = pInA1;
  95. pSrcA2Vec = pInA2;
  96. pSrcA3Vec = pInA3;
  97. blkCnt = numCols >> 2;
  98. while (blkCnt > 0U)
  99. {
  100. f32x4_t vecA;
  101. vecIn = vld1q(pInVec);
  102. pInVec += 4;
  103. vecA = vld1q(pSrcA0Vec);
  104. pSrcA0Vec += 4;
  105. acc0 = vfmaq(acc0, vecIn, vecA);
  106. vecA = vld1q(pSrcA1Vec);
  107. pSrcA1Vec += 4;
  108. acc1 = vfmaq(acc1, vecIn, vecA);
  109. vecA = vld1q(pSrcA2Vec);
  110. pSrcA2Vec += 4;
  111. acc2 = vfmaq(acc2, vecIn, vecA);
  112. vecA = vld1q(pSrcA3Vec);
  113. pSrcA3Vec += 4;
  114. acc3 = vfmaq(acc3, vecIn, vecA);
  115. blkCnt--;
  116. }
  117. /*
  118. * tail
  119. * (will be merged thru tail predication)
  120. */
  121. blkCnt = numCols & 3;
  122. if (blkCnt > 0U)
  123. {
  124. mve_pred16_t p0 = vctp32q(blkCnt);
  125. f32x4_t vecA;
  126. vecIn = vldrwq_z_f32(pInVec, p0);
  127. vecA = vld1q(pSrcA0Vec);
  128. acc0 = vfmaq(acc0, vecIn, vecA);
  129. vecA = vld1q(pSrcA1Vec);
  130. acc1 = vfmaq(acc1, vecIn, vecA);
  131. vecA = vld1q(pSrcA2Vec);
  132. acc2 = vfmaq(acc2, vecIn, vecA);
  133. vecA = vld1q(pSrcA3Vec);
  134. acc3 = vfmaq(acc3, vecIn, vecA);
  135. }
  136. /*
  137. * Sum the partial parts
  138. */
  139. *px++ = vecAddAcrossF32Mve(acc0);
  140. *px++ = vecAddAcrossF32Mve(acc1);
  141. *px++ = vecAddAcrossF32Mve(acc2);
  142. *px++ = vecAddAcrossF32Mve(acc3);
  143. pSrcA += numCols * 4;
  144. /*
  145. * Decrement the row loop counter
  146. */
  147. row -= 4;
  148. }
  149. /*
  150. * compute 2 rows in parrallel
  151. */
  152. if (row >= 2)
  153. {
  154. float32_t const *pSrcA0Vec, *pSrcA1Vec, *pInVec;
  155. f32x4_t vecIn, acc0, acc1;
  156. float32_t const *pSrcVecPtr = pSrcVec;
  157. /*
  158. * Initialize the pointers to 2 consecutive MatrixA rows
  159. */
  160. pInA0 = pSrcA;
  161. pInA1 = pInA0 + numCols;
  162. /*
  163. * Initialize the vector pointer
  164. */
  165. pInVec = pSrcVecPtr;
  166. /*
  167. * reset accumulators
  168. */
  169. acc0 = vdupq_n_f32(0.0f);
  170. acc1 = vdupq_n_f32(0.0f);
  171. pSrcA0Vec = pInA0;
  172. pSrcA1Vec = pInA1;
  173. blkCnt = numCols >> 2;
  174. while (blkCnt > 0U)
  175. {
  176. f32x4_t vecA;
  177. vecIn = vld1q(pInVec);
  178. pInVec += 4;
  179. vecA = vld1q(pSrcA0Vec);
  180. pSrcA0Vec += 4;
  181. acc0 = vfmaq(acc0, vecIn, vecA);
  182. vecA = vld1q(pSrcA1Vec);
  183. pSrcA1Vec += 4;
  184. acc1 = vfmaq(acc1, vecIn, vecA);
  185. blkCnt--;
  186. }
  187. /*
  188. * tail
  189. * (will be merged thru tail predication)
  190. */
  191. blkCnt = numCols & 3;
  192. if (blkCnt > 0U)
  193. {
  194. mve_pred16_t p0 = vctp32q(blkCnt);
  195. f32x4_t vecA;
  196. vecIn = vldrwq_z_f32(pInVec, p0);
  197. vecA = vld1q(pSrcA0Vec);
  198. acc0 = vfmaq(acc0, vecIn, vecA);
  199. vecA = vld1q(pSrcA1Vec);
  200. acc1 = vfmaq(acc1, vecIn, vecA);
  201. }
  202. /*
  203. * Sum the partial parts
  204. */
  205. *px++ = vecAddAcrossF32Mve(acc0);
  206. *px++ = vecAddAcrossF32Mve(acc1);
  207. pSrcA += numCols * 2;
  208. row -= 2;
  209. }
  210. if (row >= 1)
  211. {
  212. f32x4_t vecIn, acc0;
  213. float32_t const *pSrcA0Vec, *pInVec;
  214. float32_t const *pSrcVecPtr = pSrcVec;
  215. /*
  216. * Initialize the pointers to last MatrixA row
  217. */
  218. pInA0 = pSrcA;
  219. /*
  220. * Initialize the vector pointer
  221. */
  222. pInVec = pSrcVecPtr;
  223. /*
  224. * reset accumulators
  225. */
  226. acc0 = vdupq_n_f32(0.0f);
  227. pSrcA0Vec = pInA0;
  228. blkCnt = numCols >> 2;
  229. while (blkCnt > 0U)
  230. {
  231. f32x4_t vecA;
  232. vecIn = vld1q(pInVec);
  233. pInVec += 4;
  234. vecA = vld1q(pSrcA0Vec);
  235. pSrcA0Vec += 4;
  236. acc0 = vfmaq(acc0, vecIn, vecA);
  237. blkCnt--;
  238. }
  239. /*
  240. * tail
  241. * (will be merged thru tail predication)
  242. */
  243. blkCnt = numCols & 3;
  244. if (blkCnt > 0U)
  245. {
  246. mve_pred16_t p0 = vctp32q(blkCnt);
  247. f32x4_t vecA;
  248. vecIn = vldrwq_z_f32(pInVec, p0);
  249. vecA = vld1q(pSrcA0Vec);
  250. acc0 = vfmaq(acc0, vecIn, vecA);
  251. }
  252. /*
  253. * Sum the partial parts
  254. */
  255. *px++ = vecAddAcrossF32Mve(acc0);
  256. }
  257. }
  258. #else
  259. void arm_mat_vec_mult_f32(const arm_matrix_instance_f32 *pSrcMat, const float32_t *pVec, float32_t *pDst)
  260. {
  261. uint32_t numRows = pSrcMat->numRows;
  262. uint32_t numCols = pSrcMat->numCols;
  263. const float32_t *pSrcA = pSrcMat->pData;
  264. const float32_t *pInA1; /* input data matrix pointer A of Q31 type */
  265. const float32_t *pInA2; /* input data matrix pointer A of Q31 type */
  266. const float32_t *pInA3; /* input data matrix pointer A of Q31 type */
  267. const float32_t *pInA4; /* input data matrix pointer A of Q31 type */
  268. const float32_t *pInVec; /* input data matrix pointer B of Q31 type */
  269. float32_t *px; /* Temporary output data matrix pointer */
  270. uint16_t i, row, colCnt; /* loop counters */
  271. float32_t matData, matData2, vecData, vecData2;
  272. /* Process 4 rows at a time */
  273. row = numRows >> 2;
  274. i = 0u;
  275. px = pDst;
  276. /* The following loop performs the dot-product of each row in pSrcA with the vector */
  277. /* row loop */
  278. while (row > 0) {
  279. /* Initialize accumulators */
  280. float32_t sum1 = 0.0f;
  281. float32_t sum2 = 0.0f;
  282. float32_t sum3 = 0.0f;
  283. float32_t sum4 = 0.0f;
  284. /* For every row wise process, the pInVec pointer is set
  285. ** to the starting address of the vector */
  286. pInVec = pVec;
  287. /* Loop unrolling: process 2 columns per iteration */
  288. colCnt = numCols;
  289. /* Initialize pointers to the starting address of the column being processed */
  290. pInA1 = pSrcA + i;
  291. pInA2 = pInA1 + numCols;
  292. pInA3 = pInA2 + numCols;
  293. pInA4 = pInA3 + numCols;
  294. // Main loop: matrix-vector multiplication
  295. while (colCnt > 0u) {
  296. // Read 2 values from vector
  297. vecData = *(pInVec)++;
  298. // Read 8 values from the matrix - 2 values from each of 4 rows, and do multiply accumulate
  299. matData = *(pInA1)++;
  300. sum1 += matData * vecData;
  301. matData = *(pInA2)++;
  302. sum2 += matData * vecData;
  303. matData = *(pInA3)++;
  304. sum3 += matData * vecData;
  305. matData = *(pInA4)++;
  306. sum4 += matData * vecData;
  307. // Decrement the loop counter
  308. colCnt--;
  309. }
  310. /* Saturate and store the result in the destination buffer */
  311. *px++ = sum1;
  312. *px++ = sum2;
  313. *px++ = sum3;
  314. *px++ = sum4;
  315. i = i + numCols * 4;
  316. /* Decrement the row loop counter */
  317. row--;
  318. }
  319. /* process any remaining rows */
  320. row = numRows & 3u;
  321. while (row > 0) {
  322. float32_t sum = 0.0f;
  323. pInVec = pVec;
  324. pInA1 = pSrcA + i;
  325. colCnt = numCols >> 1;
  326. while (colCnt > 0) {
  327. vecData = *(pInVec)++;
  328. vecData2 = *(pInVec)++;
  329. matData = *(pInA1)++;
  330. matData2 = *(pInA1)++;
  331. sum += matData * vecData;
  332. sum += matData2 * vecData2;
  333. colCnt--;
  334. }
  335. // process remainder of row
  336. colCnt = numCols & 1u;
  337. while (colCnt > 0) {
  338. sum += *pInA1++ * *pInVec++;
  339. colCnt--;
  340. }
  341. *px++ = sum;
  342. i = i + numCols;
  343. row--;
  344. }
  345. }
  346. #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
  347. /**
  348. * @} end of MatrixMult group
  349. */