arm_mat_mult_f32.c 31 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021
  1. /* ----------------------------------------------------------------------
  2. * Project: CMSIS DSP Library
  3. * Title: arm_mat_mult_f32.c
  4. * Description: Floating-point matrix multiplication
  5. *
  6. * $Date: 23 April 2021
  7. * $Revision: V1.9.0
  8. *
  9. * Target Processor: Cortex-M and Cortex-A cores
  10. * -------------------------------------------------------------------- */
  11. /*
  12. * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
  13. *
  14. * SPDX-License-Identifier: Apache-2.0
  15. *
  16. * Licensed under the Apache License, Version 2.0 (the License); you may
  17. * not use this file except in compliance with the License.
  18. * You may obtain a copy of the License at
  19. *
  20. * www.apache.org/licenses/LICENSE-2.0
  21. *
  22. * Unless required by applicable law or agreed to in writing, software
  23. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  24. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  25. * See the License for the specific language governing permissions and
  26. * limitations under the License.
  27. */
  28. #include "dsp/matrix_functions.h"
  29. #if defined(ARM_MATH_NEON)
  30. #define GROUPOFROWS 8
  31. #endif
  32. /**
  33. * @ingroup groupMatrix
  34. */
  35. /**
  36. * @defgroup MatrixMult Matrix Multiplication
  37. *
  38. * Multiplies two matrices.
  39. *
  40. * @par Multiplication of two 3x3 matrices:
  41. *
  42. * \f[
  43. * \begin{pmatrix}
  44. * a_{1,1} & a_{1,2} & a_{1,3} \\
  45. * a_{2,1} & a_{2,2} & a_{2,3} \\
  46. * a_{3,1} & a_{3,2} & a_{3,3} \\
  47. * \end{pmatrix}
  48. *
  49. * \begin{pmatrix}
  50. * b_{1,1} & b_{1,2} & b_{1,3} \\
  51. * b_{2,1} & b_{2,2} & b_{2,3} \\
  52. * b_{3,1} & b_{3,2} & b_{3,3} \\
  53. * \end{pmatrix}
  54. * =
  55. * \begin{pmatrix}
  56. * a_{1,1} b_{1,1}+a_{1,2} b_{2,1}+a_{1,3} b_{3,1} & a_{1,1} b_{1,2}+a_{1,2} b_{2,2}+a_{1,3} b_{3,2} & a_{1,1} b_{1,3}+a_{1,2} b_{2,3}+a_{1,3} b_{3,3} \\
  57. * a_{2,1} b_{1,1}+a_{2,2} b_{2,1}+a_{2,3} b_{3,1} & a_{2,1} b_{1,2}+a_{2,2} b_{2,2}+a_{2,3} b_{3,2} & a_{2,1} b_{1,3}+a_{2,2} b_{2,3}+a_{2,3} b_{3,3} \\
  58. * a_{3,1} b_{1,1}+a_{3,2} b_{2,1}+a_{3,3} b_{3,1} & a_{3,1} b_{1,2}+a_{3,2} b_{2,2}+a_{3,3} b_{3,2} & a_{3,1} b_{1,3}+a_{3,2} b_{2,3}+a_{3,3} b_{3,3} \\
  59. * \end{pmatrix}
  60. * \f]
  61. * Matrix multiplication is only defined if the number of columns of the
  62. * first matrix equals the number of rows of the second matrix.
  63. * Multiplying an <code>M x N</code> matrix with an <code>N x P</code> matrix results
  64. * in an <code>M x P</code> matrix.
  65. * When matrix size checking is enabled, the functions check: (1) that the inner dimensions of
  66. * <code>pSrcA</code> and <code>pSrcB</code> are equal; and (2) that the size of the output
  67. * matrix equals the outer dimensions of <code>pSrcA</code> and <code>pSrcB</code>.
  68. */
  69. /**
  70. * @addtogroup MatrixMult
  71. * @{
  72. */
  73. #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
  74. #define MATRIX_DIM3 3
  75. #define MATRIX_DIM4 4
  76. __STATIC_INLINE arm_status arm_mat_mult_f32_2x2_mve(
  77. const arm_matrix_instance_f32 *pSrcA,
  78. const arm_matrix_instance_f32 *pSrcB,
  79. arm_matrix_instance_f32 *pDst)
  80. {
  81. /* {a00, a00, a10, a10} */
  82. static const uint32_t offsetA0[4] = { 0, 0, 2, 2 };
  83. /* {b00, b01, b00, b01} */
  84. static const uint32_t offsetB0[4] = { 0, 1, 0, 1 };
  85. /* {a01, a01, a11, a11} */
  86. static const uint32_t offsetA1[4] = { 1, 1, 3, 3 };
  87. /* {b10, b11, b10, b11} */
  88. static const uint32_t offsetB1[4] = { 2, 3, 2, 3 };
  89. uint32x4_t vecOffsA, vecOffsB;
  90. f32x4_t vecInA, vecInB, vecDst;
  91. vecOffsA = vldrwq_u32((uint32_t const *) offsetA0);
  92. vecOffsB = vldrwq_u32((uint32_t const *) offsetB0);
  93. vecInA = vldrwq_gather_shifted_offset((float32_t const *) pSrcA->pData, vecOffsA);
  94. vecInB = vldrwq_gather_shifted_offset((float32_t const *) pSrcB->pData, vecOffsB);
  95. vecDst = vmulq(vecInA, vecInB);
  96. vecOffsA = vldrwq_u32((uint32_t const *) offsetA1);
  97. vecOffsB = vldrwq_u32((uint32_t const *) offsetB1);
  98. vecInA = vldrwq_gather_shifted_offset((float32_t const *) pSrcA->pData, vecOffsA);
  99. vecInB = vldrwq_gather_shifted_offset((float32_t const *) pSrcB->pData, vecOffsB);
  100. vecDst = vfmaq(vecDst, vecInA, vecInB);
  101. vstrwq_f32(pDst->pData, vecDst);
  102. return (ARM_MATH_SUCCESS);
  103. }
  104. /*
  105. * A = {{a00, a01, a02},
  106. * {a10, a11, a12},
  107. * {a20, a21, a22}}
  108. * B = {{b00, b01, b02},
  109. * {b10, b11, b12},
  110. * {b20, b21, b22}}
  111. *
  112. * Dst = {{a00 b00 + a01 b10 + a02 b20, a00 b01 + a01 b11 + a02 b21, a00 b02 + a01 b12 + a02 b22},
  113. * {a10 b00 + a11 b10 + a12 b20, a10 b01 + a11 b11 + a12 b21, a10 b02 + a11 b12 + a12 b22},
  114. * {a20 b00 + a21 b10 + a22 b20, a20 b01 + a21 b11 + a22 b21, a20 b02 + a21 b12 + a22 b22}}
  115. */
  116. __STATIC_INLINE arm_status arm_mat_mult_f32_3x3_mve(
  117. const arm_matrix_instance_f32 *pSrcA,
  118. const arm_matrix_instance_f32 *pSrcB,
  119. arm_matrix_instance_f32 *pDst)
  120. {
  121. float32_t *pInB = pSrcB->pData; /* input data matrix pointer B */
  122. float32_t *pInA = pSrcA->pData; /* input data matrix pointer A */
  123. float32_t *pOut = pDst->pData; /* output data matrix pointer */
  124. float32_t *pInA0, *pInA1, *pInA2;
  125. f32x4_t vecMac0, vecMac1, vecMac2;
  126. f32x4_t vecInB;
  127. float32_t const *pSrBVec;
  128. pSrBVec = (float32_t const *) pInB;
  129. pInA0 = pInA;
  130. pInA1 = pInA0 + MATRIX_DIM3;
  131. pInA2 = pInA1 + MATRIX_DIM3;
  132. /* enable predication to disable last (4th) vector element */
  133. mve_pred16_t p0 = vctp32q(MATRIX_DIM3);
  134. /*
  135. * load {b0,0, b0,1, b0,2, 0}
  136. */
  137. vecInB = vldrwq_z_f32(pSrBVec, p0);
  138. pSrBVec += MATRIX_DIM3;
  139. vecMac0 = vmulq(vecInB, *pInA0++);
  140. vecMac1 = vmulq(vecInB, *pInA1++);
  141. vecMac2 = vmulq(vecInB, *pInA2++);
  142. /*
  143. * load {b1,0, b1,1, b1,2, 0}
  144. */
  145. vecInB = vldrwq_z_f32(pSrBVec, p0);
  146. pSrBVec += MATRIX_DIM3;
  147. vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
  148. vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
  149. vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
  150. /*
  151. * load {b2,0, b2,1 , b2,2, 0}
  152. */
  153. vecInB = vldrwq_z_f32(pSrBVec, p0);
  154. pSrBVec += MATRIX_DIM3;
  155. vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
  156. vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
  157. vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
  158. /* partial vector stores */
  159. vstrwq_p_f32(pOut, vecMac0, p0);
  160. pOut += MATRIX_DIM3;
  161. vstrwq_p_f32(pOut, vecMac1, p0);
  162. pOut += MATRIX_DIM3;
  163. vstrwq_p_f32(pOut, vecMac2, p0);
  164. /*
  165. * Return to application
  166. */
  167. return (ARM_MATH_SUCCESS);
  168. }
  169. __STATIC_INLINE arm_status arm_mat_mult_f32_4x4_mve(
  170. const arm_matrix_instance_f32 *pSrcA,
  171. const arm_matrix_instance_f32 *pSrcB,
  172. arm_matrix_instance_f32 *pDst)
  173. {
  174. float32_t const *pSrBVec;
  175. float32_t *pInB = pSrcB->pData; /* input data matrix pointer B */
  176. float32_t *pInA = pSrcA->pData; /* input data matrix pointer A */
  177. float32_t *pOut = pDst->pData; /* output data matrix pointer */
  178. float32_t *pInA0, *pInA1, *pInA2, *pInA3;
  179. f32x4_t vecMac0, vecMac1, vecMac2, vecMac3;
  180. f32x4_t vecInB;
  181. pSrBVec = (float32_t const *) pInB;
  182. pInA0 = pInA;
  183. pInA1 = pInA0 + MATRIX_DIM4;
  184. pInA2 = pInA1 + MATRIX_DIM4;
  185. pInA3 = pInA2 + MATRIX_DIM4;
  186. /*
  187. * load {b0,0, b0,1, b0,2, b0,3}
  188. */
  189. vecInB = vld1q(pSrBVec);
  190. pSrBVec += MATRIX_DIM4;
  191. vecMac0 = vmulq(vecInB, *pInA0++);
  192. vecMac1 = vmulq(vecInB, *pInA1++);
  193. vecMac2 = vmulq(vecInB, *pInA2++);
  194. vecMac3 = vmulq(vecInB, *pInA3++);
  195. /*
  196. * load {b1,0, b1,1, b1,2, b1,3}
  197. */
  198. vecInB = vld1q(pSrBVec);
  199. pSrBVec += MATRIX_DIM4;
  200. vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
  201. vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
  202. vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
  203. vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
  204. /*
  205. * load {b2,0, b2,1, b2,2, b2,3}
  206. */
  207. vecInB = vld1q(pSrBVec);
  208. pSrBVec += MATRIX_DIM4;
  209. vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
  210. vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
  211. vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
  212. vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
  213. /*
  214. * load {b3,0, b3,1, b3,2, b3,3}
  215. */
  216. vecInB = vld1q(pSrBVec);
  217. pSrBVec += MATRIX_DIM4;
  218. vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
  219. vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
  220. vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
  221. vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
  222. vst1q(pOut, vecMac0);
  223. pOut += MATRIX_DIM4;
  224. vst1q(pOut, vecMac1);
  225. pOut += MATRIX_DIM4;
  226. vst1q(pOut, vecMac2);
  227. pOut += MATRIX_DIM4;
  228. vst1q(pOut, vecMac3);
  229. /*
  230. * Return to application
  231. */
  232. return (ARM_MATH_SUCCESS);
  233. }
  234. /**
  235. * @brief Floating-point matrix multiplication.
  236. * @param[in] *pSrcA points to the first input matrix structure
  237. * @param[in] *pSrcB points to the second input matrix structure
  238. * @param[out] *pDst points to output matrix structure
  239. * @return The function returns either
  240. * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking.
  241. */
  242. arm_status arm_mat_mult_f32(
  243. const arm_matrix_instance_f32 * pSrcA,
  244. const arm_matrix_instance_f32 * pSrcB,
  245. arm_matrix_instance_f32 * pDst)
  246. {
  247. float32_t *pInB = pSrcB->pData; /* input data matrix pointer B */
  248. float32_t *pInA = pSrcA->pData; /* input data matrix pointer A */
  249. float32_t *pOut = pDst->pData; /* output data matrix pointer */
  250. int numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
  251. int numColsB = pSrcB->numCols; /* number of columns of input matrix B */
  252. int numColsA = pSrcA->numCols; /* number of columns of input matrix A */
  253. uint32_t blkCnt; /* loop counters */
  254. uint32_t i;
  255. arm_status status;
  256. #ifdef ARM_MATH_MATRIX_CHECK
  257. /* Check for matrix mismatch condition */
  258. if ((pSrcA->numCols != pSrcB->numRows) ||
  259. (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols))
  260. {
  261. /* Set status as ARM_MATH_SIZE_MISMATCH */
  262. status = ARM_MATH_SIZE_MISMATCH;
  263. }
  264. else
  265. #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
  266. {
  267. /* small squared matrix specialized routines */
  268. if(numRowsA == numColsB && numColsB == numColsA) {
  269. if (numRowsA == 1)
  270. {
  271. pOut[0] = pInA[0] * pInB[0];
  272. return(ARM_MATH_SUCCESS);
  273. }
  274. else if(numRowsA == 2)
  275. return arm_mat_mult_f32_2x2_mve(pSrcA, pSrcB, pDst);
  276. else if(numRowsA == 3)
  277. return arm_mat_mult_f32_3x3_mve(pSrcA, pSrcB, pDst);
  278. else if(numRowsA == 4)
  279. return arm_mat_mult_f32_4x4_mve(pSrcA, pSrcB, pDst);
  280. }
  281. /* main loop process 4 rows */
  282. i = numRowsA >> 2;
  283. while (i > 0U)
  284. {
  285. float32_t *pInA0, *pInA1, *pInA2, *pInA3;
  286. float32_t *pInB0;
  287. float32_t *pOut0, *pOut1, *pOut2, *pOut3;
  288. f32x4_t vecMac0, vecMac1, vecMac2, vecMac3;
  289. f32x4_t vecInB;
  290. /* pointers to 4 consecutive output rows */
  291. pOut0 = pOut;
  292. pOut1 = pOut0 + numColsB;
  293. pOut2 = pOut1 + numColsB;
  294. pOut3 = pOut2 + numColsB;
  295. pInB0 = pInB;
  296. uint32_t k = numColsB >> 2;
  297. while (k > 0U)
  298. {
  299. /* pointers to 4 consecutive Matrix A rows */
  300. pInA0 = pInA;
  301. pInA1 = pInA0 + numColsA;
  302. pInA2 = pInA1 + numColsA;
  303. pInA3 = pInA2 + numColsA;
  304. vecMac0 = vdupq_n_f32(0.0f);
  305. vecMac1 = vdupq_n_f32(0.0f);
  306. vecMac2 = vdupq_n_f32(0.0f);
  307. vecMac3 = vdupq_n_f32(0.0f);
  308. blkCnt = numColsA;
  309. while (blkCnt > 0U)
  310. {
  311. /*
  312. * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
  313. */
  314. vecInB = *(f32x4_t *)pInB0; /* vldrwq_f32(pInB0, 0); */
  315. vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
  316. vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
  317. vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
  318. vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
  319. pInB0 = pInB0 + numColsB;
  320. /*
  321. * Decrement the blockSize loop counter
  322. */
  323. blkCnt--;
  324. }
  325. /* Store the results (4 x 4 block) in the destination buffer */
  326. vst1q(pOut0, vecMac0);
  327. pOut0 += 4;
  328. vst1q(pOut1, vecMac1);
  329. pOut1 += 4;
  330. vst1q(pOut2, vecMac2);
  331. pOut2 += 4;
  332. vst1q(pOut3, vecMac3);
  333. pOut3 += 4;
  334. /*
  335. * rewind
  336. */
  337. pInB0 -= (numColsB * numColsA) - 4;
  338. k--;
  339. }
  340. int colBLeft = numColsB & 3;
  341. if (colBLeft)
  342. {
  343. pInA0 = pInA;
  344. pInA1 = pInA0 + numColsA;
  345. pInA2 = pInA1 + numColsA;
  346. pInA3 = pInA2 + numColsA;
  347. mve_pred16_t p0 = vctp32q(colBLeft);
  348. vecMac0 = vdupq_n_f32(0.0f);
  349. vecMac1 = vdupq_n_f32(0.0f);
  350. vecMac2 = vdupq_n_f32(0.0f);
  351. vecMac3 = vdupq_n_f32(0.0f);
  352. blkCnt = numColsA;
  353. while (blkCnt > 0U)
  354. {
  355. /*
  356. * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
  357. */
  358. vecInB = vldrwq_z_f32(pInB0, p0);
  359. vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
  360. vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
  361. vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
  362. vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
  363. pInB0 = pInB0 + numColsB;
  364. /*
  365. * Decrement the blockSize loop counter
  366. */
  367. blkCnt--;
  368. }
  369. /* Store the results (4 x colBLeft block) in the destination buffer */
  370. vstrwq_p_f32(pOut0, vecMac0, p0);
  371. vstrwq_p_f32(pOut1, vecMac1, p0);
  372. vstrwq_p_f32(pOut2, vecMac2, p0);
  373. vstrwq_p_f32(pOut3, vecMac3, p0);
  374. }
  375. /* move to next rows */
  376. pInA += 4 * numColsA;
  377. pOut += 4 * numColsB;
  378. i--;
  379. }
  380. /*
  381. * non multiple of 4 rows for Matrix A
  382. * process single row
  383. */
  384. if (numRowsA & 3)
  385. {
  386. i = numRowsA & 3;
  387. while (i > 0U)
  388. {
  389. float32_t *pInA0;
  390. float32_t *pInB0;
  391. float32_t *pOut0;
  392. f32x4_t vecInB;
  393. f32x4_t vecMac0;
  394. pOut0 = pOut;
  395. pInB0 = pInB;
  396. uint32_t k = numColsB >> 2;
  397. while (k > 0U)
  398. {
  399. pInA0 = pInA;
  400. vecMac0 = vdupq_n_f32(0.0f);
  401. blkCnt = numColsA;
  402. while (blkCnt > 0U)
  403. {
  404. /*
  405. * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
  406. */
  407. vecInB = *(f32x4_t *)pInB0; /* vldrwq_f32(pInB0, 0); */
  408. vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
  409. pInB0 = pInB0 + numColsB;
  410. /*
  411. * Decrement the blockSize loop counter
  412. */
  413. blkCnt--;
  414. }
  415. /* Store the results (1 x 4 block) in the destination buffer */
  416. vst1q(pOut0, vecMac0);
  417. pOut0 += 4;
  418. /*
  419. * rewind
  420. */
  421. pInB0 -= (numColsB * numColsA) - 4;
  422. k--;
  423. }
  424. int colBLeft = numColsB & 3;
  425. if (colBLeft)
  426. {
  427. pInA0 = pInA;
  428. mve_pred16_t p0 = vctp32q(colBLeft);
  429. vecMac0 = vdupq_n_f32(0.0f);
  430. blkCnt = numColsA;
  431. while (blkCnt > 0U)
  432. {
  433. /*
  434. * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
  435. */
  436. vecInB = vldrwq_z_f32(pInB0, p0);
  437. vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
  438. pInB0 = pInB0 + numColsB;
  439. /*
  440. * Decrement the blockSize loop counter
  441. */
  442. blkCnt--;
  443. }
  444. /* Store the results (1 x colBLeft block) in the destination buffer */
  445. vstrwq_p_f32(pOut0, vecMac0, p0);
  446. }
  447. /* move to next row */
  448. pInA += 1 * numColsA;
  449. pOut += 1 * numColsB;
  450. i--;
  451. }
  452. }
  453. status = ARM_MATH_SUCCESS;
  454. }
  455. /* Return to application */
  456. return (status);
  457. }
  458. #else
  459. #if defined(ARM_MATH_NEON)
  460. /**
  461. * @brief Floating-point matrix multiplication.
  462. * @param[in] *pSrcA points to the first input matrix structure
  463. * @param[in] *pSrcB points to the second input matrix structure
  464. * @param[out] *pDst points to output matrix structure
  465. * @return The function returns either
  466. * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking.
  467. */
  468. arm_status arm_mat_mult_f32(
  469. const arm_matrix_instance_f32 * pSrcA,
  470. const arm_matrix_instance_f32 * pSrcB,
  471. arm_matrix_instance_f32 * pDst)
  472. {
  473. float32_t *pIn1 = pSrcA->pData; /* input data matrix pointer A */
  474. float32_t *pIn2 = pSrcB->pData; /* input data matrix pointer B */
  475. float32_t *pInA = pSrcA->pData; /* input data matrix pointer A */
  476. float32_t *pOut = pDst->pData; /* output data matrix pointer */
  477. float32_t *px; /* Temporary output data matrix pointer */
  478. float32_t sum; /* Accumulator */
  479. uint16_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
  480. uint16_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
  481. uint16_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
  482. uint32_t col, i = 0U, j, row = numRowsA, rowCnt, colCnt; /* loop counters */
  483. arm_status status; /* status of matrix multiplication */
  484. float32x4_t a0V, a1V, a2V, a3V, a4V, a5V, a6V, a7V;
  485. float32x4_t acc0,acc1,acc2,acc3,acc4,acc5,acc6,acc7,temp;
  486. float32x2_t accum = vdup_n_f32(0);
  487. float32_t *pIn1B = pSrcA->pData;
  488. float32_t *pIn1C = pSrcA->pData;
  489. float32_t *pIn1D = pSrcA->pData;
  490. float32_t *pIn1E = pSrcA->pData;
  491. float32_t *pIn1F = pSrcA->pData;
  492. float32_t *pIn1G = pSrcA->pData;
  493. float32_t *pIn1H = pSrcA->pData;
  494. float32_t *pxB,*pxC, *pxD, *pxE, *pxF, *pxG, *pxH; /* Temporary output data matrix pointer */
  495. float32_t sum0,sum1, sum2,sum3, sum4, sum5 , sum6, sum7;
  496. #ifdef ARM_MATH_MATRIX_CHECK
  497. /* Check for matrix mismatch condition */
  498. if ((pSrcA->numCols != pSrcB->numRows) ||
  499. (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols))
  500. {
  501. /* Set status as ARM_MATH_SIZE_MISMATCH */
  502. status = ARM_MATH_SIZE_MISMATCH;
  503. }
  504. else
  505. #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
  506. {
  507. /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
  508. /* Row loop */
  509. rowCnt = row >> 3;
  510. while(rowCnt > 0)
  511. {
  512. /* Output pointer is set to starting address of the row being processed */
  513. px = pOut + GROUPOFROWS*i;
  514. pxB = px + numColsB;
  515. pxC = px + 2*numColsB;
  516. pxD = px + 3*numColsB;
  517. pxE = px + 4*numColsB;
  518. pxF = px + 5*numColsB;
  519. pxG = px + 6*numColsB;
  520. pxH = px + 7*numColsB;
  521. /* For every row wise process, the column loop counter is to be initiated */
  522. col = numColsB;
  523. /* For every row wise process, the pIn2 pointer is set
  524. ** to the starting address of the pSrcB data */
  525. pIn2 = pSrcB->pData;
  526. j = 0U;
  527. /* Column loop */
  528. do
  529. {
  530. /* Set the variable sum, that acts as accumulator, to zero */
  531. sum0 = 0.0f;
  532. sum1 = 0.0f;
  533. sum2 = 0.0f;
  534. sum3 = 0.0f;
  535. sum4 = 0.0f;
  536. sum5 = 0.0f;
  537. sum6 = 0.0f;
  538. sum7 = 0.0f;
  539. /* Initiate the pointer pIn1 to point to the starting address of the column being processed */
  540. pIn1 = pInA;
  541. pIn1B = pIn1 + numColsA;
  542. pIn1C = pIn1 + 2*numColsA;
  543. pIn1D = pIn1 + 3*numColsA;
  544. pIn1E = pIn1 + 4*numColsA;
  545. pIn1F = pIn1 + 5*numColsA;
  546. pIn1G = pIn1 + 6*numColsA;
  547. pIn1H = pIn1 + 7*numColsA;
  548. acc0 = vdupq_n_f32(0.0);
  549. acc1 = vdupq_n_f32(0.0);
  550. acc2 = vdupq_n_f32(0.0);
  551. acc3 = vdupq_n_f32(0.0);
  552. acc4 = vdupq_n_f32(0.0);
  553. acc5 = vdupq_n_f32(0.0);
  554. acc6 = vdupq_n_f32(0.0);
  555. acc7 = vdupq_n_f32(0.0);
  556. /* Compute 4 MACs simultaneously. */
  557. colCnt = numColsA >> 2U;
  558. /* Matrix multiplication */
  559. while (colCnt > 0U)
  560. {
  561. /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
  562. a0V = vld1q_f32(pIn1);
  563. a1V = vld1q_f32(pIn1B);
  564. a2V = vld1q_f32(pIn1C);
  565. a3V = vld1q_f32(pIn1D);
  566. a4V = vld1q_f32(pIn1E);
  567. a5V = vld1q_f32(pIn1F);
  568. a6V = vld1q_f32(pIn1G);
  569. a7V = vld1q_f32(pIn1H);
  570. pIn1 += 4;
  571. pIn1B += 4;
  572. pIn1C += 4;
  573. pIn1D += 4;
  574. pIn1E += 4;
  575. pIn1F += 4;
  576. pIn1G += 4;
  577. pIn1H += 4;
  578. temp = vsetq_lane_f32(*pIn2,temp,0);
  579. pIn2 += numColsB;
  580. temp = vsetq_lane_f32(*pIn2,temp,1);
  581. pIn2 += numColsB;
  582. temp = vsetq_lane_f32(*pIn2,temp,2);
  583. pIn2 += numColsB;
  584. temp = vsetq_lane_f32(*pIn2,temp,3);
  585. pIn2 += numColsB;
  586. acc0 = vmlaq_f32(acc0,a0V,temp);
  587. acc1 = vmlaq_f32(acc1,a1V,temp);
  588. acc2 = vmlaq_f32(acc2,a2V,temp);
  589. acc3 = vmlaq_f32(acc3,a3V,temp);
  590. acc4 = vmlaq_f32(acc4,a4V,temp);
  591. acc5 = vmlaq_f32(acc5,a5V,temp);
  592. acc6 = vmlaq_f32(acc6,a6V,temp);
  593. acc7 = vmlaq_f32(acc7,a7V,temp);
  594. /* Decrement the loop count */
  595. colCnt--;
  596. }
  597. accum = vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0));
  598. sum0 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
  599. accum = vpadd_f32(vget_low_f32(acc1), vget_high_f32(acc1));
  600. sum1 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
  601. accum = vpadd_f32(vget_low_f32(acc2), vget_high_f32(acc2));
  602. sum2 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
  603. accum = vpadd_f32(vget_low_f32(acc3), vget_high_f32(acc3));
  604. sum3 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
  605. accum = vpadd_f32(vget_low_f32(acc4), vget_high_f32(acc4));
  606. sum4 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
  607. accum = vpadd_f32(vget_low_f32(acc5), vget_high_f32(acc5));
  608. sum5 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
  609. accum = vpadd_f32(vget_low_f32(acc6), vget_high_f32(acc6));
  610. sum6 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
  611. accum = vpadd_f32(vget_low_f32(acc7), vget_high_f32(acc7));
  612. sum7 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
  613. /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
  614. ** No loop unrolling is used. */
  615. colCnt = numColsA & 3;
  616. while (colCnt > 0U)
  617. {
  618. /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
  619. sum0 += *pIn1++ * (*pIn2);
  620. sum1 += *pIn1B++ * (*pIn2);
  621. sum2 += *pIn1C++ * (*pIn2);
  622. sum3 += *pIn1D++ * (*pIn2);
  623. sum4 += *pIn1E++ * (*pIn2);
  624. sum5 += *pIn1F++ * (*pIn2);
  625. sum6 += *pIn1G++ * (*pIn2);
  626. sum7 += *pIn1H++ * (*pIn2);
  627. pIn2 += numColsB;
  628. /* Decrement the loop counter */
  629. colCnt--;
  630. }
  631. /* Store the result in the destination buffer */
  632. *px++ = sum0;
  633. *pxB++ = sum1;
  634. *pxC++ = sum2;
  635. *pxD++ = sum3;
  636. *pxE++ = sum4;
  637. *pxF++ = sum5;
  638. *pxG++ = sum6;
  639. *pxH++ = sum7;
  640. /* Update the pointer pIn2 to point to the starting address of the next column */
  641. j++;
  642. pIn2 = pSrcB->pData + j;
  643. /* Decrement the column loop counter */
  644. col--;
  645. } while (col > 0U);
  646. /* Update the pointer pInA to point to the starting address of the next row */
  647. i = i + numColsB;
  648. pInA = pInA + GROUPOFROWS*numColsA;
  649. /* Decrement the row loop counter */
  650. rowCnt--;
  651. }
  652. /*
  653. i was the index of a group of rows computed by previous loop.
  654. Now i is the index of a row since below code is computing row per row
  655. and no more group of row per group of rows.
  656. */
  657. i = GROUPOFROWS*i;
  658. rowCnt = row & 7;
  659. while(rowCnt > 0)
  660. {
  661. /* Output pointer is set to starting address of the row being processed */
  662. px = pOut + i;
  663. /* For every row wise process, the column loop counter is to be initiated */
  664. col = numColsB;
  665. /* For every row wise process, the pIn2 pointer is set
  666. ** to the starting address of the pSrcB data */
  667. pIn2 = pSrcB->pData;
  668. j = 0U;
  669. /* Column loop */
  670. do
  671. {
  672. /* Set the variable sum, that acts as accumulator, to zero */
  673. sum = 0.0f;
  674. /* Initiate the pointer pIn1 to point to the starting address of the column being processed */
  675. pIn1 = pInA;
  676. acc0 = vdupq_n_f32(0.0);
  677. /* Compute 4 MACs simultaneously. */
  678. colCnt = numColsA >> 2U;
  679. /* Matrix multiplication */
  680. while (colCnt > 0U)
  681. {
  682. /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
  683. a0V = vld1q_f32(pIn1); // load & separate real/imag pSrcA (de-interleave 2)
  684. pIn1 += 4;
  685. temp = vsetq_lane_f32(*pIn2,temp,0);
  686. pIn2 += numColsB;
  687. temp = vsetq_lane_f32(*pIn2,temp,1);
  688. pIn2 += numColsB;
  689. temp = vsetq_lane_f32(*pIn2,temp,2);
  690. pIn2 += numColsB;
  691. temp = vsetq_lane_f32(*pIn2,temp,3);
  692. pIn2 += numColsB;
  693. acc0 = vmlaq_f32(acc0,a0V,temp);
  694. /* Decrement the loop count */
  695. colCnt--;
  696. }
  697. accum = vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0));
  698. sum += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
  699. /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
  700. ** No loop unrolling is used. */
  701. colCnt = numColsA % 0x4U;
  702. while (colCnt > 0U)
  703. {
  704. /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
  705. sum += *pIn1++ * (*pIn2);
  706. pIn2 += numColsB;
  707. /* Decrement the loop counter */
  708. colCnt--;
  709. }
  710. /* Store the result in the destination buffer */
  711. *px++ = sum;
  712. /* Update the pointer pIn2 to point to the starting address of the next column */
  713. j++;
  714. pIn2 = pSrcB->pData + j;
  715. /* Decrement the column loop counter */
  716. col--;
  717. } while (col > 0U);
  718. /* Update the pointer pInA to point to the starting address of the next row */
  719. i = i + numColsB;
  720. pInA = pInA + numColsA;
  721. /* Decrement the row loop counter */
  722. rowCnt--;
  723. }
  724. /* Set status as ARM_MATH_SUCCESS */
  725. status = ARM_MATH_SUCCESS;
  726. }
  727. /* Return to application */
  728. return (status);
  729. }
  730. #else
  731. /**
  732. * @brief Floating-point matrix multiplication.
  733. * @param[in] *pSrcA points to the first input matrix structure
  734. * @param[in] *pSrcB points to the second input matrix structure
  735. * @param[out] *pDst points to output matrix structure
  736. * @return The function returns either
  737. * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking.
  738. */
  739. arm_status arm_mat_mult_f32(
  740. const arm_matrix_instance_f32 * pSrcA,
  741. const arm_matrix_instance_f32 * pSrcB,
  742. arm_matrix_instance_f32 * pDst)
  743. {
  744. float32_t *pIn1 = pSrcA->pData; /* Input data matrix pointer A */
  745. float32_t *pIn2 = pSrcB->pData; /* Input data matrix pointer B */
  746. float32_t *pInA = pSrcA->pData; /* Input data matrix pointer A */
  747. float32_t *pInB = pSrcB->pData; /* Input data matrix pointer B */
  748. float32_t *pOut = pDst->pData; /* Output data matrix pointer */
  749. float32_t *px; /* Temporary output data matrix pointer */
  750. float32_t sum; /* Accumulator */
  751. uint16_t numRowsA = pSrcA->numRows; /* Number of rows of input matrix A */
  752. uint16_t numColsB = pSrcB->numCols; /* Number of columns of input matrix B */
  753. uint16_t numColsA = pSrcA->numCols; /* Number of columns of input matrix A */
  754. uint32_t col, i = 0U, row = numRowsA, colCnt; /* Loop counters */
  755. arm_status status; /* Status of matrix multiplication */
  756. #ifdef ARM_MATH_MATRIX_CHECK
  757. /* Check for matrix mismatch condition */
  758. if ((pSrcA->numCols != pSrcB->numRows) ||
  759. (pSrcA->numRows != pDst->numRows) ||
  760. (pSrcB->numCols != pDst->numCols) )
  761. {
  762. /* Set status as ARM_MATH_SIZE_MISMATCH */
  763. status = ARM_MATH_SIZE_MISMATCH;
  764. }
  765. else
  766. #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
  767. {
  768. /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
  769. /* row loop */
  770. do
  771. {
  772. /* Output pointer is set to starting address of row being processed */
  773. px = pOut + i;
  774. /* For every row wise process, column loop counter is to be initiated */
  775. col = numColsB;
  776. /* For every row wise process, pIn2 pointer is set to starting address of pSrcB data */
  777. pIn2 = pSrcB->pData;
  778. /* column loop */
  779. do
  780. {
  781. /* Set the variable sum, that acts as accumulator, to zero */
  782. sum = 0.0f;
  783. /* Initialize pointer pIn1 to point to starting address of column being processed */
  784. pIn1 = pInA;
  785. #if defined (ARM_MATH_LOOPUNROLL)
  786. /* Loop unrolling: Compute 4 MACs at a time. */
  787. colCnt = numColsA >> 2U;
  788. /* matrix multiplication */
  789. while (colCnt > 0U)
  790. {
  791. /* c(m,p) = a(m,1) * b(1,p) + a(m,2) * b(2,p) + .... + a(m,n) * b(n,p) */
  792. /* Perform the multiply-accumulates */
  793. sum += *pIn1++ * *pIn2;
  794. pIn2 += numColsB;
  795. sum += *pIn1++ * *pIn2;
  796. pIn2 += numColsB;
  797. sum += *pIn1++ * *pIn2;
  798. pIn2 += numColsB;
  799. sum += *pIn1++ * *pIn2;
  800. pIn2 += numColsB;
  801. /* Decrement loop counter */
  802. colCnt--;
  803. }
  804. /* Loop unrolling: Compute remaining MACs */
  805. colCnt = numColsA % 0x4U;
  806. #else
  807. /* Initialize cntCnt with number of columns */
  808. colCnt = numColsA;
  809. #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
  810. while (colCnt > 0U)
  811. {
  812. /* c(m,p) = a(m,1) * b(1,p) + a(m,2) * b(2,p) + .... + a(m,n) * b(n,p) */
  813. /* Perform the multiply-accumulates */
  814. sum += *pIn1++ * *pIn2;
  815. pIn2 += numColsB;
  816. /* Decrement loop counter */
  817. colCnt--;
  818. }
  819. /* Store result in destination buffer */
  820. *px++ = sum;
  821. /* Decrement column loop counter */
  822. col--;
  823. /* Update pointer pIn2 to point to starting address of next column */
  824. pIn2 = pInB + (numColsB - col);
  825. } while (col > 0U);
  826. /* Update pointer pInA to point to starting address of next row */
  827. i = i + numColsB;
  828. pInA = pInA + numColsA;
  829. /* Decrement row loop counter */
  830. row--;
  831. } while (row > 0U);
  832. /* Set status as ARM_MATH_SUCCESS */
  833. status = ARM_MATH_SUCCESS;
  834. }
  835. /* Return to application */
  836. return (status);
  837. }
  838. #endif /* #if defined(ARM_MATH_NEON) */
  839. #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
  840. /**
  841. * @} end of MatrixMult group
  842. */