arm_mat_mult_f16.c 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763
  1. /* ----------------------------------------------------------------------
  2. * Project: CMSIS DSP Library
  3. * Title: arm_mat_mult_f16.c
  4. * Description: Floating-point matrix multiplication
  5. *
  6. * $Date: 23 April 2021
  7. * $Revision: V1.9.0
  8. *
  9. * Target Processor: Cortex-M and Cortex-A cores
  10. * -------------------------------------------------------------------- */
  11. /*
  12. * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
  13. *
  14. * SPDX-License-Identifier: Apache-2.0
  15. *
  16. * Licensed under the Apache License, Version 2.0 (the License); you may
  17. * not use this file except in compliance with the License.
  18. * You may obtain a copy of the License at
  19. *
  20. * www.apache.org/licenses/LICENSE-2.0
  21. *
  22. * Unless required by applicable law or agreed to in writing, software
  23. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  24. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  25. * See the License for the specific language governing permissions and
  26. * limitations under the License.
  27. */
  28. #include "dsp/matrix_functions_f16.h"
  29. #if defined(ARM_FLOAT16_SUPPORTED)
  30. /**
  31. * @ingroup groupMatrix
  32. */
  33. /**
  34. * @addtogroup MatrixMult
  35. * @{
  36. */
  37. /**
  38. * @brief Floating-point matrix multiplication.
  39. * @param[in] *pSrcA points to the first input matrix structure
  40. * @param[in] *pSrcB points to the second input matrix structure
  41. * @param[out] *pDst points to output matrix structure
  42. * @return The function returns either
  43. * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking.
  44. */
  45. #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
  46. __STATIC_FORCEINLINE arm_status arm_mat_mult_f16_2x2_mve(
  47. const arm_matrix_instance_f16 *pSrcA,
  48. const arm_matrix_instance_f16 *pSrcB,
  49. arm_matrix_instance_f16 *pDst)
  50. {
  51. static const uint16_t offsetA[8] = { 0, 0, 2, 2, 0, 0, 2, 2 };
  52. /* offsetB allows to read and duplicate 1 row of B */
  53. static const uint16_t offsetB[8] = { 0, 1, 0, 1, 0, 1, 0, 1 };
  54. uint16x8_t vecOffsA, vecOffsB;
  55. f16x8_t vecInA, vecInB, vecDst;
  56. float16_t *pOut = pDst->pData; /* output data matrix pointer */
  57. /*
  58. * load initial offsets
  59. */
  60. vecOffsA = vldrhq_u16((uint16_t const *) offsetA);
  61. vecOffsB = vldrhq_u16((uint16_t const *) offsetB);
  62. /*
  63. * load {a00 a00 a10 a10 x x x x }
  64. */
  65. vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
  66. /*
  67. * load {b00 b01 b00 b01 x x x x }
  68. */
  69. vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
  70. /*
  71. * { a00 b00 a00 b01
  72. * a10 b00 a10 b01
  73. * x x
  74. * x x }
  75. */
  76. vecDst = vmulq(vecInA, vecInB);
  77. /*
  78. * move to 2nd column of matrix A
  79. */
  80. vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 1);
  81. /*
  82. * load {a01 a01 a11 a11 x x x x}
  83. */
  84. vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
  85. /*
  86. * move to next B row
  87. */
  88. vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 2);
  89. /*
  90. * load {b10, b11, b10, b11, x x x x }
  91. */
  92. vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
  93. /*
  94. * { a00 b00 + a01 b10 a00 b01 + a01 b11
  95. * a10 b00 + a11 b10 a10 b01 + a11 b11
  96. * x x
  97. * x x }
  98. */
  99. vecDst = vfmaq(vecDst, vecInA, vecInB);
  100. mve_pred16_t p0 = vctp16q(2*2);
  101. /*
  102. * Store the result in the destination buffer
  103. * (lower half of the vector)
  104. */
  105. vstrhq_p(pOut, vecDst, p0);
  106. return (ARM_MATH_SUCCESS);
  107. }
  108. __STATIC_FORCEINLINE arm_status arm_mat_mult_f16_3x3_mve(
  109. const arm_matrix_instance_f16 *pSrcA,
  110. const arm_matrix_instance_f16 *pSrcB,
  111. arm_matrix_instance_f16 *pDst)
  112. {
  113. static const uint16_t offsetA[8] = { 0, 0, 0, 3, 3, 3, 6, 6 };
  114. /* offsetB allows to read and duplicate 1 row of B */
  115. static const uint16_t offsetB[8] = { 0, 1, 2, 0, 1, 2, 0, 1 };
  116. uint16x8_t vecOffsA, vecOffsB;
  117. f16x8_t vecInA, vecInB, vecDst;
  118. float16_t *pOut = pDst->pData; /* output data matrix pointer */
  119. /*
  120. * load initial offsets
  121. */
  122. vecOffsA = vldrhq_u16((uint16_t const *) offsetA);
  123. vecOffsB = vldrhq_u16((uint16_t const *) offsetB);
  124. /*
  125. * load {a00 a00 a00 a10 a10 a10 a20 a20}
  126. */
  127. vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
  128. /*
  129. * load {b00 b01 b02 b00 b01 b02 b00 b01}
  130. */
  131. vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
  132. /*
  133. * { a00 b00 a00 b01 a00 b02
  134. * a10 b00 a10 b01 a10 b02
  135. * a20 b00 a20 b01}
  136. */
  137. vecDst = vmulq(vecInA, vecInB);
  138. /*
  139. * move to 2nd column of matrix A
  140. */
  141. vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 1);
  142. /*
  143. * load {a01 a01 a01 a11 a11 a11 a21 a21}
  144. */
  145. vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
  146. /*
  147. * move to next B row
  148. */
  149. vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 3);
  150. /*
  151. * load {b10, b11, b12, b10, b11, b12, b10, b11}
  152. */
  153. vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
  154. /*
  155. * { a00 b00 + a01 b10 a00 b01 + a01 b11 a00 b02 + a01 b12
  156. * a10 b00 + a11 b10 a10 b01 + a11 b11 a10 b02 + a11 b12
  157. * a20 b00 + a21 b10 a20 b01 + a21 b11 }
  158. */
  159. vecDst = vfmaq(vecDst, vecInA, vecInB);
  160. /*
  161. * move to 3rd column of matrix A
  162. */
  163. vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 1);
  164. /*
  165. * load {a02 a02 a02 a12 a12 a12 a22 a22}
  166. */
  167. vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
  168. /*
  169. * move to next B row
  170. */
  171. vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 3);
  172. /*
  173. * load {b20, b21, b22, b20, b21, b22, b20, b21}
  174. */
  175. vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
  176. /*
  177. * {a00 b00 + a01 b10 + a02 b20 a00 b01 + a01 b11 + a02 b21 a00 b02 + a01 b12 + a02 b22},
  178. * a10 b00 + a11 b10 + a12 b20 a10 b01 + a11 b11 + a12 b21 a10 b02 + a11 b12 + a12 b22},
  179. * a20 b00 + a21 b10 + a22 b20 a20 b01 + a21 b11 + a22 b21 }
  180. */
  181. vecDst = vfmaq(vecDst, vecInA, vecInB);
  182. /*
  183. * Store the result in the destination buffer
  184. */
  185. vst1q(pOut, vecDst); pOut += 8;
  186. /* last element computed in scalar mode
  187. * a20 b02 + a21 b12 + a22 b22
  188. */
  189. _Float16 * pA = (_Float16 *)pSrcA->pData;
  190. _Float16 * pB = (_Float16 *)pSrcB->pData;
  191. *pOut = pA[2*3] * pB[2] + pA[2*3+1] * pB[3+2] + pA[2*3+2] * pB[2*3+2];
  192. return (ARM_MATH_SUCCESS);
  193. }
  194. __STATIC_FORCEINLINE arm_status arm_mat_mult_f16_4x4_mve(
  195. const arm_matrix_instance_f16 *pSrcA,
  196. const arm_matrix_instance_f16 *pSrcB,
  197. arm_matrix_instance_f16 *pDst)
  198. {
  199. /* offsetA allows to read and duplicate 2 successive column elements of A */
  200. static const uint16_t offsetA[8] = { 0, 0, 0, 0, 4, 4, 4, 4 };
  201. /* offsetB allows to read and duplicate 1 row of B */
  202. static const uint16_t offsetB[8] = { 0, 1, 2, 3, 0, 1, 2, 3 };
  203. uint16x8_t vecOffsA, vecOffsB;
  204. f16x8_t vecInA, vecInB, vecDst0, vecDst1;
  205. float16_t *pOut = pDst->pData; /* output data matrix pointer */
  206. /*
  207. * load initial offsets
  208. */
  209. vecOffsA = vldrhq_u16((uint16_t const *) offsetA);
  210. vecOffsB = vldrhq_u16((uint16_t const *) offsetB);
  211. /*
  212. * load {a00 a00 a00 a00 a10 a10 a10 a10}
  213. */
  214. vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
  215. /*
  216. * load {b00 b01 b02 b03 b00 b01 b02 b03}
  217. */
  218. vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
  219. /*
  220. * { a00 b00 a00 b01 a00 b02 a00 b03
  221. * a10 b00 a10 b01 a10 b02 a10 b03 }
  222. */
  223. vecDst0 = vmulq(vecInA, vecInB);
  224. /*
  225. * jump 2 x A rows (2nd half of matrix)
  226. */
  227. vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 8);
  228. /*
  229. * load {a20 a20 a20 a20 a30 a30 a30 a30}
  230. */
  231. vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
  232. /*
  233. * { a20 b00 a20 b01 a20 b02 a20 b03
  234. * a30 b00 a30 b01 a30 b02 + a31 b12 }
  235. */
  236. vecDst1 = vmulq(vecInA, vecInB);
  237. /*
  238. * rewind back to top half of the A matrix (2nd column)
  239. */
  240. vecOffsA = vsubq(vecOffsA, (uint16_t) 7);
  241. /*
  242. * load {a01 a01 a01 a01 a11 a11 a11 a11}
  243. */
  244. vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
  245. /*
  246. * move to next B row
  247. */
  248. vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 4);
  249. /*
  250. * load {b10, b11, b12, b13, b10, b11, b12, b13}
  251. */
  252. vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
  253. /*
  254. * { a00 b00 + a01 b10 a00 b01 + a01 b11 a00 b02 + a01 b12 a00 b03 + a01 b13
  255. * a10 b00 + a11 b10 a10 b01 + a11 b11 a10 b02 + a11 b12 a10 b03 + a11 b13 }
  256. */
  257. vecDst0 = vfmaq(vecDst0, vecInA, vecInB);
  258. /*
  259. * jump 2 x A rows (2nd half of matrix)
  260. */
  261. vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 8);
  262. /*
  263. * load {a21 a21 a21 a21 a31 a31 a31 a31}
  264. */
  265. vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
  266. /*
  267. * {a20 b00 + a21 b10 a20 b01 + a21 b11 a20 b02 + a21 b12 a20 b03 + a21 b13
  268. * a30 b00 + a31 b10 a30 b01 + a31 b11 a30 b02 + a31 b12 a30 b03 + a31 b13 }
  269. */
  270. vecDst1 = vfmaq(vecDst1, vecInA, vecInB);
  271. /*
  272. * rewind back to top half of the A matrix (3rd column)
  273. */
  274. vecOffsA = vsubq(vecOffsA, (uint16_t) 7);
  275. /*
  276. * load {a02 a02 a02 a02 a12 a12 a12 a12}
  277. */
  278. vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
  279. /*
  280. * move to next B row
  281. */
  282. vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 4);
  283. /*
  284. * load {b20, b21, b22, b23, b20, b21, b22, b23}
  285. */
  286. vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
  287. /*
  288. * { a00 b00 + a01 b10 + a02 b20 a00 b01 + a01 b11 + a02 b21 a00 b02 + a01 b12 + a02 b22 a00 b03 + a01 b13 + a02 b23
  289. * a10 b00 + a11 b10 + a12 b20 a10 b01 + a11 b11 + a12 b21 a10 b02 + a11 b12 + a12 b22 a10 b03 + a11 b13 + a12 b23 }
  290. */
  291. vecDst0 = vfmaq(vecDst0, vecInA, vecInB);
  292. /*
  293. * jump 2 x A rows
  294. */
  295. vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 8);
  296. /*
  297. * load {a22 a22 a22 a22 a32 a32 a32 a32}
  298. */
  299. vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
  300. /*
  301. * {a20 b00 + a21 b10 + a22 b20 a20 b01 + a21 b11 + a22 b21 a20 b02 + a21 b12 + a22 b22 a20 b03 + a21 b13 + a22 b23
  302. * a30 b00 + a31 b10 + a32 b20 a30 b01 + a31 b11 + a32 b21 a30 b02 + a31 b12 + a32 b22 a30 b03 + a31 b13 + a32 b23 }
  303. */
  304. vecDst1 = vfmaq(vecDst1, vecInA, vecInB);
  305. /*
  306. * rewind back to top half of the A matrix (4th column)
  307. */
  308. vecOffsA = vsubq(vecOffsA, (uint16_t) 7);
  309. /*
  310. * load {a03 a03 a03 a03 a13 a13 a13 a13}
  311. */
  312. vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
  313. /*
  314. * move to next B row
  315. */
  316. vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 4);
  317. /*
  318. * load {b30, b31, b32, b33, b30, b31, b32, b33}
  319. */
  320. vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
  321. /*
  322. * { a00 b00 +...+ a03 b30, a00 b01 +...+ a03 b31, a00 b02 +...+ a03 b32, a00 b03 +...+ a03 b33
  323. * a10 b00 +...+ a13 b30, a10 b01 +...+ a13 b31, a10 b02 +...+ a13 b32, a10 b03 +...+ a13 b33 }
  324. */
  325. vecDst0 = vfmaq(vecDst0, vecInA, vecInB);
  326. /*
  327. * jump 2 x A rows
  328. */
  329. vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 8);
  330. /*
  331. * load {a23 a23 a23 a23 a33 a33 a33 a33}
  332. */
  333. vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
  334. /*
  335. * {a20 b00 +...+ a23 b30, a20 b01 +...+ a23 b31, a20 b02 +...+ a23 b32, a20 b03 +...+ a23 b33
  336. * a30 b00 +...+ a33 b30, a30 b01 +...+ a33 b31, a30 b02 +...+ a33 b32, a30 b03 +...+ a33 b33 }
  337. */
  338. vecDst1 = vfmaq(vecDst1, vecInA, vecInB);
  339. /*
  340. * Store the result in the destination buffer
  341. */
  342. vst1q(pOut, vecDst0); pOut += 8;
  343. vst1q(pOut, vecDst1);
  344. return (ARM_MATH_SUCCESS);
  345. }
  346. arm_status arm_mat_mult_f16(
  347. const arm_matrix_instance_f16 * pSrcA,
  348. const arm_matrix_instance_f16 * pSrcB,
  349. arm_matrix_instance_f16 * pDst)
  350. {
  351. float16_t *pInB = pSrcB->pData; /* input data matrix pointer B */
  352. float16_t *pInA = pSrcA->pData; /* input data matrix pointer A */
  353. float16_t *pOut = pDst->pData; /* output data matrix pointer */
  354. int numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
  355. int numColsB = pSrcB->numCols; /* number of columns of input matrix B */
  356. int numColsA = pSrcA->numCols; /* number of columns of input matrix A */
  357. uint32_t blkCnt; /* loop counters */
  358. int i;
  359. #ifdef ARM_MATH_MATRIX_CHECK
  360. /* Check for matrix mismatch condition */
  361. if ((pSrcA->numCols != pSrcB->numRows) ||
  362. (pSrcA->numRows != pDst->numRows) ||
  363. (pSrcB->numCols != pDst->numCols) )
  364. {
  365. /* Set status as ARM_MATH_SIZE_MISMATCH */
  366. return(ARM_MATH_SIZE_MISMATCH);
  367. }
  368. else
  369. #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
  370. {
  371. /* small squared matrix specialized routines */
  372. if(numRowsA == numColsB && numColsB == numColsA) {
  373. if(numRowsA == 2)
  374. return arm_mat_mult_f16_2x2_mve(pSrcA, pSrcB, pDst);
  375. else if(numRowsA == 3)
  376. return arm_mat_mult_f16_3x3_mve(pSrcA, pSrcB, pDst);
  377. else if(numRowsA == 4)
  378. return arm_mat_mult_f16_4x4_mve(pSrcA, pSrcB, pDst);
  379. }
  380. /* main loop process 4 rows */
  381. i = numRowsA / 4;
  382. while(i > 0)
  383. {
  384. float16_t *pInA0, *pInA1, *pInA2, *pInA3;
  385. float16_t *pInB0;
  386. float16_t *pOut0, *pOut1, *pOut2, *pOut3;
  387. f16x8_t vecMac0, vecMac1, vecMac2, vecMac3;
  388. f16x8_t vecInB;
  389. /* pointers to 4 consecutive output rows */
  390. pOut0 = pOut;
  391. pOut1 = pOut0 + numColsB;
  392. pOut2 = pOut1 + numColsB;
  393. pOut3 = pOut2 + numColsB;
  394. pInB0 = pInB;
  395. int k = numColsB >> 3;
  396. while(k > 0)
  397. {
  398. /* pointers to 4 consecutive Matrix A rows */
  399. pInA0 = pInA;
  400. pInA1 = pInA0 + numColsA;
  401. pInA2 = pInA1 + numColsA;
  402. pInA3 = pInA2 + numColsA;
  403. vecMac0 = vdupq_n_f16(0.0f16);
  404. vecMac1 = vdupq_n_f16(0.0f16);
  405. vecMac2 = vdupq_n_f16(0.0f16);
  406. vecMac3 = vdupq_n_f16(0.0f16);
  407. blkCnt = numColsA;
  408. while (blkCnt > 0U)
  409. {
  410. /*
  411. * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3..., bi,4n+7}
  412. */
  413. vecInB = *(f16x8_t *)pInB0; /* vldrhq_f16(pInB0, 0); */
  414. vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
  415. vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
  416. vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
  417. vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
  418. pInB0 = pInB0 + numColsB;
  419. /*
  420. * Decrement the blockSize loop counter
  421. */
  422. blkCnt--;
  423. }
  424. /* Store the results (4 x 8 block) in the destination buffer */
  425. vst1q(pOut0, vecMac0); pOut0 += 8;
  426. vst1q(pOut1, vecMac1); pOut1 += 8;
  427. vst1q(pOut2, vecMac2); pOut2 += 8;
  428. vst1q(pOut3, vecMac3); pOut3 += 8;
  429. /*
  430. * rewind
  431. */
  432. pInB0 -= (numColsB * numColsA) - 8;
  433. k--;
  434. }
  435. int colBLeft = numColsB & 7;
  436. if (colBLeft)
  437. {
  438. pInA0 = pInA;
  439. pInA1 = pInA0 + numColsA;
  440. pInA2 = pInA1 + numColsA;
  441. pInA3 = pInA2 + numColsA;
  442. mve_pred16_t p0 = vctp16q(colBLeft);
  443. vecMac0 = vdupq_n_f16(0.0f16);
  444. vecMac1 = vdupq_n_f16(0.0f16);
  445. vecMac2 = vdupq_n_f16(0.0f16);
  446. vecMac3 = vdupq_n_f16(0.0f16);
  447. blkCnt = numColsA;
  448. while (blkCnt > 0U)
  449. {
  450. /*
  451. * load {bi,4n+0, bi,4n+1, bi,4n+2, ..bi,4n+colBLeft-1, 0, ..}
  452. */
  453. vecInB = vldrhq_z_f16(pInB0, p0);
  454. vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
  455. vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
  456. vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
  457. vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
  458. pInB0 = pInB0 + numColsB;
  459. /*
  460. * Decrement the blockSize loop counter
  461. */
  462. blkCnt--;
  463. }
  464. /* Store the results (4 x colBLeft block) in the destination buffer */
  465. vstrhq_p_f16(pOut0, vecMac0, p0);
  466. vstrhq_p_f16(pOut1, vecMac1, p0);
  467. vstrhq_p_f16(pOut2, vecMac2, p0);
  468. vstrhq_p_f16(pOut3, vecMac3, p0);
  469. }
  470. pInA += 4 * numColsA;
  471. pOut += 4 * numColsB;
  472. i--;
  473. }
  474. /*
  475. * non multiple of 4 rows for Matrix A
  476. * process single row
  477. */
  478. if (numRowsA & 3)
  479. {
  480. i = numRowsA & 3;
  481. do
  482. {
  483. float16_t *pInA0;
  484. float16_t *pInB0;
  485. float16_t *pOut0;
  486. f16x8_t vecInB;
  487. f16x8_t vecMac0;
  488. pOut0 = pOut;
  489. pInB0 = pInB;
  490. int k = numColsB >> 3;
  491. while(k > 0)
  492. {
  493. pInA0 = pInA;
  494. vecMac0 = vdupq_n_f16(0.0f16);
  495. blkCnt = numColsA;
  496. while (blkCnt > 0U)
  497. {
  498. /*
  499. * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3, ...bi,4n+7}
  500. */
  501. vecInB = *(f16x8_t *)pInB0; /* vldrhq_f16(pInB0, 0); */
  502. vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
  503. pInB0 = pInB0 + numColsB;
  504. /*
  505. * Decrement the blockSize loop counter
  506. */
  507. blkCnt--;
  508. }
  509. /* Store the results (1 x 8 block) in the destination buffer */
  510. vst1q(pOut0, vecMac0); pOut0 += 8;
  511. /*
  512. * rewind
  513. */
  514. pInB0 -= (numColsB * numColsA) - 8;
  515. k--;
  516. }
  517. int colBLeft = numColsB & 7;
  518. if (colBLeft)
  519. {
  520. pInA0 = pInA;
  521. mve_pred16_t p0 = vctp16q(colBLeft);
  522. vecMac0 = vdupq_n_f16(0.0f16);
  523. blkCnt = numColsA;
  524. while (blkCnt > 0U)
  525. {
  526. /*
  527. * load {bi,4n+0, bi,4n+1, bi,4n+2, ..., bi,4n+colBLeft, 0, ...}
  528. */
  529. vecInB = vldrhq_z_f16(pInB0, p0);
  530. vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
  531. pInB0 = pInB0 + numColsB;
  532. /*
  533. * Decrement the blockSize loop counter
  534. */
  535. blkCnt--;
  536. }
  537. /* Store the results (1 x colBLeft block) in the destination buffer */
  538. vstrhq_p_f16(pOut0, vecMac0, p0);
  539. }
  540. pInA += 1 * numColsA;
  541. pOut += 1 * numColsB;
  542. }
  543. while (--i);
  544. }
  545. /*
  546. * Return to application
  547. */
  548. return (ARM_MATH_SUCCESS);
  549. }
  550. }
  551. #else
  552. arm_status arm_mat_mult_f16(
  553. const arm_matrix_instance_f16 * pSrcA,
  554. const arm_matrix_instance_f16 * pSrcB,
  555. arm_matrix_instance_f16 * pDst)
  556. {
  557. float16_t *pIn1 = pSrcA->pData; /* Input data matrix pointer A */
  558. float16_t *pIn2 = pSrcB->pData; /* Input data matrix pointer B */
  559. float16_t *pInA = pSrcA->pData; /* Input data matrix pointer A */
  560. float16_t *pInB = pSrcB->pData; /* Input data matrix pointer B */
  561. float16_t *pOut = pDst->pData; /* Output data matrix pointer */
  562. float16_t *px; /* Temporary output data matrix pointer */
  563. _Float16 sum; /* Accumulator */
  564. uint16_t numRowsA = pSrcA->numRows; /* Number of rows of input matrix A */
  565. uint16_t numColsB = pSrcB->numCols; /* Number of columns of input matrix B */
  566. uint16_t numColsA = pSrcA->numCols; /* Number of columns of input matrix A */
  567. uint32_t col, i = 0U, row = numRowsA, colCnt; /* Loop counters */
  568. arm_status status; /* Status of matrix multiplication */
  569. #ifdef ARM_MATH_MATRIX_CHECK
  570. /* Check for matrix mismatch condition */
  571. if ((pSrcA->numCols != pSrcB->numRows) ||
  572. (pSrcA->numRows != pDst->numRows) ||
  573. (pSrcB->numCols != pDst->numCols) )
  574. {
  575. /* Set status as ARM_MATH_SIZE_MISMATCH */
  576. status = ARM_MATH_SIZE_MISMATCH;
  577. }
  578. else
  579. #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
  580. {
  581. /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
  582. /* row loop */
  583. do
  584. {
  585. /* Output pointer is set to starting address of row being processed */
  586. px = pOut + i;
  587. /* For every row wise process, column loop counter is to be initiated */
  588. col = numColsB;
  589. /* For every row wise process, pIn2 pointer is set to starting address of pSrcB data */
  590. pIn2 = pSrcB->pData;
  591. /* column loop */
  592. do
  593. {
  594. /* Set the variable sum, that acts as accumulator, to zero */
  595. sum = 0.0f16;
  596. /* Initialize pointer pIn1 to point to starting address of column being processed */
  597. pIn1 = pInA;
  598. #if defined (ARM_MATH_LOOPUNROLL)
  599. /* Loop unrolling: Compute 4 MACs at a time. */
  600. colCnt = numColsA >> 2U;
  601. /* matrix multiplication */
  602. while (colCnt > 0U)
  603. {
  604. /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
  605. /* Perform the multiply-accumulates */
  606. sum += (_Float16)*pIn1++ * (_Float16)*pIn2;
  607. pIn2 += numColsB;
  608. sum += (_Float16)*pIn1++ * (_Float16)*pIn2;
  609. pIn2 += numColsB;
  610. sum += (_Float16)*pIn1++ * (_Float16)*pIn2;
  611. pIn2 += numColsB;
  612. sum += (_Float16)*pIn1++ * (_Float16)*pIn2;
  613. pIn2 += numColsB;
  614. /* Decrement loop counter */
  615. colCnt--;
  616. }
  617. /* Loop unrolling: Compute remaining MACs */
  618. colCnt = numColsA % 0x4U;
  619. #else
  620. /* Initialize cntCnt with number of columns */
  621. colCnt = numColsA;
  622. #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
  623. while (colCnt > 0U)
  624. {
  625. /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
  626. /* Perform the multiply-accumulates */
  627. sum += (_Float16)*pIn1++ * (_Float16)*pIn2;
  628. pIn2 += numColsB;
  629. /* Decrement loop counter */
  630. colCnt--;
  631. }
  632. /* Store result in destination buffer */
  633. *px++ = sum;
  634. /* Decrement column loop counter */
  635. col--;
  636. /* Update pointer pIn2 to point to starting address of next column */
  637. pIn2 = pInB + (numColsB - col);
  638. } while (col > 0U);
  639. /* Update pointer pInA to point to starting address of next row */
  640. i = i + numColsB;
  641. pInA = pInA + numColsA;
  642. /* Decrement row loop counter */
  643. row--;
  644. } while (row > 0U);
  645. /* Set status as ARM_MATH_SUCCESS */
  646. status = ARM_MATH_SUCCESS;
  647. }
  648. /* Return to application */
  649. return (status);
  650. }
  651. #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
  652. /**
  653. * @} end of MatrixMult group
  654. */
  655. #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */