arm_mat_solve_lower_triangular_f32.c 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. /* ----------------------------------------------------------------------
  2. * Project: CMSIS DSP Library
  3. * Title: arm_mat_solve_lower_triangular_f32.c
  4. * Description: Solve linear system LT X = A with LT lower triangular matrix
  5. *
  6. * $Date: 23 April 2021
  7. * $Revision: V1.9.0
  8. *
  9. * Target Processor: Cortex-M and Cortex-A cores
  10. * -------------------------------------------------------------------- */
  11. /*
  12. * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
  13. *
  14. * SPDX-License-Identifier: Apache-2.0
  15. *
  16. * Licensed under the Apache License, Version 2.0 (the License); you may
  17. * not use this file except in compliance with the License.
  18. * You may obtain a copy of the License at
  19. *
  20. * www.apache.org/licenses/LICENSE-2.0
  21. *
  22. * Unless required by applicable law or agreed to in writing, software
  23. * distributed under the License is distributed on an AS IS BASIS, WITHOUT
  24. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  25. * See the License for the specific language governing permissions and
  26. * limitations under the License.
  27. */
  28. #include "dsp/matrix_functions.h"
  29. /**
  30. @ingroup groupMatrix
  31. */
  32. /**
  33. @addtogroup MatrixInv
  34. @{
  35. */
  36. /**
  37. * @brief Solve LT . X = A where LT is a lower triangular matrix
  38. * @param[in] lt The lower triangular matrix
  39. * @param[in] a The matrix a
  40. * @param[out] dst The solution X of LT . X = A
  41. * @return The function returns ARM_MATH_SINGULAR, if the system can't be solved.
  42. */
  43. #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
  44. #include "arm_helium_utils.h"
  45. arm_status arm_mat_solve_lower_triangular_f32(
  46. const arm_matrix_instance_f32 * lt,
  47. const arm_matrix_instance_f32 * a,
  48. arm_matrix_instance_f32 * dst)
  49. {
  50. arm_status status; /* status of matrix inverse */
  51. #ifdef ARM_MATH_MATRIX_CHECK
  52. /* Check for matrix mismatch condition */
  53. if ((lt->numRows != lt->numCols) ||
  54. (lt->numRows != a->numRows) )
  55. {
  56. /* Set status as ARM_MATH_SIZE_MISMATCH */
  57. status = ARM_MATH_SIZE_MISMATCH;
  58. }
  59. else
  60. #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
  61. {
  62. /* a1 b1 c1 x1 = a1
  63. b2 c2 x2 a2
  64. c3 x3 a3
  65. x3 = a3 / c3
  66. x2 = (a2 - c2 x3) / b2
  67. */
  68. int i,j,k,n,cols;
  69. n = dst->numRows;
  70. cols = dst->numCols;
  71. float32_t *pX = dst->pData;
  72. float32_t *pLT = lt->pData;
  73. float32_t *pA = a->pData;
  74. float32_t *lt_row;
  75. float32_t *a_col;
  76. float32_t invLT;
  77. f32x4_t vecA;
  78. f32x4_t vecX;
  79. for(i=0; i < n ; i++)
  80. {
  81. for(j=0; j+3 < cols; j += 4)
  82. {
  83. vecA = vld1q_f32(&pA[i * cols + j]);
  84. for(k=0; k < i; k++)
  85. {
  86. vecX = vld1q_f32(&pX[cols*k+j]);
  87. vecA = vfmsq(vecA,vdupq_n_f32(pLT[n*i + k]),vecX);
  88. }
  89. if (pLT[n*i + i]==0.0f)
  90. {
  91. return(ARM_MATH_SINGULAR);
  92. }
  93. invLT = 1.0f / pLT[n*i + i];
  94. vecA = vmulq(vecA,vdupq_n_f32(invLT));
  95. vst1q(&pX[i*cols+j],vecA);
  96. }
  97. for(; j < cols; j ++)
  98. {
  99. a_col = &pA[j];
  100. lt_row = &pLT[n*i];
  101. float32_t tmp=a_col[i * cols];
  102. for(k=0; k < i; k++)
  103. {
  104. tmp -= lt_row[k] * pX[cols*k+j];
  105. }
  106. if (lt_row[i]==0.0f)
  107. {
  108. return(ARM_MATH_SINGULAR);
  109. }
  110. tmp = tmp / lt_row[i];
  111. pX[i*cols+j] = tmp;
  112. }
  113. }
  114. status = ARM_MATH_SUCCESS;
  115. }
  116. /* Return to application */
  117. return (status);
  118. }
  119. #else
  120. #if defined(ARM_MATH_NEON) && !defined(ARM_MATH_AUTOVECTORIZE)
  121. arm_status arm_mat_solve_lower_triangular_f32(
  122. const arm_matrix_instance_f32 * lt,
  123. const arm_matrix_instance_f32 * a,
  124. arm_matrix_instance_f32 * dst)
  125. {
  126. arm_status status; /* status of matrix inverse */
  127. #ifdef ARM_MATH_MATRIX_CHECK
  128. /* Check for matrix mismatch condition */
  129. if ((lt->numRows != lt->numCols) ||
  130. (lt->numRows != a->numRows) )
  131. {
  132. /* Set status as ARM_MATH_SIZE_MISMATCH */
  133. status = ARM_MATH_SIZE_MISMATCH;
  134. }
  135. else
  136. #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
  137. {
  138. /* a1 b1 c1 x1 = a1
  139. b2 c2 x2 a2
  140. c3 x3 a3
  141. x3 = a3 / c3
  142. x2 = (a2 - c2 x3) / b2
  143. */
  144. int i,j,k,n,cols;
  145. n = dst->numRows;
  146. cols = dst->numCols;
  147. float32_t *pX = dst->pData;
  148. float32_t *pLT = lt->pData;
  149. float32_t *pA = a->pData;
  150. float32_t *lt_row;
  151. float32_t *a_col;
  152. float32_t invLT;
  153. f32x4_t vecA;
  154. f32x4_t vecX;
  155. for(i=0; i < n ; i++)
  156. {
  157. for(j=0; j+3 < cols; j += 4)
  158. {
  159. vecA = vld1q_f32(&pA[i * cols + j]);
  160. for(k=0; k < i; k++)
  161. {
  162. vecX = vld1q_f32(&pX[cols*k+j]);
  163. vecA = vfmsq_f32(vecA,vdupq_n_f32(pLT[n*i + k]),vecX);
  164. }
  165. if (pLT[n*i + i]==0.0f)
  166. {
  167. return(ARM_MATH_SINGULAR);
  168. }
  169. invLT = 1.0f / pLT[n*i + i];
  170. vecA = vmulq_f32(vecA,vdupq_n_f32(invLT));
  171. vst1q_f32(&pX[i*cols+j],vecA);
  172. }
  173. for(; j < cols; j ++)
  174. {
  175. a_col = &pA[j];
  176. lt_row = &pLT[n*i];
  177. float32_t tmp=a_col[i * cols];
  178. for(k=0; k < i; k++)
  179. {
  180. tmp -= lt_row[k] * pX[cols*k+j];
  181. }
  182. if (lt_row[i]==0.0f)
  183. {
  184. return(ARM_MATH_SINGULAR);
  185. }
  186. tmp = tmp / lt_row[i];
  187. pX[i*cols+j] = tmp;
  188. }
  189. }
  190. status = ARM_MATH_SUCCESS;
  191. }
  192. /* Return to application */
  193. return (status);
  194. }
  195. #else
  196. arm_status arm_mat_solve_lower_triangular_f32(
  197. const arm_matrix_instance_f32 * lt,
  198. const arm_matrix_instance_f32 * a,
  199. arm_matrix_instance_f32 * dst)
  200. {
  201. arm_status status; /* status of matrix inverse */
  202. #ifdef ARM_MATH_MATRIX_CHECK
  203. /* Check for matrix mismatch condition */
  204. if ((lt->numRows != lt->numCols) ||
  205. (lt->numRows != a->numRows) )
  206. {
  207. /* Set status as ARM_MATH_SIZE_MISMATCH */
  208. status = ARM_MATH_SIZE_MISMATCH;
  209. }
  210. else
  211. #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
  212. {
  213. /* a1 b1 c1 x1 = a1
  214. b2 c2 x2 a2
  215. c3 x3 a3
  216. x3 = a3 / c3
  217. x2 = (a2 - c2 x3) / b2
  218. */
  219. int i,j,k,n,cols;
  220. float32_t *pX = dst->pData;
  221. float32_t *pLT = lt->pData;
  222. float32_t *pA = a->pData;
  223. float32_t *lt_row;
  224. float32_t *a_col;
  225. n = dst->numRows;
  226. cols = dst -> numCols;
  227. for(j=0; j < cols; j ++)
  228. {
  229. a_col = &pA[j];
  230. for(i=0; i < n ; i++)
  231. {
  232. float32_t tmp=a_col[i * cols];
  233. lt_row = &pLT[n*i];
  234. for(k=0; k < i; k++)
  235. {
  236. tmp -= lt_row[k] * pX[cols*k+j];
  237. }
  238. if (lt_row[i]==0.0f)
  239. {
  240. return(ARM_MATH_SINGULAR);
  241. }
  242. tmp = tmp / lt_row[i];
  243. pX[i*cols+j] = tmp;
  244. }
  245. }
  246. status = ARM_MATH_SUCCESS;
  247. }
  248. /* Return to application */
  249. return (status);
  250. }
  251. #endif /* #if defined(ARM_MATH_NEON) */
  252. #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
  253. /**
  254. @} end of MatrixInv group
  255. */