arm_mat_mult_f32.c 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983
  1. /* ----------------------------------------------------------------------
  2. * Project: CMSIS DSP Library
  3. * Title: arm_mat_mult_f32.c
  4. * Description: Floating-point matrix multiplication
  5. *
  6. * $Date: 18. March 2019
  7. * $Revision: V1.6.0
  8. *
  9. * Target Processor: Cortex-M cores
  10. * -------------------------------------------------------------------- */
  11. /*
  12. * Copyright (C) 2010-2019 ARM Limited or its affiliates. All rights reserved.
  13. *
  14. * SPDX-License-Identifier: Apache-2.0
  15. *
  16. * Licensed under the Apache License, Version 2.0 (the License); you may
  17. * not use this file except in compliance with the License.
  18. * You may obtain a copy of the License at
  19. *
  20. * www.apache.org/licenses/LICENSE-2.0
  21. *
  22. * Unless required by applicable law or agreed to in writing, software
  23. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  24. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  25. * See the License for the specific language governing permissions and
  26. * limitations under the License.
  27. */
  28. #include "arm_math.h"
  29. /**
  30. * @ingroup groupMatrix
  31. */
  32. /**
  33. * @defgroup MatrixMult Matrix Multiplication
  34. *
  35. * Multiplies two matrices.
  36. *
  37. * \image html MatrixMultiplication.gif "Multiplication of two 3 x 3 matrices"
  38. * Matrix multiplication is only defined if the number of columns of the
  39. * first matrix equals the number of rows of the second matrix.
  40. * Multiplying an <code>M x N</code> matrix with an <code>N x P</code> matrix results
  41. * in an <code>M x P</code> matrix.
  42. * When matrix size checking is enabled, the functions check: (1) that the inner dimensions of
  43. * <code>pSrcA</code> and <code>pSrcB</code> are equal; and (2) that the size of the output
  44. * matrix equals the outer dimensions of <code>pSrcA</code> and <code>pSrcB</code>.
  45. */
  46. /**
  47. * @addtogroup MatrixMult
  48. * @{
  49. */
  50. /**
  51. * @brief Floating-point matrix multiplication.
  52. * @param[in] *pSrcA points to the first input matrix structure
  53. * @param[in] *pSrcB points to the second input matrix structure
  54. * @param[out] *pDst points to output matrix structure
  55. * @return The function returns either
  56. * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking.
  57. */
  58. #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
  59. #define MATRIX_DIM3 3
  60. #define MATRIX_DIM4 4
  61. __STATIC_INLINE arm_status arm_mat_mult_f32_2x2_mve(
  62. const arm_matrix_instance_f32 *pSrcA,
  63. const arm_matrix_instance_f32 *pSrcB,
  64. arm_matrix_instance_f32 *pDst)
  65. {
  66. /* {a00, a00, a10, a10} */
  67. static const uint32_t offsetA0[4] = { 0, 0, 2, 2 };
  68. /* {b00, b01, b00, b01} */
  69. static const uint32_t offsetB0[4] = { 0, 1, 0, 1 };
  70. /* {a01, a01, a11, a11} */
  71. static const uint32_t offsetA1[4] = { 1, 1, 3, 3 };
  72. /* {b10, b11, b10, b11} */
  73. static const uint32_t offsetB1[4] = { 2, 3, 2, 3 };
  74. uint32x4_t vecOffsA, vecOffsB;
  75. f32x4_t vecInA, vecInB, vecDst;
  76. vecOffsA = vldrwq_u32((uint32_t const *) offsetA0);
  77. vecOffsB = vldrwq_u32((uint32_t const *) offsetB0);
  78. vecInA = vldrwq_gather_shifted_offset((float32_t const *) pSrcA->pData, vecOffsA);
  79. vecInB = vldrwq_gather_shifted_offset((float32_t const *) pSrcB->pData, vecOffsB);
  80. vecDst = vmulq(vecInA, vecInB);
  81. vecOffsA = vldrwq_u32((uint32_t const *) offsetA1);
  82. vecOffsB = vldrwq_u32((uint32_t const *) offsetB1);
  83. vecInA = vldrwq_gather_shifted_offset((float32_t const *) pSrcA->pData, vecOffsA);
  84. vecInB = vldrwq_gather_shifted_offset((float32_t const *) pSrcB->pData, vecOffsB);
  85. vecDst = vfmaq(vecDst, vecInA, vecInB);
  86. vstrwq_f32(pDst->pData, vecDst);
  87. return (ARM_MATH_SUCCESS);
  88. }
  89. /*
  90. * A = {{a00, a01, a02},
  91. * {a10, a11, a12},
  92. * {a20, a21, a22}}
  93. * B = {{b00, b01, b02},
  94. * {b10, b11, b12},
  95. * {b20, b21, b22}}
  96. *
  97. * Dst = {{a00 b00 + a01 b10 + a02 b20, a00 b01 + a01 b11 + a02 b21, a00 b02 + a01 b12 + a02 b22},
  98. * {a10 b00 + a11 b10 + a12 b20, a10 b01 + a11 b11 + a12 b21, a10 b02 + a11 b12 + a12 b22},
  99. * {a20 b00 + a21 b10 + a22 b20, a20 b01 + a21 b11 + a22 b21, a20 b02 + a21 b12 + a22 b22}}
  100. */
  101. __STATIC_INLINE arm_status arm_mat_mult_f32_3x3_mve(
  102. const arm_matrix_instance_f32 *pSrcA,
  103. const arm_matrix_instance_f32 *pSrcB,
  104. arm_matrix_instance_f32 *pDst)
  105. {
  106. float32_t *pInB = pSrcB->pData; /* input data matrix pointer B */
  107. float32_t *pInA = pSrcA->pData; /* input data matrix pointer A */
  108. float32_t *pOut = pDst->pData; /* output data matrix pointer */
  109. float32_t *pInA0, *pInA1, *pInA2;
  110. f32x4_t vecMac0, vecMac1, vecMac2;
  111. f32x4_t vecInB;
  112. float32_t const *pSrBVec;
  113. pSrBVec = (float32_t const *) pInB;
  114. pInA0 = pInA;
  115. pInA1 = pInA0 + MATRIX_DIM3;
  116. pInA2 = pInA1 + MATRIX_DIM3;
  117. /* enable predication to disable last (4th) vector element */
  118. mve_pred16_t p0 = vctp32q(MATRIX_DIM3);
  119. /*
  120. * load {b0,0, b0,1, b0,2, 0}
  121. */
  122. vecInB = vldrwq_z_f32(pSrBVec, p0);
  123. pSrBVec += MATRIX_DIM3;
  124. vecMac0 = vmulq(vecInB, *pInA0++);
  125. vecMac1 = vmulq(vecInB, *pInA1++);
  126. vecMac2 = vmulq(vecInB, *pInA2++);
  127. /*
  128. * load {b1,0, b1,1, b1,2, 0}
  129. */
  130. vecInB = vldrwq_z_f32(pSrBVec, p0);
  131. pSrBVec += MATRIX_DIM3;
  132. vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
  133. vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
  134. vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
  135. /*
  136. * load {b2,0, b2,1 , b2,2, 0}
  137. */
  138. vecInB = vldrwq_z_f32(pSrBVec, p0);
  139. pSrBVec += MATRIX_DIM3;
  140. vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
  141. vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
  142. vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
  143. /* partial vector stores */
  144. vstrwq_p_f32(pOut, vecMac0, p0);
  145. pOut += MATRIX_DIM3;
  146. vstrwq_p_f32(pOut, vecMac1, p0);
  147. pOut += MATRIX_DIM3;
  148. vstrwq_p_f32(pOut, vecMac2, p0);
  149. /*
  150. * Return to application
  151. */
  152. return (ARM_MATH_SUCCESS);
  153. }
  154. __STATIC_INLINE arm_status arm_mat_mult_f32_4x4_mve(
  155. const arm_matrix_instance_f32 *pSrcA,
  156. const arm_matrix_instance_f32 *pSrcB,
  157. arm_matrix_instance_f32 *pDst)
  158. {
  159. float32_t const *pSrBVec;
  160. float32_t *pInB = pSrcB->pData; /* input data matrix pointer B */
  161. float32_t *pInA = pSrcA->pData; /* input data matrix pointer A */
  162. float32_t *pOut = pDst->pData; /* output data matrix pointer */
  163. float32_t *pInA0, *pInA1, *pInA2, *pInA3;
  164. f32x4_t vecMac0, vecMac1, vecMac2, vecMac3;
  165. f32x4_t vecInB;
  166. pSrBVec = (float32_t const *) pInB;
  167. pInA0 = pInA;
  168. pInA1 = pInA0 + MATRIX_DIM4;
  169. pInA2 = pInA1 + MATRIX_DIM4;
  170. pInA3 = pInA2 + MATRIX_DIM4;
  171. /*
  172. * load {b0,0, b0,1, b0,2, b0,3}
  173. */
  174. vecInB = vld1q(pSrBVec);
  175. pSrBVec += MATRIX_DIM4;
  176. vecMac0 = vmulq(vecInB, *pInA0++);
  177. vecMac1 = vmulq(vecInB, *pInA1++);
  178. vecMac2 = vmulq(vecInB, *pInA2++);
  179. vecMac3 = vmulq(vecInB, *pInA3++);
  180. /*
  181. * load {b1,0, b1,1, b1,2, b1,3}
  182. */
  183. vecInB = vld1q(pSrBVec);
  184. pSrBVec += MATRIX_DIM4;
  185. vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
  186. vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
  187. vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
  188. vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
  189. /*
  190. * load {b2,0, b2,1, b2,2, b2,3}
  191. */
  192. vecInB = vld1q(pSrBVec);
  193. pSrBVec += MATRIX_DIM4;
  194. vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
  195. vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
  196. vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
  197. vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
  198. /*
  199. * load {b3,0, b3,1, b3,2, b3,3}
  200. */
  201. vecInB = vld1q(pSrBVec);
  202. pSrBVec += MATRIX_DIM4;
  203. vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
  204. vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
  205. vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
  206. vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
  207. vst1q(pOut, vecMac0);
  208. pOut += MATRIX_DIM4;
  209. vst1q(pOut, vecMac1);
  210. pOut += MATRIX_DIM4;
  211. vst1q(pOut, vecMac2);
  212. pOut += MATRIX_DIM4;
  213. vst1q(pOut, vecMac3);
  214. /*
  215. * Return to application
  216. */
  217. return (ARM_MATH_SUCCESS);
  218. }
  219. arm_status arm_mat_mult_f32(
  220. const arm_matrix_instance_f32 * pSrcA,
  221. const arm_matrix_instance_f32 * pSrcB,
  222. arm_matrix_instance_f32 * pDst)
  223. {
  224. float32_t *pInB = pSrcB->pData; /* input data matrix pointer B */
  225. float32_t *pInA = pSrcA->pData; /* input data matrix pointer A */
  226. float32_t *pOut = pDst->pData; /* output data matrix pointer */
  227. int numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
  228. int numColsB = pSrcB->numCols; /* number of columns of input matrix B */
  229. int numColsA = pSrcA->numCols; /* number of columns of input matrix A */
  230. uint32_t blkCnt; /* loop counters */
  231. uint32_t i;
  232. arm_status status;
  233. #ifdef ARM_MATH_MATRIX_CHECK
  234. /* Check for matrix mismatch condition */
  235. if ((pSrcA->numCols != pSrcB->numRows) ||
  236. (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols))
  237. {
  238. /* Set status as ARM_MATH_SIZE_MISMATCH */
  239. status = ARM_MATH_SIZE_MISMATCH;
  240. }
  241. else
  242. #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
  243. {
  244. /* small squared matrix specialized routines */
  245. if(numRowsA == numColsB && numColsB == numColsA) {
  246. if (numRowsA == 1)
  247. {
  248. pOut[0] = pInA[0] * pInB[0];
  249. return(ARM_MATH_SUCCESS);
  250. }
  251. else if(numRowsA == 2)
  252. return arm_mat_mult_f32_2x2_mve(pSrcA, pSrcB, pDst);
  253. else if(numRowsA == 3)
  254. return arm_mat_mult_f32_3x3_mve(pSrcA, pSrcB, pDst);
  255. else if(numRowsA == 4)
  256. return arm_mat_mult_f32_4x4_mve(pSrcA, pSrcB, pDst);
  257. }
  258. /* main loop process 4 rows */
  259. i = numRowsA >> 2;
  260. while (i > 0U)
  261. {
  262. float32_t *pInA0, *pInA1, *pInA2, *pInA3;
  263. float32_t *pInB0;
  264. float32_t *pOut0, *pOut1, *pOut2, *pOut3;
  265. f32x4_t vecMac0, vecMac1, vecMac2, vecMac3;
  266. f32x4_t vecInB;
  267. /* pointers to 4 consecutive output rows */
  268. pOut0 = pOut;
  269. pOut1 = pOut0 + numColsB;
  270. pOut2 = pOut1 + numColsB;
  271. pOut3 = pOut2 + numColsB;
  272. pInB0 = pInB;
  273. uint32_t k = numColsB >> 2;
  274. while (k > 0U)
  275. {
  276. /* pointers to 4 consecutive Matrix A rows */
  277. pInA0 = pInA;
  278. pInA1 = pInA0 + numColsA;
  279. pInA2 = pInA1 + numColsA;
  280. pInA3 = pInA2 + numColsA;
  281. vecMac0 = vdupq_n_f32(0.0f);
  282. vecMac1 = vdupq_n_f32(0.0f);
  283. vecMac2 = vdupq_n_f32(0.0f);
  284. vecMac3 = vdupq_n_f32(0.0f);
  285. blkCnt = numColsA;
  286. while (blkCnt > 0U)
  287. {
  288. /*
  289. * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
  290. */
  291. vecInB = *(f32x4_t *)pInB0; /* vldrwq_f32(pInB0, 0); */
  292. vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
  293. vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
  294. vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
  295. vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
  296. pInB0 = pInB0 + numColsB;
  297. /*
  298. * Decrement the blockSize loop counter
  299. */
  300. blkCnt--;
  301. }
  302. /* Store the results (4 x 4 block) in the destination buffer */
  303. vst1q(pOut0, vecMac0);
  304. pOut0 += 4;
  305. vst1q(pOut1, vecMac1);
  306. pOut1 += 4;
  307. vst1q(pOut2, vecMac2);
  308. pOut2 += 4;
  309. vst1q(pOut3, vecMac3);
  310. pOut3 += 4;
  311. /*
  312. * rewind
  313. */
  314. pInB0 -= (numColsB * numColsA) - 4;
  315. k--;
  316. }
  317. int colBLeft = numColsB & 3;
  318. if (colBLeft)
  319. {
  320. pInA0 = pInA;
  321. pInA1 = pInA0 + numColsA;
  322. pInA2 = pInA1 + numColsA;
  323. pInA3 = pInA2 + numColsA;
  324. mve_pred16_t p0 = vctp32q(colBLeft);
  325. vecMac0 = vdupq_n_f32(0.0f);
  326. vecMac1 = vdupq_n_f32(0.0f);
  327. vecMac2 = vdupq_n_f32(0.0f);
  328. vecMac3 = vdupq_n_f32(0.0f);
  329. blkCnt = numColsA;
  330. while (blkCnt > 0U)
  331. {
  332. /*
  333. * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
  334. */
  335. vecInB = vldrwq_z_f32(pInB0, p0);
  336. vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
  337. vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
  338. vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
  339. vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
  340. pInB0 = pInB0 + numColsB;
  341. /*
  342. * Decrement the blockSize loop counter
  343. */
  344. blkCnt--;
  345. }
  346. /* Store the results (4 x colBLeft block) in the destination buffer */
  347. vstrwq_p_f32(pOut0, vecMac0, p0);
  348. vstrwq_p_f32(pOut1, vecMac1, p0);
  349. vstrwq_p_f32(pOut2, vecMac2, p0);
  350. vstrwq_p_f32(pOut3, vecMac3, p0);
  351. }
  352. /* move to next rows */
  353. pInA += 4 * numColsA;
  354. pOut += 4 * numColsB;
  355. i--;
  356. }
  357. /*
  358. * non multiple of 4 rows for Matrix A
  359. * process single row
  360. */
  361. if (numRowsA & 3)
  362. {
  363. i = numRowsA & 3;
  364. while (i > 0U)
  365. {
  366. float32_t *pInA0;
  367. float32_t *pInB0;
  368. float32_t *pOut0;
  369. f32x4_t vecInB;
  370. f32x4_t vecMac0;
  371. pOut0 = pOut;
  372. pInB0 = pInB;
  373. uint32_t k = numColsB >> 2;
  374. while (k > 0U)
  375. {
  376. pInA0 = pInA;
  377. vecMac0 = vdupq_n_f32(0.0f);
  378. blkCnt = numColsA;
  379. while (blkCnt > 0U)
  380. {
  381. /*
  382. * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
  383. */
  384. vecInB = *(f32x4_t *)pInB0; /* vldrwq_f32(pInB0, 0); */
  385. vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
  386. pInB0 = pInB0 + numColsB;
  387. /*
  388. * Decrement the blockSize loop counter
  389. */
  390. blkCnt--;
  391. }
  392. /* Store the results (1 x 4 block) in the destination buffer */
  393. vst1q(pOut0, vecMac0);
  394. pOut0 += 4;
  395. /*
  396. * rewind
  397. */
  398. pInB0 -= (numColsB * numColsA) - 4;
  399. k--;
  400. }
  401. int colBLeft = numColsB & 3;
  402. if (colBLeft)
  403. {
  404. pInA0 = pInA;
  405. mve_pred16_t p0 = vctp32q(colBLeft);
  406. vecMac0 = vdupq_n_f32(0.0f);
  407. blkCnt = numColsA;
  408. while (blkCnt > 0U)
  409. {
  410. /*
  411. * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
  412. */
  413. vecInB = vldrwq_z_f32(pInB0, p0);
  414. vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
  415. pInB0 = pInB0 + numColsB;
  416. /*
  417. * Decrement the blockSize loop counter
  418. */
  419. blkCnt--;
  420. }
  421. /* Store the results (1 x colBLeft block) in the destination buffer */
  422. vstrwq_p_f32(pOut0, vecMac0, p0);
  423. }
  424. /* move to next row */
  425. pInA += 1 * numColsA;
  426. pOut += 1 * numColsB;
  427. i--;
  428. }
  429. }
  430. status = ARM_MATH_SUCCESS;
  431. }
  432. /* Return to application */
  433. return (status);
  434. }
  435. #else
  436. #if defined(ARM_MATH_NEON)
  437. #define GROUPOFROWS 8
  438. arm_status arm_mat_mult_f32(
  439. const arm_matrix_instance_f32 * pSrcA,
  440. const arm_matrix_instance_f32 * pSrcB,
  441. arm_matrix_instance_f32 * pDst)
  442. {
  443. float32_t *pIn1 = pSrcA->pData; /* input data matrix pointer A */
  444. float32_t *pIn2 = pSrcB->pData; /* input data matrix pointer B */
  445. float32_t *pInA = pSrcA->pData; /* input data matrix pointer A */
  446. float32_t *pOut = pDst->pData; /* output data matrix pointer */
  447. float32_t *px; /* Temporary output data matrix pointer */
  448. float32_t sum; /* Accumulator */
  449. uint16_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
  450. uint16_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
  451. uint16_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
  452. uint16_t col, i = 0U, j, row = numRowsA, rowCnt, colCnt; /* loop counters */
  453. arm_status status; /* status of matrix multiplication */
  454. float32x4_t a0V, a1V, a2V, a3V, a4V, a5V, a6V, a7V;
  455. float32x4_t acc0,acc1,acc2,acc3,acc4,acc5,acc6,acc7,temp;
  456. float32x2_t accum = vdup_n_f32(0);
  457. float32_t *pIn1B = pSrcA->pData;
  458. float32_t *pIn1C = pSrcA->pData;
  459. float32_t *pIn1D = pSrcA->pData;
  460. float32_t *pIn1E = pSrcA->pData;
  461. float32_t *pIn1F = pSrcA->pData;
  462. float32_t *pIn1G = pSrcA->pData;
  463. float32_t *pIn1H = pSrcA->pData;
  464. float32_t *pxB,*pxC, *pxD, *pxE, *pxF, *pxG, *pxH; /* Temporary output data matrix pointer */
  465. float32_t sum0,sum1, sum2,sum3, sum4, sum5 , sum6, sum7;
  466. #ifdef ARM_MATH_MATRIX_CHECK
  467. /* Check for matrix mismatch condition */
  468. if ((pSrcA->numCols != pSrcB->numRows) ||
  469. (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols))
  470. {
  471. /* Set status as ARM_MATH_SIZE_MISMATCH */
  472. status = ARM_MATH_SIZE_MISMATCH;
  473. }
  474. else
  475. #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
  476. {
  477. /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
  478. /* Row loop */
  479. rowCnt = row >> 3;
  480. while(rowCnt > 0)
  481. {
  482. /* Output pointer is set to starting address of the row being processed */
  483. px = pOut + GROUPOFROWS*i;
  484. pxB = px + numColsB;
  485. pxC = px + 2*numColsB;
  486. pxD = px + 3*numColsB;
  487. pxE = px + 4*numColsB;
  488. pxF = px + 5*numColsB;
  489. pxG = px + 6*numColsB;
  490. pxH = px + 7*numColsB;
  491. /* For every row wise process, the column loop counter is to be initiated */
  492. col = numColsB;
  493. /* For every row wise process, the pIn2 pointer is set
  494. ** to the starting address of the pSrcB data */
  495. pIn2 = pSrcB->pData;
  496. j = 0U;
  497. /* Column loop */
  498. do
  499. {
  500. /* Set the variable sum, that acts as accumulator, to zero */
  501. sum0 = 0.0f;
  502. sum1 = 0.0f;
  503. sum2 = 0.0f;
  504. sum3 = 0.0f;
  505. sum4 = 0.0f;
  506. sum5 = 0.0f;
  507. sum6 = 0.0f;
  508. sum7 = 0.0f;
  509. /* Initiate the pointer pIn1 to point to the starting address of the column being processed */
  510. pIn1 = pInA;
  511. pIn1B = pIn1 + numColsA;
  512. pIn1C = pIn1 + 2*numColsA;
  513. pIn1D = pIn1 + 3*numColsA;
  514. pIn1E = pIn1 + 4*numColsA;
  515. pIn1F = pIn1 + 5*numColsA;
  516. pIn1G = pIn1 + 6*numColsA;
  517. pIn1H = pIn1 + 7*numColsA;
  518. acc0 = vdupq_n_f32(0.0);
  519. acc1 = vdupq_n_f32(0.0);
  520. acc2 = vdupq_n_f32(0.0);
  521. acc3 = vdupq_n_f32(0.0);
  522. acc4 = vdupq_n_f32(0.0);
  523. acc5 = vdupq_n_f32(0.0);
  524. acc6 = vdupq_n_f32(0.0);
  525. acc7 = vdupq_n_f32(0.0);
  526. /* Compute 4 MACs simultaneously. */
  527. colCnt = numColsA >> 2U;
  528. /* Matrix multiplication */
  529. while (colCnt > 0U)
  530. {
  531. /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
  532. a0V = vld1q_f32(pIn1);
  533. a1V = vld1q_f32(pIn1B);
  534. a2V = vld1q_f32(pIn1C);
  535. a3V = vld1q_f32(pIn1D);
  536. a4V = vld1q_f32(pIn1E);
  537. a5V = vld1q_f32(pIn1F);
  538. a6V = vld1q_f32(pIn1G);
  539. a7V = vld1q_f32(pIn1H);
  540. pIn1 += 4;
  541. pIn1B += 4;
  542. pIn1C += 4;
  543. pIn1D += 4;
  544. pIn1E += 4;
  545. pIn1F += 4;
  546. pIn1G += 4;
  547. pIn1H += 4;
  548. temp = vsetq_lane_f32(*pIn2,temp,0);
  549. pIn2 += numColsB;
  550. temp = vsetq_lane_f32(*pIn2,temp,1);
  551. pIn2 += numColsB;
  552. temp = vsetq_lane_f32(*pIn2,temp,2);
  553. pIn2 += numColsB;
  554. temp = vsetq_lane_f32(*pIn2,temp,3);
  555. pIn2 += numColsB;
  556. acc0 = vmlaq_f32(acc0,a0V,temp);
  557. acc1 = vmlaq_f32(acc1,a1V,temp);
  558. acc2 = vmlaq_f32(acc2,a2V,temp);
  559. acc3 = vmlaq_f32(acc3,a3V,temp);
  560. acc4 = vmlaq_f32(acc4,a4V,temp);
  561. acc5 = vmlaq_f32(acc5,a5V,temp);
  562. acc6 = vmlaq_f32(acc6,a6V,temp);
  563. acc7 = vmlaq_f32(acc7,a7V,temp);
  564. /* Decrement the loop count */
  565. colCnt--;
  566. }
  567. accum = vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0));
  568. sum0 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
  569. accum = vpadd_f32(vget_low_f32(acc1), vget_high_f32(acc1));
  570. sum1 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
  571. accum = vpadd_f32(vget_low_f32(acc2), vget_high_f32(acc2));
  572. sum2 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
  573. accum = vpadd_f32(vget_low_f32(acc3), vget_high_f32(acc3));
  574. sum3 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
  575. accum = vpadd_f32(vget_low_f32(acc4), vget_high_f32(acc4));
  576. sum4 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
  577. accum = vpadd_f32(vget_low_f32(acc5), vget_high_f32(acc5));
  578. sum5 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
  579. accum = vpadd_f32(vget_low_f32(acc6), vget_high_f32(acc6));
  580. sum6 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
  581. accum = vpadd_f32(vget_low_f32(acc7), vget_high_f32(acc7));
  582. sum7 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
  583. /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
  584. ** No loop unrolling is used. */
  585. colCnt = numColsA & 3;
  586. while (colCnt > 0U)
  587. {
  588. /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
  589. sum0 += *pIn1++ * (*pIn2);
  590. sum1 += *pIn1B++ * (*pIn2);
  591. sum2 += *pIn1C++ * (*pIn2);
  592. sum3 += *pIn1D++ * (*pIn2);
  593. sum4 += *pIn1E++ * (*pIn2);
  594. sum5 += *pIn1F++ * (*pIn2);
  595. sum6 += *pIn1G++ * (*pIn2);
  596. sum7 += *pIn1H++ * (*pIn2);
  597. pIn2 += numColsB;
  598. /* Decrement the loop counter */
  599. colCnt--;
  600. }
  601. /* Store the result in the destination buffer */
  602. *px++ = sum0;
  603. *pxB++ = sum1;
  604. *pxC++ = sum2;
  605. *pxD++ = sum3;
  606. *pxE++ = sum4;
  607. *pxF++ = sum5;
  608. *pxG++ = sum6;
  609. *pxH++ = sum7;
  610. /* Update the pointer pIn2 to point to the starting address of the next column */
  611. j++;
  612. pIn2 = pSrcB->pData + j;
  613. /* Decrement the column loop counter */
  614. col--;
  615. } while (col > 0U);
  616. /* Update the pointer pInA to point to the starting address of the next row */
  617. i = i + numColsB;
  618. pInA = pInA + GROUPOFROWS*numColsA;
  619. /* Decrement the row loop counter */
  620. rowCnt--;
  621. }
  622. /*
  623. i was the index of a group of rows computed by previous loop.
  624. Now i is the index of a row since below code is computing row per row
  625. and no more group of row per group of rows.
  626. */
  627. i = GROUPOFROWS*i;
  628. rowCnt = row & 7;
  629. while(rowCnt > 0)
  630. {
  631. /* Output pointer is set to starting address of the row being processed */
  632. px = pOut + i;
  633. /* For every row wise process, the column loop counter is to be initiated */
  634. col = numColsB;
  635. /* For every row wise process, the pIn2 pointer is set
  636. ** to the starting address of the pSrcB data */
  637. pIn2 = pSrcB->pData;
  638. j = 0U;
  639. /* Column loop */
  640. do
  641. {
  642. /* Set the variable sum, that acts as accumulator, to zero */
  643. sum = 0.0f;
  644. /* Initiate the pointer pIn1 to point to the starting address of the column being processed */
  645. pIn1 = pInA;
  646. acc0 = vdupq_n_f32(0.0);
  647. /* Compute 4 MACs simultaneously. */
  648. colCnt = numColsA >> 2U;
  649. /* Matrix multiplication */
  650. while (colCnt > 0U)
  651. {
  652. /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
  653. a0V = vld1q_f32(pIn1); // load & separate real/imag pSrcA (de-interleave 2)
  654. pIn1 += 4;
  655. temp = vsetq_lane_f32(*pIn2,temp,0);
  656. pIn2 += numColsB;
  657. temp = vsetq_lane_f32(*pIn2,temp,1);
  658. pIn2 += numColsB;
  659. temp = vsetq_lane_f32(*pIn2,temp,2);
  660. pIn2 += numColsB;
  661. temp = vsetq_lane_f32(*pIn2,temp,3);
  662. pIn2 += numColsB;
  663. acc0 = vmlaq_f32(acc0,a0V,temp);
  664. /* Decrement the loop count */
  665. colCnt--;
  666. }
  667. accum = vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0));
  668. sum += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
  669. /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
  670. ** No loop unrolling is used. */
  671. colCnt = numColsA % 0x4U;
  672. while (colCnt > 0U)
  673. {
  674. /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
  675. sum += *pIn1++ * (*pIn2);
  676. pIn2 += numColsB;
  677. /* Decrement the loop counter */
  678. colCnt--;
  679. }
  680. /* Store the result in the destination buffer */
  681. *px++ = sum;
  682. /* Update the pointer pIn2 to point to the starting address of the next column */
  683. j++;
  684. pIn2 = pSrcB->pData + j;
  685. /* Decrement the column loop counter */
  686. col--;
  687. } while (col > 0U);
  688. /* Update the pointer pInA to point to the starting address of the next row */
  689. i = i + numColsB;
  690. pInA = pInA + numColsA;
  691. /* Decrement the row loop counter */
  692. rowCnt--;
  693. }
  694. /* Set status as ARM_MATH_SUCCESS */
  695. status = ARM_MATH_SUCCESS;
  696. }
  697. /* Return to application */
  698. return (status);
  699. }
  700. #else
  701. arm_status arm_mat_mult_f32(
  702. const arm_matrix_instance_f32 * pSrcA,
  703. const arm_matrix_instance_f32 * pSrcB,
  704. arm_matrix_instance_f32 * pDst)
  705. {
  706. float32_t *pIn1 = pSrcA->pData; /* Input data matrix pointer A */
  707. float32_t *pIn2 = pSrcB->pData; /* Input data matrix pointer B */
  708. float32_t *pInA = pSrcA->pData; /* Input data matrix pointer A */
  709. float32_t *pInB = pSrcB->pData; /* Input data matrix pointer B */
  710. float32_t *pOut = pDst->pData; /* Output data matrix pointer */
  711. float32_t *px; /* Temporary output data matrix pointer */
  712. float32_t sum; /* Accumulator */
  713. uint16_t numRowsA = pSrcA->numRows; /* Number of rows of input matrix A */
  714. uint16_t numColsB = pSrcB->numCols; /* Number of columns of input matrix B */
  715. uint16_t numColsA = pSrcA->numCols; /* Number of columns of input matrix A */
  716. uint32_t col, i = 0U, row = numRowsA, colCnt; /* Loop counters */
  717. arm_status status; /* Status of matrix multiplication */
  718. #ifdef ARM_MATH_MATRIX_CHECK
  719. /* Check for matrix mismatch condition */
  720. if ((pSrcA->numCols != pSrcB->numRows) ||
  721. (pSrcA->numRows != pDst->numRows) ||
  722. (pSrcB->numCols != pDst->numCols) )
  723. {
  724. /* Set status as ARM_MATH_SIZE_MISMATCH */
  725. status = ARM_MATH_SIZE_MISMATCH;
  726. }
  727. else
  728. #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
  729. {
  730. /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
  731. /* row loop */
  732. do
  733. {
  734. /* Output pointer is set to starting address of row being processed */
  735. px = pOut + i;
  736. /* For every row wise process, column loop counter is to be initiated */
  737. col = numColsB;
  738. /* For every row wise process, pIn2 pointer is set to starting address of pSrcB data */
  739. pIn2 = pSrcB->pData;
  740. /* column loop */
  741. do
  742. {
  743. /* Set the variable sum, that acts as accumulator, to zero */
  744. sum = 0.0f;
  745. /* Initialize pointer pIn1 to point to starting address of column being processed */
  746. pIn1 = pInA;
  747. #if defined (ARM_MATH_LOOPUNROLL)
  748. /* Loop unrolling: Compute 4 MACs at a time. */
  749. colCnt = numColsA >> 2U;
  750. /* matrix multiplication */
  751. while (colCnt > 0U)
  752. {
  753. /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
  754. /* Perform the multiply-accumulates */
  755. sum += *pIn1++ * *pIn2;
  756. pIn2 += numColsB;
  757. sum += *pIn1++ * *pIn2;
  758. pIn2 += numColsB;
  759. sum += *pIn1++ * *pIn2;
  760. pIn2 += numColsB;
  761. sum += *pIn1++ * *pIn2;
  762. pIn2 += numColsB;
  763. /* Decrement loop counter */
  764. colCnt--;
  765. }
  766. /* Loop unrolling: Compute remaining MACs */
  767. colCnt = numColsA % 0x4U;
  768. #else
  769. /* Initialize cntCnt with number of columns */
  770. colCnt = numColsA;
  771. #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
  772. while (colCnt > 0U)
  773. {
  774. /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
  775. /* Perform the multiply-accumulates */
  776. sum += *pIn1++ * *pIn2;
  777. pIn2 += numColsB;
  778. /* Decrement loop counter */
  779. colCnt--;
  780. }
  781. /* Store result in destination buffer */
  782. *px++ = sum;
  783. /* Decrement column loop counter */
  784. col--;
  785. /* Update pointer pIn2 to point to starting address of next column */
  786. pIn2 = pInB + (numColsB - col);
  787. } while (col > 0U);
  788. /* Update pointer pInA to point to starting address of next row */
  789. i = i + numColsB;
  790. pInA = pInA + numColsA;
  791. /* Decrement row loop counter */
  792. row--;
  793. } while (row > 0U);
  794. /* Set status as ARM_MATH_SUCCESS */
  795. status = ARM_MATH_SUCCESS;
  796. }
  797. /* Return to application */
  798. return (status);
  799. }
  800. #endif /* #if defined(ARM_MATH_NEON) */
  801. #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
  802. /**
  803. * @} end of MatrixMult group
  804. */