matrix_utils.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. /******************************************************************************
  2. * @file matrix_utils.h
  3. * @brief Public header file for NMSIS DSP Library
  4. * @version V1.11.0
  5. * @date 30 May 2022
  6. * Target Processor: RISC-V cores
  7. ******************************************************************************/
  8. /*
  9. * Copyright (c) 2010-2022 Arm Limited or its affiliates. All rights reserved.
  10. * Copyright (c) 2019 Nuclei Limited. All rights reserved.
  11. *
  12. * SPDX-License-Identifier: Apache-2.0
  13. *
  14. * Licensed under the Apache License, Version 2.0 (the License); you may
  15. * not use this file except in compliance with the License.
  16. * You may obtain a copy of the License at
  17. *
  18. * www.apache.org/licenses/LICENSE-2.0
  19. *
  20. * Unless required by applicable law or agreed to in writing, software
  21. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  22. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  23. * See the License for the specific language governing permissions and
  24. * limitations under the License.
  25. */
  26. #ifndef MATRIX_UTILS_H_
  27. #define MATRIX_UTILS_H_
  28. #include "riscv_math_types.h"
  29. #include "riscv_math_memory.h"
  30. #include "dsp/none.h"
  31. #include "dsp/utils.h"
  32. #ifdef __cplusplus
  33. extern "C"
  34. {
  35. #endif
  36. #define ELEM(A,ROW,COL) &((A)->pData[(A)->numCols* (ROW) + (COL)])
  37. #define SCALE_COL_T(T,CAST,A,ROW,v,i) \
  38. { \
  39. int32_t _w; \
  40. T *data = (A)->pData; \
  41. const int32_t _numCols = (A)->numCols; \
  42. const int32_t nb = (A)->numRows - ROW;\
  43. \
  44. data += i + _numCols * (ROW); \
  45. \
  46. for(_w=0;_w < nb; _w++) \
  47. { \
  48. *data *= CAST v; \
  49. data += _numCols; \
  50. } \
  51. }
  52. #define COPY_COL_T(T,A,ROW,COL,DST) \
  53. { \
  54. uint32_t _row; \
  55. T *_pb=DST; \
  56. T *_pa = (A)->pData + ROW * (A)->numCols + COL;\
  57. for(_row = ROW; _row < (A)->numRows; _row ++) \
  58. { \
  59. *_pb++ = *_pa; \
  60. _pa += (A)->numCols; \
  61. } \
  62. }
  63. #if defined(RISCV_FLOAT16_SUPPORTED)
  64. #define SWAP_ROWS_F16(A,COL,i,j) \
  65. { \
  66. int32_t _w; \
  67. float16_t *dataI = (A)->pData; \
  68. float16_t *dataJ = (A)->pData; \
  69. const int32_t _numCols = (A)->numCols;\
  70. const int32_t nb = _numCols-(COL); \
  71. \
  72. dataI += i*_numCols + (COL); \
  73. dataJ += j*_numCols + (COL); \
  74. \
  75. for(_w=0;_w < nb; _w++) \
  76. { \
  77. float16_t tmp; \
  78. tmp = *dataI; \
  79. *dataI++ = *dataJ; \
  80. *dataJ++ = tmp; \
  81. } \
  82. }
  83. #define SCALE_ROW_F16(A,COL,v,i) \
  84. { \
  85. int32_t _w; \
  86. float16_t *data = (A)->pData; \
  87. const int32_t _numCols = (A)->numCols;\
  88. const int32_t nb = _numCols-(COL); \
  89. \
  90. data += i*_numCols + (COL); \
  91. \
  92. for(_w=0;_w < nb; _w++) \
  93. { \
  94. *data++ *= (_Float16)v; \
  95. } \
  96. }
  97. #define MAC_ROW_F16(COL,A,i,v,B,j) \
  98. { \
  99. int32_t _w; \
  100. float16_t *dataA = (A)->pData; \
  101. float16_t *dataB = (B)->pData; \
  102. const int32_t _numCols = (A)->numCols; \
  103. const int32_t nb = _numCols-(COL); \
  104. \
  105. dataA += i*_numCols + (COL); \
  106. dataB += j*_numCols + (COL); \
  107. \
  108. for(_w=0;_w < nb; _w++) \
  109. { \
  110. *dataA++ += (_Float16)v * (_Float16)*dataB++;\
  111. } \
  112. }
  113. #define MAS_ROW_F16(COL,A,i,v,B,j) \
  114. { \
  115. int32_t _w; \
  116. float16_t *dataA = (A)->pData; \
  117. float16_t *dataB = (B)->pData; \
  118. const int32_t _numCols = (A)->numCols; \
  119. const int32_t nb = _numCols-(COL); \
  120. \
  121. dataA += i*_numCols + (COL); \
  122. dataB += j*_numCols + (COL); \
  123. \
  124. for(_w=0;_w < nb; _w++) \
  125. { \
  126. *dataA++ -= (_Float16)v * (_Float16)*dataB++;\
  127. } \
  128. }
  129. /* Functions with only a scalar version */
  130. #define COPY_COL_F16(A,ROW,COL,DST) \
  131. COPY_COL_T(float16_t,A,ROW,COL,DST)
  132. #define SCALE_COL_F16(A,ROW,v,i) \
  133. SCALE_COL_T(float16_t,(_Float16),A,ROW,v,i)
  134. #endif /* defined(RISCV_FLOAT16_SUPPORTED)*/
  135. #define SWAP_ROWS_F32(A,COL,i,j) \
  136. { \
  137. int32_t _w; \
  138. float32_t tmp; \
  139. float32_t *dataI = (A)->pData; \
  140. float32_t *dataJ = (A)->pData; \
  141. const int32_t _numCols = (A)->numCols;\
  142. const int32_t nb = _numCols - COL; \
  143. \
  144. dataI += i*_numCols + (COL); \
  145. dataJ += j*_numCols + (COL); \
  146. \
  147. \
  148. for(_w=0;_w < nb; _w++) \
  149. { \
  150. tmp = *dataI; \
  151. *dataI++ = *dataJ; \
  152. *dataJ++ = tmp; \
  153. } \
  154. }
  155. #define SCALE_ROW_F32(A,COL,v,i) \
  156. { \
  157. int32_t _w; \
  158. float32_t *data = (A)->pData; \
  159. const int32_t _numCols = (A)->numCols;\
  160. const int32_t nb = _numCols - COL; \
  161. \
  162. data += i*_numCols + (COL); \
  163. \
  164. for(_w=0;_w < nb; _w++) \
  165. { \
  166. *data++ *= v; \
  167. } \
  168. }
  169. #define MAC_ROW_F32(COL,A,i,v,B,j) \
  170. { \
  171. int32_t _w; \
  172. float32_t *dataA = (A)->pData; \
  173. float32_t *dataB = (B)->pData; \
  174. const int32_t _numCols = (A)->numCols;\
  175. const int32_t nb = _numCols-(COL); \
  176. \
  177. dataA = dataA + i*_numCols + (COL); \
  178. dataB = dataB + j*_numCols + (COL); \
  179. \
  180. for(_w=0;_w < nb; _w++) \
  181. { \
  182. *dataA++ += v* *dataB++; \
  183. } \
  184. }
  185. #define MAS_ROW_F32(COL,A,i,v,B,j) \
  186. { \
  187. int32_t _w; \
  188. float32_t *dataA = (A)->pData; \
  189. float32_t *dataB = (B)->pData; \
  190. const int32_t _numCols = (A)->numCols;\
  191. const int32_t nb = _numCols-(COL); \
  192. \
  193. dataA = dataA + i*_numCols + (COL); \
  194. dataB = dataB + j*_numCols + (COL); \
  195. \
  196. for(_w=0;_w < nb; _w++) \
  197. { \
  198. *dataA++ -= v* *dataB++; \
  199. } \
  200. }
  201. /* Functions _with only a scalar version */
  202. #define COPY_COL_F32(A,ROW,COL,DST) \
  203. COPY_COL_T(float32_t,A,ROW,COL,DST)
  204. #define COPY_COL_F64(A,ROW,COL,DST) \
  205. COPY_COL_T(float64_t,A,ROW,COL,DST)
  206. #define SWAP_COLS_F32(A,COL,i,j) \
  207. { \
  208. int32_t _w; \
  209. float32_t *data = (A)->pData; \
  210. const int32_t _numCols = (A)->numCols; \
  211. for(_w=(COL);_w < _numCols; _w++) \
  212. { \
  213. float32_t tmp; \
  214. tmp = data[_w*_numCols + i]; \
  215. data[_w*_numCols + i] = data[_w*_numCols + j];\
  216. data[_w*_numCols + j] = tmp; \
  217. } \
  218. }
  219. #define SCALE_COL_F32(A,ROW,v,i) \
  220. SCALE_COL_T(float32_t,,A,ROW,v,i)
  221. #define SWAP_ROWS_F64(A,COL,i,j) \
  222. { \
  223. int32_t _w; \
  224. float64_t *dataI = (A)->pData; \
  225. float64_t *dataJ = (A)->pData; \
  226. const int32_t _numCols = (A)->numCols;\
  227. const int32_t nb = _numCols-(COL); \
  228. \
  229. dataI += i*_numCols + (COL); \
  230. dataJ += j*_numCols + (COL); \
  231. \
  232. for(_w=0;_w < nb; _w++) \
  233. { \
  234. float64_t tmp; \
  235. tmp = *dataI; \
  236. *dataI++ = *dataJ; \
  237. *dataJ++ = tmp; \
  238. } \
  239. }
  240. #define SWAP_COLS_F64(A,COL,i,j) \
  241. { \
  242. int32_t _w; \
  243. float64_t *data = (A)->pData; \
  244. const int32_t _numCols = (A)->numCols; \
  245. for(_w=(COL);_w < _numCols; _w++) \
  246. { \
  247. float64_t tmp; \
  248. tmp = data[_w*_numCols + i]; \
  249. data[_w*_numCols + i] = data[_w*_numCols + j];\
  250. data[_w*_numCols + j] = tmp; \
  251. } \
  252. }
  253. #define SCALE_ROW_F64(A,COL,v,i) \
  254. { \
  255. int32_t _w; \
  256. float64_t *data = (A)->pData; \
  257. const int32_t _numCols = (A)->numCols;\
  258. const int32_t nb = _numCols-(COL); \
  259. \
  260. data += i*_numCols + (COL); \
  261. \
  262. for(_w=0;_w < nb; _w++) \
  263. { \
  264. *data++ *= v; \
  265. } \
  266. }
  267. #define SCALE_COL_F64(A,ROW,v,i) \
  268. SCALE_COL_T(float64_t,,A,ROW,v,i)
  269. #define MAC_ROW_F64(COL,A,i,v,B,j) \
  270. { \
  271. int32_t _w; \
  272. float64_t *dataA = (A)->pData; \
  273. float64_t *dataB = (B)->pData; \
  274. const int32_t _numCols = (A)->numCols;\
  275. const int32_t nb = _numCols-(COL); \
  276. \
  277. dataA += i*_numCols + (COL); \
  278. dataB += j*_numCols + (COL); \
  279. \
  280. for(_w=0;_w < nb; _w++) \
  281. { \
  282. *dataA++ += v* *dataB++; \
  283. } \
  284. }
  285. #define MAS_ROW_F64(COL,A,i,v,B,j) \
  286. { \
  287. int32_t _w; \
  288. float64_t *dataA = (A)->pData; \
  289. float64_t *dataB = (B)->pData; \
  290. const int32_t _numCols = (A)->numCols;\
  291. const int32_t nb = _numCols-(COL); \
  292. \
  293. dataA += i*_numCols + (COL); \
  294. dataB += j*_numCols + (COL); \
  295. \
  296. for(_w=0;_w < nb; _w++) \
  297. { \
  298. *dataA++ -= v* *dataB++; \
  299. } \
  300. }
  301. #ifdef __cplusplus
  302. }
  303. #endif
  304. #endif /* ifndef _MATRIX_UTILS_H_ */