arm_mat_ldlt_f32.c 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. /* ----------------------------------------------------------------------
  2. * Project: CMSIS DSP Library
  3. * Title: arm_mat_ldl_f32.c
  4. * Description: Floating-point LDL decomposition
  5. *
  6. * $Date: 23 April 2021
  7. * $Revision: V1.9.0
  8. *
  9. * Target Processor: Cortex-M and Cortex-A cores
  10. * -------------------------------------------------------------------- */
  11. /*
  12. * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
  13. *
  14. * SPDX-License-Identifier: Apache-2.0
  15. *
  16. * Licensed under the Apache License, Version 2.0 (the License); you may
  17. * not use this file except in compliance with the License.
  18. * You may obtain a copy of the License at
  19. *
  20. * www.apache.org/licenses/LICENSE-2.0
  21. *
  22. * Unless required by applicable law or agreed to in writing, software
  23. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  24. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  25. * See the License for the specific language governing permissions and
  26. * limitations under the License.
  27. */
  28. #include "dsp/matrix_functions.h"
  29. #include "dsp/matrix_utils.h"
  30. #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
  31. /**
  32. @ingroup groupMatrix
  33. */
  34. /**
  35. @addtogroup MatrixChol
  36. @{
  37. */
  38. /**
  39. * @brief Floating-point LDL^t decomposition of positive semi-definite matrix.
  40. * @param[in] pSrc points to the instance of the input floating-point matrix structure.
  41. * @param[out] pl points to the instance of the output floating-point triangular matrix structure.
  42. * @param[out] pd points to the instance of the output floating-point diagonal matrix structure.
  43. * @param[out] pp points to the instance of the output floating-point permutation vector.
  44. * @return The function returns ARM_MATH_SIZE_MISMATCH, if the dimensions do not match.
  45. * @return execution status
  46. - \ref ARM_MATH_SUCCESS : Operation successful
  47. - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
  48. - \ref ARM_MATH_DECOMPOSITION_FAILURE : Input matrix cannot be decomposed
  49. * @par
  50. * Computes the LDL^t decomposition of a matrix A such that P A P^t = L D L^t.
  51. */
  52. arm_status arm_mat_ldlt_f32(
  53. const arm_matrix_instance_f32 * pSrc,
  54. arm_matrix_instance_f32 * pl,
  55. arm_matrix_instance_f32 * pd,
  56. uint16_t * pp)
  57. {
  58. arm_status status; /* status of matrix inverse */
  59. #ifdef ARM_MATH_MATRIX_CHECK
  60. /* Check for matrix mismatch condition */
  61. if ((pSrc->numRows != pSrc->numCols) ||
  62. (pl->numRows != pl->numCols) ||
  63. (pd->numRows != pd->numCols) ||
  64. (pl->numRows != pd->numRows) )
  65. {
  66. /* Set status as ARM_MATH_SIZE_MISMATCH */
  67. status = ARM_MATH_SIZE_MISMATCH;
  68. }
  69. else
  70. #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
  71. {
  72. const int n=pSrc->numRows;
  73. int fullRank = 1, diag,k;
  74. float32_t *pA;
  75. memset(pd->pData,0,sizeof(float32_t)*n*n);
  76. memcpy(pl->pData,pSrc->pData,n*n*sizeof(float32_t));
  77. pA = pl->pData;
  78. int cnt = n;
  79. uint16x8_t vecP;
  80. for(int k=0;k < n; k+=8)
  81. {
  82. mve_pred16_t p0;
  83. p0 = vctp16q(cnt);
  84. vecP = vidupq_u16((uint16_t)k, 1);
  85. vstrhq_p(&pp[k], vecP, p0);
  86. cnt -= 8;
  87. }
  88. for(k=0;k < n; k++)
  89. {
  90. /* Find pivot */
  91. float32_t m=F32_MIN,a;
  92. int j=k;
  93. for(int r=k;r<n;r++)
  94. {
  95. if (pA[r*n+r] > m)
  96. {
  97. m = pA[r*n+r];
  98. j = r;
  99. }
  100. }
  101. if(j != k)
  102. {
  103. SWAP_ROWS_F32(pl,0,k,j);
  104. SWAP_COLS_F32(pl,0,k,j);
  105. }
  106. pp[k] = j;
  107. a = pA[k*n+k];
  108. if (fabsf(a) < 1.0e-8f)
  109. {
  110. fullRank = 0;
  111. break;
  112. }
  113. float32_t invA;
  114. invA = 1.0f / a;
  115. int32x4_t vecOffs;
  116. int w;
  117. vecOffs = vidupq_u32((uint32_t)0, 1);
  118. vecOffs = vmulq_n_s32(vecOffs,n);
  119. for(w=k+1; w<n; w+=4)
  120. {
  121. int cnt = n - k - 1;
  122. f32x4_t vecX;
  123. f32x4_t vecA;
  124. f32x4_t vecW0,vecW1, vecW2, vecW3;
  125. mve_pred16_t p0;
  126. vecW0 = vdupq_n_f32(pA[(w + 0)*n+k]);
  127. vecW1 = vdupq_n_f32(pA[(w + 1)*n+k]);
  128. vecW2 = vdupq_n_f32(pA[(w + 2)*n+k]);
  129. vecW3 = vdupq_n_f32(pA[(w + 3)*n+k]);
  130. for(int x=k+1;x<n;x += 4)
  131. {
  132. p0 = vctp32q(cnt);
  133. //pA[w*n+x] = pA[w*n+x] - pA[w*n+k] * (pA[x*n+k] * invA);
  134. vecX = vldrwq_gather_shifted_offset_z_f32(&pA[x*n+k], (uint32x4_t)vecOffs, p0);
  135. vecX = vmulq_m_n_f32(vuninitializedq_f32(),vecX,invA,p0);
  136. vecA = vldrwq_z_f32(&pA[(w + 0)*n+x],p0);
  137. vecA = vfmsq_m(vecA, vecW0, vecX, p0);
  138. vstrwq_p(&pA[(w + 0)*n+x], vecA, p0);
  139. vecA = vldrwq_z_f32(&pA[(w + 1)*n+x],p0);
  140. vecA = vfmsq_m(vecA, vecW1, vecX, p0);
  141. vstrwq_p(&pA[(w + 1)*n+x], vecA, p0);
  142. vecA = vldrwq_z_f32(&pA[(w + 2)*n+x],p0);
  143. vecA = vfmsq_m(vecA, vecW2, vecX, p0);
  144. vstrwq_p(&pA[(w + 2)*n+x], vecA, p0);
  145. vecA = vldrwq_z_f32(&pA[(w + 3)*n+x],p0);
  146. vecA = vfmsq_m(vecA, vecW3, vecX, p0);
  147. vstrwq_p(&pA[(w + 3)*n+x], vecA, p0);
  148. cnt -= 4;
  149. }
  150. }
  151. for(; w<n; w++)
  152. {
  153. int cnt = n - k - 1;
  154. f32x4_t vecA,vecX,vecW;
  155. mve_pred16_t p0;
  156. vecW = vdupq_n_f32(pA[w*n+k]);
  157. for(int x=k+1;x<n;x += 4)
  158. {
  159. p0 = vctp32q(cnt);
  160. //pA[w*n+x] = pA[w*n+x] - pA[w*n+k] * (pA[x*n+k] * invA);
  161. vecA = vldrwq_z_f32(&pA[w*n+x],p0);
  162. vecX = vldrwq_gather_shifted_offset_z_f32(&pA[x*n+k], (uint32x4_t)vecOffs, p0);
  163. vecX = vmulq_m_n_f32(vuninitializedq_f32(),vecX,invA,p0);
  164. vecA = vfmsq_m(vecA, vecW, vecX, p0);
  165. vstrwq_p(&pA[w*n+x], vecA, p0);
  166. cnt -= 4;
  167. }
  168. }
  169. for(int w=k+1;w<n;w++)
  170. {
  171. pA[w*n+k] = pA[w*n+k] * invA;
  172. }
  173. }
  174. diag=k;
  175. if (!fullRank)
  176. {
  177. diag--;
  178. for(int row=0; row < n;row++)
  179. {
  180. mve_pred16_t p0;
  181. int cnt= n-k;
  182. f32x4_t zero=vdupq_n_f32(0.0f);
  183. for(int col=k; col < n;col += 4)
  184. {
  185. p0 = vctp32q(cnt);
  186. vstrwq_p(&pl->pData[row*n+col], zero, p0);
  187. cnt -= 4;
  188. }
  189. }
  190. }
  191. for(int row=0; row < n;row++)
  192. {
  193. mve_pred16_t p0;
  194. int cnt= n-row-1;
  195. f32x4_t zero=vdupq_n_f32(0.0f);
  196. for(int col=row+1; col < n;col+=4)
  197. {
  198. p0 = vctp32q(cnt);
  199. vstrwq_p(&pl->pData[row*n+col], zero, p0);
  200. cnt -= 4;
  201. }
  202. }
  203. for(int d=0; d < diag;d++)
  204. {
  205. pd->pData[d*n+d] = pl->pData[d*n+d];
  206. pl->pData[d*n+d] = 1.0;
  207. }
  208. status = ARM_MATH_SUCCESS;
  209. }
  210. /* Return to application */
  211. return (status);
  212. }
  213. #else
  214. /**
  215. @ingroup groupMatrix
  216. */
  217. /**
  218. @addtogroup MatrixChol
  219. @{
  220. */
  221. /**
  222. * @brief Floating-point LDL^t decomposition of positive semi-definite matrix.
  223. * @param[in] pSrc points to the instance of the input floating-point matrix structure.
  224. * @param[out] pl points to the instance of the output floating-point triangular matrix structure.
  225. * @param[out] pd points to the instance of the output floating-point diagonal matrix structure.
  226. * @param[out] pp points to the instance of the output floating-point permutation vector.
  227. * @return The function returns ARM_MATH_SIZE_MISMATCH, if the dimensions do not match.
  228. * @return execution status
  229. - \ref ARM_MATH_SUCCESS : Operation successful
  230. - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
  231. - \ref ARM_MATH_DECOMPOSITION_FAILURE : Input matrix cannot be decomposed
  232. * @par
  233. * Computes the LDL^t decomposition of a matrix A such that P A P^t = L D L^t.
  234. */
  235. arm_status arm_mat_ldlt_f32(
  236. const arm_matrix_instance_f32 * pSrc,
  237. arm_matrix_instance_f32 * pl,
  238. arm_matrix_instance_f32 * pd,
  239. uint16_t * pp)
  240. {
  241. arm_status status; /* status of matrix inverse */
  242. #ifdef ARM_MATH_MATRIX_CHECK
  243. /* Check for matrix mismatch condition */
  244. if ((pSrc->numRows != pSrc->numCols) ||
  245. (pl->numRows != pl->numCols) ||
  246. (pd->numRows != pd->numCols) ||
  247. (pl->numRows != pd->numRows) )
  248. {
  249. /* Set status as ARM_MATH_SIZE_MISMATCH */
  250. status = ARM_MATH_SIZE_MISMATCH;
  251. }
  252. else
  253. #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
  254. {
  255. const int n=pSrc->numRows;
  256. int fullRank = 1, diag,k;
  257. float32_t *pA;
  258. int row,d;
  259. memset(pd->pData,0,sizeof(float32_t)*n*n);
  260. memcpy(pl->pData,pSrc->pData,n*n*sizeof(float32_t));
  261. pA = pl->pData;
  262. for(k=0;k < n; k++)
  263. {
  264. pp[k] = k;
  265. }
  266. for(k=0;k < n; k++)
  267. {
  268. /* Find pivot */
  269. float32_t m=F32_MIN,a;
  270. int j=k;
  271. int r;
  272. for(r=k;r<n;r++)
  273. {
  274. if (pA[r*n+r] > m)
  275. {
  276. m = pA[r*n+r];
  277. j = r;
  278. }
  279. }
  280. if(j != k)
  281. {
  282. SWAP_ROWS_F32(pl,0,k,j);
  283. SWAP_COLS_F32(pl,0,k,j);
  284. }
  285. pp[k] = j;
  286. a = pA[k*n+k];
  287. if (fabsf(a) < 1.0e-8f)
  288. {
  289. fullRank = 0;
  290. break;
  291. }
  292. for(int w=k+1;w<n;w++)
  293. {
  294. int x;
  295. for(x=k+1;x<n;x++)
  296. {
  297. pA[w*n+x] = pA[w*n+x] - pA[w*n+k] * pA[x*n+k] / a;
  298. }
  299. }
  300. for(int w=k+1;w<n;w++)
  301. {
  302. pA[w*n+k] = pA[w*n+k] / a;
  303. }
  304. }
  305. diag=k;
  306. if (!fullRank)
  307. {
  308. diag--;
  309. for(row=0; row < n;row++)
  310. {
  311. int col;
  312. for(col=k; col < n;col++)
  313. {
  314. pl->pData[row*n+col]=0.0;
  315. }
  316. }
  317. }
  318. for(row=0; row < n;row++)
  319. {
  320. int col;
  321. for(col=row+1; col < n;col++)
  322. {
  323. pl->pData[row*n+col] = 0.0;
  324. }
  325. }
  326. for(d=0; d < diag;d++)
  327. {
  328. pd->pData[d*n+d] = pl->pData[d*n+d];
  329. pl->pData[d*n+d] = 1.0;
  330. }
  331. status = ARM_MATH_SUCCESS;
  332. }
  333. /* Return to application */
  334. return (status);
  335. }
  336. #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
  337. /**
  338. @} end of MatrixChol group
  339. */