arm_mat_vec_mult_f16.c 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. /* ----------------------------------------------------------------------
  2. * Project: CMSIS DSP Library
  3. * Title: arm_mat_vec_mult_f16.c
  4. * Description: Floating-point matrix and vector multiplication
  5. *
  6. * $Date: 23 April 2021
  7. *
  8. * $Revision: V1.9.0
  9. *
  10. * Target Processor: Cortex-M and Cortex-A cores
  11. * -------------------------------------------------------------------- */
  12. /*
  13. * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
  14. *
  15. * SPDX-License-Identifier: Apache-2.0
  16. *
  17. * Licensed under the Apache License, Version 2.0 (the License); you may
  18. * not use this file except in compliance with the License.
  19. * You may obtain a copy of the License at
  20. *
  21. * www.apache.org/licenses/LICENSE-2.0
  22. *
  23. * Unless required by applicable law or agreed to in writing, software
  24. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  25. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  26. * See the License for the specific language governing permissions and
  27. * limitations under the License.
  28. */
  29. #include "dsp/matrix_functions_f16.h"
  30. #if defined(ARM_FLOAT16_SUPPORTED)
  31. /**
  32. * @ingroup groupMatrix
  33. */
  34. /**
  35. * @addtogroup MatrixVectMult
  36. * @{
  37. */
  38. /**
  39. * @brief Floating-point matrix and vector multiplication.
  40. * @param[in] *pSrcMat points to the input matrix structure
  41. * @param[in] *pVec points to input vector
  42. * @param[out] *pDst points to output vector
  43. */
  44. #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
  45. #include "arm_helium_utils.h"
  46. void arm_mat_vec_mult_f16(
  47. const arm_matrix_instance_f16 *pSrcMat,
  48. const float16_t *pSrcVec,
  49. float16_t *pDstVec)
  50. {
  51. uint32_t numRows = pSrcMat->numRows;
  52. uint32_t numCols = pSrcMat->numCols;
  53. const float16_t *pSrcA = pSrcMat->pData;
  54. const float16_t *pInA0;
  55. const float16_t *pInA1;
  56. float16_t *px;
  57. int32_t row;
  58. uint32_t blkCnt; /* loop counters */
  59. row = numRows;
  60. px = pDstVec;
  61. /*
  62. * compute 4 rows in parallel
  63. */
  64. while (row >= 4)
  65. {
  66. const float16_t *pInA2, *pInA3;
  67. float16_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec, *pInVec;
  68. f16x8_t vecIn, acc0, acc1, acc2, acc3;
  69. float16_t const *pSrcVecPtr = pSrcVec;
  70. /*
  71. * Initialize the pointers to 4 consecutive MatrixA rows
  72. */
  73. pInA0 = pSrcA;
  74. pInA1 = pInA0 + numCols;
  75. pInA2 = pInA1 + numCols;
  76. pInA3 = pInA2 + numCols;
  77. /*
  78. * Initialize the vector pointer
  79. */
  80. pInVec = pSrcVecPtr;
  81. /*
  82. * reset accumulators
  83. */
  84. acc0 = vdupq_n_f16(0.0f);
  85. acc1 = vdupq_n_f16(0.0f);
  86. acc2 = vdupq_n_f16(0.0f);
  87. acc3 = vdupq_n_f16(0.0f);
  88. pSrcA0Vec = pInA0;
  89. pSrcA1Vec = pInA1;
  90. pSrcA2Vec = pInA2;
  91. pSrcA3Vec = pInA3;
  92. blkCnt = numCols >> 3;
  93. while (blkCnt > 0U)
  94. {
  95. f16x8_t vecA;
  96. vecIn = vld1q(pInVec);
  97. pInVec += 8;
  98. vecA = vld1q(pSrcA0Vec);
  99. pSrcA0Vec += 8;
  100. acc0 = vfmaq(acc0, vecIn, vecA);
  101. vecA = vld1q(pSrcA1Vec);
  102. pSrcA1Vec += 8;
  103. acc1 = vfmaq(acc1, vecIn, vecA);
  104. vecA = vld1q(pSrcA2Vec);
  105. pSrcA2Vec += 8;
  106. acc2 = vfmaq(acc2, vecIn, vecA);
  107. vecA = vld1q(pSrcA3Vec);
  108. pSrcA3Vec += 8;
  109. acc3 = vfmaq(acc3, vecIn, vecA);
  110. blkCnt--;
  111. }
  112. /*
  113. * tail
  114. * (will be merged thru tail predication)
  115. */
  116. blkCnt = numCols & 7;
  117. if (blkCnt > 0U)
  118. {
  119. mve_pred16_t p0 = vctp16q(blkCnt);
  120. f16x8_t vecA;
  121. vecIn = vldrhq_z_f16(pInVec, p0);
  122. vecA = vld1q(pSrcA0Vec);
  123. acc0 = vfmaq(acc0, vecIn, vecA);
  124. vecA = vld1q(pSrcA1Vec);
  125. acc1 = vfmaq(acc1, vecIn, vecA);
  126. vecA = vld1q(pSrcA2Vec);
  127. acc2 = vfmaq(acc2, vecIn, vecA);
  128. vecA = vld1q(pSrcA3Vec);
  129. acc3 = vfmaq(acc3, vecIn, vecA);
  130. }
  131. /*
  132. * Sum the partial parts
  133. */
  134. *px++ = vecAddAcrossF16Mve(acc0);
  135. *px++ = vecAddAcrossF16Mve(acc1);
  136. *px++ = vecAddAcrossF16Mve(acc2);
  137. *px++ = vecAddAcrossF16Mve(acc3);
  138. pSrcA += numCols * 4;
  139. /*
  140. * Decrement the row loop counter
  141. */
  142. row -= 4;
  143. }
  144. /*
  145. * compute 2 rows in parrallel
  146. */
  147. if (row >= 2)
  148. {
  149. float16_t const *pSrcA0Vec, *pSrcA1Vec, *pInVec;
  150. f16x8_t vecIn, acc0, acc1;
  151. float16_t const *pSrcVecPtr = pSrcVec;
  152. /*
  153. * Initialize the pointers to 2 consecutive MatrixA rows
  154. */
  155. pInA0 = pSrcA;
  156. pInA1 = pInA0 + numCols;
  157. /*
  158. * Initialize the vector pointer
  159. */
  160. pInVec = pSrcVecPtr;
  161. /*
  162. * reset accumulators
  163. */
  164. acc0 = vdupq_n_f16(0.0f);
  165. acc1 = vdupq_n_f16(0.0f);
  166. pSrcA0Vec = pInA0;
  167. pSrcA1Vec = pInA1;
  168. blkCnt = numCols >> 3;
  169. while (blkCnt > 0U)
  170. {
  171. f16x8_t vecA;
  172. vecIn = vld1q(pInVec);
  173. pInVec += 8;
  174. vecA = vld1q(pSrcA0Vec);
  175. pSrcA0Vec += 8;
  176. acc0 = vfmaq(acc0, vecIn, vecA);
  177. vecA = vld1q(pSrcA1Vec);
  178. pSrcA1Vec += 8;
  179. acc1 = vfmaq(acc1, vecIn, vecA);
  180. blkCnt--;
  181. }
  182. /*
  183. * tail
  184. * (will be merged thru tail predication)
  185. */
  186. blkCnt = numCols & 7;
  187. if (blkCnt > 0U)
  188. {
  189. mve_pred16_t p0 = vctp16q(blkCnt);
  190. f16x8_t vecA;
  191. vecIn = vldrhq_z_f16(pInVec, p0);
  192. vecA = vld1q(pSrcA0Vec);
  193. acc0 = vfmaq(acc0, vecIn, vecA);
  194. vecA = vld1q(pSrcA1Vec);
  195. acc1 = vfmaq(acc1, vecIn, vecA);
  196. }
  197. /*
  198. * Sum the partial parts
  199. */
  200. *px++ = vecAddAcrossF16Mve(acc0);
  201. *px++ = vecAddAcrossF16Mve(acc1);
  202. pSrcA += numCols * 2;
  203. row -= 2;
  204. }
  205. if (row >= 1)
  206. {
  207. f16x8_t vecIn, acc0;
  208. float16_t const *pSrcA0Vec, *pInVec;
  209. float16_t const *pSrcVecPtr = pSrcVec;
  210. /*
  211. * Initialize the pointers to last MatrixA row
  212. */
  213. pInA0 = pSrcA;
  214. /*
  215. * Initialize the vector pointer
  216. */
  217. pInVec = pSrcVecPtr;
  218. /*
  219. * reset accumulators
  220. */
  221. acc0 = vdupq_n_f16(0.0f);
  222. pSrcA0Vec = pInA0;
  223. blkCnt = numCols >> 3;
  224. while (blkCnt > 0U)
  225. {
  226. f16x8_t vecA;
  227. vecIn = vld1q(pInVec);
  228. pInVec += 8;
  229. vecA = vld1q(pSrcA0Vec);
  230. pSrcA0Vec += 8;
  231. acc0 = vfmaq(acc0, vecIn, vecA);
  232. blkCnt--;
  233. }
  234. /*
  235. * tail
  236. * (will be merged thru tail predication)
  237. */
  238. blkCnt = numCols & 7;
  239. if (blkCnt > 0U)
  240. {
  241. mve_pred16_t p0 = vctp16q(blkCnt);
  242. f16x8_t vecA;
  243. vecIn = vldrhq_z_f16(pInVec, p0);
  244. vecA = vld1q(pSrcA0Vec);
  245. acc0 = vfmaq(acc0, vecIn, vecA);
  246. }
  247. /*
  248. * Sum the partial parts
  249. */
  250. *px++ = vecAddAcrossF16Mve(acc0);
  251. }
  252. }
  253. #else
  254. void arm_mat_vec_mult_f16(const arm_matrix_instance_f16 *pSrcMat, const float16_t *pVec, float16_t *pDst)
  255. {
  256. uint32_t numRows = pSrcMat->numRows;
  257. uint32_t numCols = pSrcMat->numCols;
  258. const float16_t *pSrcA = pSrcMat->pData;
  259. const float16_t *pInA1; /* input data matrix pointer A of Q31 type */
  260. const float16_t *pInA2; /* input data matrix pointer A of Q31 type */
  261. const float16_t *pInA3; /* input data matrix pointer A of Q31 type */
  262. const float16_t *pInA4; /* input data matrix pointer A of Q31 type */
  263. const float16_t *pInVec; /* input data matrix pointer B of Q31 type */
  264. float16_t *px; /* Temporary output data matrix pointer */
  265. uint16_t i, row, colCnt; /* loop counters */
  266. float16_t matData, matData2, vecData, vecData2;
  267. /* Process 4 rows at a time */
  268. row = numRows >> 2;
  269. i = 0u;
  270. px = pDst;
  271. /* The following loop performs the dot-product of each row in pSrcA with the vector */
  272. /* row loop */
  273. while (row > 0) {
  274. /* For every row wise process, the pInVec pointer is set
  275. ** to the starting address of the vector */
  276. pInVec = pVec;
  277. /* Initialize accumulators */
  278. float16_t sum1 = 0.0f16;
  279. float16_t sum2 = 0.0f16;
  280. float16_t sum3 = 0.0f16;
  281. float16_t sum4 = 0.0f16;
  282. /* Loop unrolling: process 2 columns per iteration */
  283. colCnt = numCols;
  284. /* Initialize pointers to the starting address of the column being processed */
  285. pInA1 = pSrcA + i;
  286. pInA2 = pInA1 + numCols;
  287. pInA3 = pInA2 + numCols;
  288. pInA4 = pInA3 + numCols;
  289. // Main loop: matrix-vector multiplication
  290. while (colCnt > 0u) {
  291. // Read 2 values from vector
  292. vecData = *(pInVec)++;
  293. // Read 8 values from the matrix - 2 values from each of 4 rows, and do multiply accumulate
  294. matData = *(pInA1)++;
  295. sum1 += (_Float16)matData * (_Float16)vecData;
  296. matData = *(pInA2)++;
  297. sum2 += (_Float16)matData * (_Float16)vecData;
  298. matData = *(pInA3)++;
  299. sum3 += (_Float16)matData * (_Float16)vecData;
  300. matData = *(pInA4)++;
  301. sum4 += (_Float16)matData * (_Float16)vecData;
  302. // Decrement the loop counter
  303. colCnt--;
  304. }
  305. /* Saturate and store the result in the destination buffer */
  306. *px++ = sum1;
  307. *px++ = sum2;
  308. *px++ = sum3;
  309. *px++ = sum4;
  310. i = i + numCols * 4;
  311. /* Decrement the row loop counter */
  312. row--;
  313. }
  314. /* process any remaining rows */
  315. row = numRows & 3u;
  316. while (row > 0) {
  317. float16_t sum = 0.0f16;
  318. pInVec = pVec;
  319. pInA1 = pSrcA + i;
  320. colCnt = numCols >> 1;
  321. while (colCnt > 0) {
  322. vecData = *(pInVec)++;
  323. vecData2 = *(pInVec)++;
  324. matData = *(pInA1)++;
  325. matData2 = *(pInA1)++;
  326. sum += (_Float16)matData * (_Float16)vecData;
  327. sum += (_Float16)matData2 * (_Float16)vecData2;
  328. colCnt--;
  329. }
  330. // process remainder of row
  331. colCnt = numCols & 1u;
  332. while (colCnt > 0) {
  333. sum += (_Float16)*pInA1++ * (_Float16)*pInVec++;
  334. colCnt--;
  335. }
  336. *px++ = sum;
  337. i = i + numCols;
  338. row--;
  339. }
  340. }
  341. #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
  342. /**
  343. * @} end of MatrixMult group
  344. */
  345. #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */