ソースを参照

CMSIS-DSP: Improvements to matrix inversion.

Partial pivoting added for better numerical stability.
Christophe Favergeon 3 年 前
コミット
cb0960577d

+ 615 - 0
Include/dsp/matrix_utils.h

@@ -0,0 +1,615 @@
+/******************************************************************************
+ * @file     matrix_utils.h
+ * @brief    Public header file for CMSIS DSP Library
+ * @version  V1.11.0
+ * @date     30 May 2022
+ * Target Processor: Cortex-M and Cortex-A cores
+ ******************************************************************************/
+/*
+ * Copyright (c) 2010-2022 Arm Limited or its affiliates. All rights reserved.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the License); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ 
+#ifndef _MATRIX_UTILS_H_
+#define _MATRIX_UTILS_H_
+
+#include "arm_math_types.h"
+#include "arm_math_memory.h"
+
+#include "dsp/none.h"
+#include "dsp/utils.h"
+
+#ifdef   __cplusplus
+extern "C"
+{
+#endif
+
+#define ELEM(A,ROW,COL) &((A)->pData[(A)->numCols* (ROW) + (COL)])
+
+#define SCALE_COL_T(T,CAST,A,ROW,v,i)        \
+{                                       \
+  int32_t w;                            \
+  T *data = (A)->pData;                 \
+  const int32_t numCols = (A)->numCols; \
+  const int32_t nb = (A)->numRows - ROW;\
+                                        \
+  data += i + numCols * (ROW);          \
+                                        \
+  for(w=0;w < nb; w++)                  \
+  {                                     \
+     *data *= CAST v;                   \
+     data += numCols;                   \
+  }                                     \
+}
+
+#if defined(ARM_FLOAT16_SUPPORTED)
+#if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
+
+#define SWAP_ROWS_F16(A,COL,i,j)                  \
+  {                                               \
+    int cnt = ((A)->numCols)-(COL);               \
+    int32_t w;                                   \
+    float16_t *data = (A)->pData;                 \
+    const int32_t numCols = (A)->numCols;        \
+                                                  \
+    for(w=(COL);w < numCols; w+=8)                \
+    {                                             \
+       f16x8_t tmpa,tmpb;                         \
+       mve_pred16_t p0 = vctp16q(cnt);            \
+                                                  \
+       tmpa=vldrhq_z_f16(&data[i*numCols + w],p0);\
+       tmpb=vldrhq_z_f16(&data[j*numCols + w],p0);\
+                                                  \
+       vstrhq_p(&data[i*numCols + w], tmpb, p0);  \
+       vstrhq_p(&data[j*numCols + w], tmpa, p0);  \
+                                                  \
+       cnt -= 8;                                  \
+    }                                             \
+  }
+
+#define SCALE_ROW_F16(A,COL,v,i)                   \
+{                                                   \
+  int cnt = ((A)->numCols)-(COL);                   \
+  int32_t w;                                       \
+  float16_t *data = (A)->pData;                     \
+  const int32_t numCols = (A)->numCols;            \
+                                                    \
+  for(w=(COL);w < numCols; w+=8)                    \
+  {                                                 \
+       f16x8_t tmpa;                                \
+       mve_pred16_t p0 = vctp16q(cnt);              \
+       tmpa = vldrhq_z_f16(&data[i*numCols + w],p0);\
+       tmpa = vmulq_n_f16(tmpa,(_Float16)v);                  \
+       vstrhq_p(&data[i*numCols + w], tmpa, p0);    \
+       cnt -= 8;                                    \
+  }                                                 \
+                                                    \
+}
+
+#define MAC_ROW_F16(COL,A,i,v,B,j)                   \
+{                                                    \
+  int cnt = ((A)->numCols)-(COL);                    \
+  int32_t w;                                        \
+  float16_t *dataA = (A)->pData;                     \
+  float16_t *dataB = (B)->pData;                     \
+  const int32_t numCols = (A)->numCols;             \
+                                                     \
+  for(w=(COL);w < numCols; w+=8)                     \
+  {                                                  \
+       f16x8_t tmpa,tmpb;                            \
+       mve_pred16_t p0 = vctp16q(cnt);               \
+       tmpa = vldrhq_z_f16(&dataA[i*numCols + w],p0);\
+       tmpb = vldrhq_z_f16(&dataB[j*numCols + w],p0);\
+       tmpa = vfmaq_n_f16(tmpa,tmpb,v);              \
+       vstrhq_p(&dataA[i*numCols + w], tmpa, p0);    \
+       cnt -= 8;                                     \
+  }                                                  \
+                                                     \
+}
+
+#define MAS_ROW_F16(COL,A,i,v,B,j)                   \
+{                                                    \
+  int cnt = ((A)->numCols)-(COL);                    \
+  int32_t w;                                        \
+  float16_t *dataA = (A)->pData;                     \
+  float16_t *dataB = (B)->pData;                     \
+  const int32_t numCols = (A)->numCols;             \
+  f16x8_t vec=vdupq_n_f16(v);                        \
+                                                     \
+  for(w=(COL);w < numCols; w+=8)                     \
+  {                                                  \
+       f16x8_t tmpa,tmpb;                            \
+       mve_pred16_t p0 = vctp16q(cnt);               \
+       tmpa = vldrhq_z_f16(&dataA[i*numCols + w],p0);\
+       tmpb = vldrhq_z_f16(&dataB[j*numCols + w],p0);\
+       tmpa = vfmsq_f16(tmpa,tmpb,vec);              \
+       vstrhq_p(&dataA[i*numCols + w], tmpa, p0);    \
+       cnt -= 8;                                     \
+  }                                                  \
+                                                     \
+}
+
+#else
+
+#define SWAP_ROWS_F16(A,COL,i,j)       \
+{                                      \
+  int32_t w;                           \
+  float16_t *dataI = (A)->pData;       \
+  float16_t *dataJ = (A)->pData;       \
+  const int32_t numCols = (A)->numCols;\
+  const int32_t nb = numCols-(COL);    \
+                                       \
+  dataI += i*numCols + (COL);          \
+  dataJ += j*numCols + (COL);          \
+                                       \
+  for(w=0;w < nb; w++)                 \
+  {                                    \
+     float16_t tmp;                    \
+     tmp = *dataI;                     \
+     *dataI++ = *dataJ;                \
+     *dataJ++ = tmp;                   \
+  }                                    \
+}
+
+#define SCALE_ROW_F16(A,COL,v,i)       \
+{                                      \
+  int32_t w;                           \
+  float16_t *data = (A)->pData;        \
+  const int32_t numCols = (A)->numCols;\
+  const int32_t nb = numCols-(COL);    \
+                                       \
+  data += i*numCols + (COL);           \
+                                       \
+  for(w=0;w < nb; w++)                 \
+  {                                    \
+     *data++ *= (_Float16)v;           \
+  }                                    \
+}
+
+
+#define MAC_ROW_F16(COL,A,i,v,B,j)                \
+{                                                 \
+  int32_t w;                                      \
+  float16_t *dataA = (A)->pData;                  \
+  float16_t *dataB = (B)->pData;                  \
+  const int32_t numCols = (A)->numCols;           \
+  const int32_t nb = numCols-(COL);               \
+                                                  \
+  dataA += i*numCols + (COL);                     \
+  dataB += j*numCols + (COL);                     \
+                                                  \
+  for(w=0;w < nb; w++)                            \
+  {                                               \
+     *dataA++ += (_Float16)v * (_Float16)*dataB++;\
+  }                                               \
+}
+
+#define MAS_ROW_F16(COL,A,i,v,B,j)                \
+{                                                 \
+  int32_t w;                                      \
+  float16_t *dataA = (A)->pData;                  \
+  float16_t *dataB = (B)->pData;                  \
+  const int32_t numCols = (A)->numCols;           \
+  const int32_t nb = numCols-(COL);               \
+                                                  \
+  dataA += i*numCols + (COL);                     \
+  dataB += j*numCols + (COL);                     \
+                                                  \
+  for(w=0;w < nb; w++)                            \
+  {                                               \
+     *dataA++ -= (_Float16)v * (_Float16)*dataB++;\
+  }                                               \
+}
+
+#endif /*defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)*/
+
+
+#define SCALE_COL_F16(A,ROW,v,i)        \
+  SCALE_COL_T(float16_t,(_Float16),A,ROW,v,i)
+  
+#endif /* defined(ARM_FLOAT16_SUPPORTED)*/
+
+#if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
+
+#define SWAP_ROWS_F32(A,COL,i,j)                  \
+  {                                               \
+    int cnt = ((A)->numCols)-(COL);               \
+    float32_t *data = (A)->pData;                 \
+    const int32_t numCols = (A)->numCols;        \
+    int32_t w;                                   \
+                                                  \
+    for(w=(COL);w < numCols; w+=4)                \
+    {                                             \
+       f32x4_t tmpa,tmpb;                         \
+       mve_pred16_t p0 = vctp32q(cnt);            \
+                                                  \
+       tmpa=vldrwq_z_f32(&data[i*numCols + w],p0);\
+       tmpb=vldrwq_z_f32(&data[j*numCols + w],p0);\
+                                                  \
+       vstrwq_p(&data[i*numCols + w], tmpb, p0);  \
+       vstrwq_p(&data[j*numCols + w], tmpa, p0);  \
+                                                  \
+       cnt -= 4;                                  \
+    }                                             \
+  }
+
+#define MAC_ROW_F32(COL,A,i,v,B,j)                   \
+{                                                    \
+  int cnt = ((A)->numCols)-(COL);                    \
+  float32_t *dataA = (A)->pData;                     \
+  float32_t *dataB = (B)->pData;                     \
+  const int32_t numCols = (A)->numCols;             \
+  int32_t w;                                        \
+                                                     \
+  for(w=(COL);w < numCols; w+=4)                     \
+  {                                                  \
+       f32x4_t tmpa,tmpb;                            \
+       mve_pred16_t p0 = vctp32q(cnt);               \
+       tmpa = vldrwq_z_f32(&dataA[i*numCols + w],p0);\
+       tmpb = vldrwq_z_f32(&dataB[j*numCols + w],p0);\
+       tmpa = vfmaq_n_f32(tmpa,tmpb,v);              \
+       vstrwq_p(&dataA[i*numCols + w], tmpa, p0);    \
+       cnt -= 4;                                     \
+  }                                                  \
+                                                     \
+}
+
+#define MAS_ROW_F32(COL,A,i,v,B,j)                   \
+{                                                    \
+  int cnt = ((A)->numCols)-(COL);                    \
+  float32_t *dataA = (A)->pData;                     \
+  float32_t *dataB = (B)->pData;                     \
+  const int32_t numCols = (A)->numCols;             \
+  int32_t w;                                        \
+  f32x4_t vec=vdupq_n_f32(v);                        \
+                                                     \
+  for(w=(COL);w < numCols; w+=4)                     \
+  {                                                  \
+       f32x4_t tmpa,tmpb;                            \
+       mve_pred16_t p0 = vctp32q(cnt);               \
+       tmpa = vldrwq_z_f32(&dataA[i*numCols + w],p0);\
+       tmpb = vldrwq_z_f32(&dataB[j*numCols + w],p0);\
+       tmpa = vfmsq_f32(tmpa,tmpb,vec);              \
+       vstrwq_p(&dataA[i*numCols + w], tmpa, p0);    \
+       cnt -= 4;                                     \
+  }                                                  \
+                                                     \
+}
+
+#define SCALE_ROW_F32(A,COL,v,i)                    \
+{                                                   \
+  int cnt = ((A)->numCols)-(COL);                   \
+  float32_t *data = (A)->pData;                     \
+  const int32_t numCols = (A)->numCols;            \
+  int32_t w;                                       \
+                                                    \
+  for(w=(COL);w < numCols; w+=4)                    \
+  {                                                 \
+       f32x4_t tmpa;                                \
+       mve_pred16_t p0 = vctp32q(cnt);              \
+       tmpa = vldrwq_z_f32(&data[i*numCols + w],p0);\
+       tmpa = vmulq_n_f32(tmpa,v);                  \
+       vstrwq_p(&data[i*numCols + w], tmpa, p0);    \
+       cnt -= 4;                                    \
+  }                                                 \
+                                                    \
+}
+
+#elif defined(ARM_MATH_NEON) && !defined(ARM_MATH_AUTOVECTORIZE)
+
+#define SWAP_ROWS_F32(A,COL,i,j)       \
+{                                      \
+  int32_t w;                           \
+  float32_t *dataI = (A)->pData;       \
+  float32_t *dataJ = (A)->pData;       \
+  const int32_t numCols = (A)->numCols;\
+  const int32_t nb = numCols - COL;    \
+                                       \
+  dataI += i*numCols + (COL);          \
+  dataJ += j*numCols + (COL);          \
+                                       \
+  float32_t tmp;                       \
+                                       \
+  for(w=0;w < nb; w++)                 \
+  {                                    \
+     tmp = *dataI;                     \
+     *dataI++ = *dataJ;                \
+     *dataJ++ = tmp;                   \
+  }                                    \
+}
+
+#define MAC_ROW_F32(COL,A,i,v,B,j)     \
+{                                      \
+  float32_t *dataA = (A)->pData;       \
+  float32_t *dataB = (B)->pData;       \
+  const int32_t numCols = (A)->numCols;\
+  const int32_t nb = numCols - (COL);  \
+  int32_t nbElems;                     \
+  f32x4_t vec = vdupq_n_f32(v);        \
+                                       \
+  nbElems = nb >> 2;                   \
+                                       \
+  dataA += i*numCols + (COL);          \
+  dataB += j*numCols + (COL);          \
+                                       \
+  while(nbElems>0)                     \
+  {                                    \
+       f32x4_t tmpa,tmpb;              \
+       tmpa = vld1q_f32(dataA,p0);     \
+       tmpb = vld1q_f32(dataB,p0);     \
+       tmpa = vmlaq_f32(tmpa,tmpb,vec);\
+       vst1q_f32(dataA, tmpa, p0);     \
+       nbElems--;                      \
+       dataA += 4;                     \
+       dataB += 4;                     \
+  }                                    \
+                                       \
+  nbElems = nb & 3;                    \
+  while(nbElems > 0)                   \
+  {                                    \
+     *dataA++ += v* *dataB++;          \
+     nbElems--;                        \
+  }                                    \
+}
+
+#define MAS_ROW_F32(COL,A,i,v,B,j)     \
+{                                      \
+  float32_t *dataA = (A)->pData;       \
+  float32_t *dataB = (B)->pData;       \
+  const int32_t numCols = (A)->numCols;\
+  const int32_t nb = numCols - (COL);  \
+  int32_t nbElems;                     \
+  f32x4_t vec = vdupq_n_f32(v);        \
+                                       \
+  nbElems = nb >> 2;                   \
+                                       \
+  dataA += i*numCols + (COL);          \
+  dataB += j*numCols + (COL);          \
+                                       \
+  while(nbElems>0)                     \
+  {                                    \
+       f32x4_t tmpa,tmpb;              \
+       tmpa = vld1q_f32(dataA);        \
+       tmpb = vld1q_f32(dataB);        \
+       tmpa = vmlsq_f32(tmpa,tmpb,vec);\
+       vst1q_f32(dataA, tmpa);         \
+       nbElems--;                      \
+       dataA += 4;                     \
+       dataB += 4;                     \
+  }                                    \
+                                       \
+  nbElems = nb & 3;                    \
+  while(nbElems > 0)                   \
+  {                                    \
+     *dataA++ -= v* *dataB++;          \
+     nbElems--;                        \
+  }                                    \
+}
+
+#define SCALE_ROW_F32(A,COL,v,i)        \
+{                                       \
+  float32_t *data = (A)->pData;         \
+  const int32_t numCols = (A)->numCols; \
+  const int32_t nb = numCols - (COL);   \
+  int32_t nbElems;                      \
+  f32x4_t vec = vdupq_n_f32(v);         \
+                                        \
+  nbElems = nb >> 2;                    \
+                                        \
+  data += i*numCols + (COL);            \
+  while(nbElems>0)                      \
+  {                                     \
+       f32x4_t tmpa;                    \
+       tmpa = vld1q_f32(data);          \
+       tmpa = vmulq_f32(tmpa,vec);      \
+       vst1q_f32(data, tmpa);           \
+       data += 4;                       \
+       nbElems --;                      \
+  }                                     \
+                                        \
+  nbElems = nb & 3;                     \
+  while(nbElems > 0)                    \
+  {                                     \
+     *data++ *= v;                      \
+     nbElems--;                         \
+  }                                     \
+                                        \
+}
+
+#else
+
+#define SWAP_ROWS_F32(A,COL,i,j)       \
+{                                      \
+  int32_t w;                           \
+  float32_t tmp;                       \
+  float32_t *dataI = (A)->pData;       \
+  float32_t *dataJ = (A)->pData;       \
+  const int32_t numCols = (A)->numCols;\
+  const int32_t nb = numCols - COL;    \
+                                       \
+  dataI += i*numCols + (COL);          \
+  dataJ += j*numCols + (COL);          \
+                                       \
+                                       \
+  for(w=0;w < nb; w++)                 \
+  {                                    \
+     tmp = *dataI;                     \
+     *dataI++ = *dataJ;                \
+     *dataJ++ = tmp;                   \
+  }                                    \
+}
+
+#define SCALE_ROW_F32(A,COL,v,i)       \
+{                                      \
+  int32_t w;                           \
+  float32_t *data = (A)->pData;        \
+  const int32_t numCols = (A)->numCols;\
+  const int32_t nb = numCols - COL;    \
+                                       \
+  data += i*numCols + (COL);           \
+                                       \
+  for(w=0;w < nb; w++)                 \
+  {                                    \
+     *data++ *= v;                     \
+  }                                    \
+}
+
+
+#define MAC_ROW_F32(COL,A,i,v,B,j)     \
+{                                      \
+  int32_t w;                           \
+  float32_t *dataA = (A)->pData;       \
+  float32_t *dataB = (B)->pData;       \
+  const int32_t numCols = (A)->numCols;\
+  const int32_t nb = numCols-(COL);    \
+                                       \
+  dataA = dataA + i*numCols + (COL);   \
+  dataB = dataB + j*numCols + (COL);   \
+                                       \
+  for(w=0;w < nb; w++)                 \
+  {                                    \
+     *dataA++ += v* *dataB++;          \
+  }                                    \
+}
+
+#define MAS_ROW_F32(COL,A,i,v,B,j)     \
+{                                      \
+  int32_t w;                           \
+  float32_t *dataA = (A)->pData;       \
+  float32_t *dataB = (B)->pData;       \
+  const int32_t numCols = (A)->numCols;\
+  const int32_t nb = numCols-(COL);    \
+                                       \
+  dataA = dataA + i*numCols + (COL);   \
+  dataB = dataB + j*numCols + (COL);   \
+                                       \
+  for(w=0;w < nb; w++)                 \
+  {                                    \
+     *dataA++ -= v* *dataB++;          \
+  }                                    \
+}
+
+#endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
+
+#define SWAP_COLS_F32(A,COL,i,j)               \
+{                                              \
+  int32_t w;                                  \
+  float32_t *data = (A)->pData;                \
+  const int32_t numCols = (A)->numCols;       \
+  for(w=(COL);w < numCols; w++)                \
+  {                                            \
+     float32_t tmp;                            \
+     tmp = data[w*numCols + i];                \
+     data[w*numCols + i] = data[w*numCols + j];\
+     data[w*numCols + j] = tmp;                \
+  }                                            \
+}
+
+#define SCALE_COL_F32(A,ROW,v,i)        \
+  SCALE_COL_T(float32_t,,A,ROW,v,i)
+
+#define SWAP_ROWS_F64(A,COL,i,j)       \
+{                                      \
+  int32_t w;                           \
+  float64_t *dataI = (A)->pData;       \
+  float64_t *dataJ = (A)->pData;       \
+  const int32_t numCols = (A)->numCols;\
+  const int32_t nb = numCols-(COL);    \
+                                       \
+  dataI += i*numCols + (COL);          \
+  dataJ += j*numCols + (COL);          \
+                                       \
+  for(w=0;w < nb; w++)                 \
+  {                                    \
+     float64_t tmp;                    \
+     tmp = *dataI;                     \
+     *dataI++ = *dataJ;                \
+     *dataJ++ = tmp;                   \
+  }                                    \
+}
+
+#define SWAP_COLS_F64(A,COL,i,j)               \
+{                                              \
+  int32_t w;                                  \
+  float64_t *data = (A)->pData;                \
+  const int32_t numCols = (A)->numCols;       \
+  for(w=(COL);w < numCols; w++)                \
+  {                                            \
+     float64_t tmp;                            \
+     tmp = data[w*numCols + i];                \
+     data[w*numCols + i] = data[w*numCols + j];\
+     data[w*numCols + j] = tmp;                \
+  }                                            \
+}
+
+#define SCALE_ROW_F64(A,COL,v,i)       \
+{                                      \
+  int32_t w;                           \
+  float64_t *data = (A)->pData;        \
+  const int32_t numCols = (A)->numCols;\
+  const int32_t nb = numCols-(COL);    \
+                                       \
+  data += i*numCols + (COL);           \
+                                       \
+  for(w=0;w < nb; w++)                 \
+  {                                    \
+     *data++ *= v;                     \
+  }                                    \
+}
+
+#define SCALE_COL_F64(A,ROW,v,i)        \
+  SCALE_COL_T(float64_t,,A,ROW,v,i)
+
+#define MAC_ROW_F64(COL,A,i,v,B,j)      \
+{                                       \
+  int32_t w;                           \
+  float64_t *dataA = (A)->pData;        \
+  float64_t *dataB = (B)->pData;        \
+  const int32_t numCols = (A)->numCols;\
+  const int32_t nb = numCols-(COL);     \
+                                        \
+  dataA += i*numCols + (COL);           \
+  dataB += j*numCols + (COL);           \
+                                        \
+  for(w=0;w < nb; w++)                  \
+  {                                     \
+     *dataA++ += v* *dataB++;           \
+  }                                     \
+}
+
+#define MAS_ROW_F64(COL,A,i,v,B,j)      \
+{                                       \
+  int32_t w;                           \
+  float64_t *dataA = (A)->pData;        \
+  float64_t *dataB = (B)->pData;        \
+  const int32_t numCols = (A)->numCols;\
+  const int32_t nb = numCols-(COL);     \
+                                        \
+  dataA += i*numCols + (COL);           \
+  dataB += j*numCols + (COL);           \
+                                        \
+  for(w=0;w < nb; w++)                  \
+  {                                     \
+     *dataA++ -= v* *dataB++;           \
+  }                                     \
+}
+
+#ifdef   __cplusplus
+}
+#endif
+
+#endif /* ifndef _MATRIX_UTILS_H_ */

+ 5 - 9
Source/MatrixFunctions/arm_mat_cholesky_f16.c

@@ -27,6 +27,7 @@
  */
 
 #include "dsp/matrix_functions_f16.h"
+#include "dsp/matrix_utils.h"
 
 #if defined(ARM_FLOAT16_SUPPORTED)
 
@@ -50,7 +51,7 @@
                    - \ref ARM_MATH_DECOMPOSITION_FAILURE      : Input matrix cannot be decomposed
    * @par
    * If the matrix is ill conditioned or only semi-definite, then it is better using the LDL^t decomposition.
-   * The decomposition of A is returning a lower triangular matrix U such that A = U U^t
+   * The decomposition of A is returning a lower triangular matrix U such that A = L L^t
    */
 
 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
@@ -164,10 +165,7 @@ arm_status arm_mat_cholesky_f16(
        }
 
        invSqrtVj = 1.0f16/(_Float16)sqrtf((float32_t)pG[i * n + i]);
-       for(j=i; j < n ; j++)
-       {
-         pG[j * n + i] = (_Float16)pG[j * n + i] * (_Float16)invSqrtVj ;
-       }
+       SCALE_COL_F16(pDst,i,invSqrtVj,i);
     }
 
     status = ARM_MATH_SUCCESS;
@@ -233,10 +231,8 @@ arm_status arm_mat_cholesky_f16(
        because doing it in f16 would not have any impact on the performances.
        */
        invSqrtVj = 1.0f/sqrtf((float32_t)pG[i * n + i]);
-       for(j=i ; j < n ; j++)
-       {
-         pG[j * n + i] = (_Float16)pG[j * n + i] * (_Float16)invSqrtVj ;
-       }
+       SCALE_COL_F16(pDst,i,invSqrtVj,i);
+
     }
 
     status = ARM_MATH_SUCCESS;

+ 7 - 14
Source/MatrixFunctions/arm_mat_cholesky_f32.c

@@ -27,6 +27,7 @@
  */
 
 #include "dsp/matrix_functions.h"
+#include "dsp/matrix_utils.h"
 
 /**
   @ingroup groupMatrix
@@ -35,7 +36,7 @@
 /**
   @defgroup MatrixChol Cholesky and LDLT decompositions
 
-  Computes the Cholesky or LDL^t decomposition of a matrix.
+  Computes the Cholesky or LL^t decomposition of a matrix.
 
 
   If the input matrix does not have a decomposition, then the 
@@ -58,7 +59,7 @@
                    - \ref ARM_MATH_DECOMPOSITION_FAILURE      : Input matrix cannot be decomposed
    * @par
    * If the matrix is ill conditioned or only semi-definite, then it is better using the LDL^t decomposition.
-   * The decomposition of A is returning a lower triangular matrix U such that A = U U^t
+   * The decomposition of A is returning a lower triangular matrix L such that A = L L^t
    */
 
 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
@@ -170,10 +171,7 @@ arm_status arm_mat_cholesky_f32(
        }
 
        invSqrtVj = 1.0f/sqrtf(pG[i * n + i]);
-       for(j=i; j < n ; j++)
-       {
-         pG[j * n + i] = pG[j * n + i] * invSqrtVj ;
-       }
+       SCALE_COL_F32(pDst,i,invSqrtVj,i);
     }
 
     status = ARM_MATH_SUCCESS;
@@ -350,10 +348,7 @@ arm_status arm_mat_cholesky_f32(
        }
 
        invSqrtVj = 1.0f/sqrtf(pG[i * n + i]);
-       for(j=i; j < n ; j++)
-       {
-         pG[j * n + i] = pG[j * n + i] * invSqrtVj ;
-       }
+       SCALE_COL_F32(pDst,i,invSqrtVj,i);
     }
 
     status = ARM_MATH_SUCCESS;
@@ -416,10 +411,8 @@ arm_status arm_mat_cholesky_f32(
        }
 
        invSqrtVj = 1.0f/sqrtf(pG[i * n + i]);
-       for(j=i ; j < n ; j++)
-       {
-         pG[j * n + i] = pG[j * n + i] * invSqrtVj ;
-       }
+       SCALE_COL_F32(pDst,i,invSqrtVj,i);
+      
     }
 
     status = ARM_MATH_SUCCESS;

+ 4 - 5
Source/MatrixFunctions/arm_mat_cholesky_f64.c

@@ -27,6 +27,7 @@
  */
 
 #include "dsp/matrix_functions.h"
+#include "dsp/matrix_utils.h"
 
 /**
   @ingroup groupMatrix
@@ -48,7 +49,7 @@
                    - \ref ARM_MATH_DECOMPOSITION_FAILURE      : Input matrix cannot be decomposed
    * @par
    * If the matrix is ill conditioned or only semi-definite, then it is better using the LDL^t decomposition.
-   * The decomposition of A is returning a lower triangular matrix U such that A = U U^t
+   * The decomposition of A is returning a lower triangular matrix L such that A = L L^t
    */
 
 
@@ -102,10 +103,8 @@ arm_status arm_mat_cholesky_f64(
        }
 
        invSqrtVj = 1.0/sqrt(pG[i * n + i]);
-       for(j=i ; j < n ; j++)
-       {
-         pG[j * n + i] = pG[j * n + i] * invSqrtVj ;
-       }
+       SCALE_COL_F64(pDst,i,invSqrtVj,i);
+
     }
 
     status = ARM_MATH_SUCCESS;

+ 53 - 675
Source/MatrixFunctions/arm_mat_inverse_f16.c

@@ -27,6 +27,7 @@
  */
 
 #include "dsp/matrix_functions_f16.h"
+#include "dsp/matrix_utils.h"
 
 #if defined(ARM_FLOAT16_SUPPORTED)
 
@@ -50,520 +51,20 @@
                    - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
                    - \ref ARM_MATH_SINGULAR      : Input matrix is found to be singular (non-invertible)
  */
-#if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
-
-arm_status arm_mat_inverse_f16(
-  const arm_matrix_instance_f16 * pSrc,
-  arm_matrix_instance_f16 * pDst)
-{
-    float16_t *pIn = pSrc->pData;   /* input data matrix pointer */
-    float16_t *pOut = pDst->pData;  /* output data matrix pointer */
-    float16_t *pInT1, *pInT2;   /* Temporary input data matrix pointer */
-    float16_t *pOutT1, *pOutT2; /* Temporary output data matrix pointer */
-    float16_t *pPivotRowIn, *pPRT_in, *pPivotRowDst, *pPRT_pDst;    /* Temporary input and output data matrix pointer */
-
-    uint32_t  numRows = pSrc->numRows;  /* Number of rows in the matrix  */
-    uint32_t  numCols = pSrc->numCols;  /* Number of Cols in the matrix  */
-    float16_t *pTmpA, *pTmpB;
-
-    _Float16 in = 0.0f16;        /* Temporary input values  */
-    uint32_t  i, rowCnt, flag = 0U, j, loopCnt, l;   /* loop counters */
-    arm_status status;          /* status of matrix inverse */
-    uint32_t  blkCnt;
-
-#ifdef ARM_MATH_MATRIX_CHECK
-   /* Check for matrix mismatch condition */
-  if ((pSrc->numRows != pSrc->numCols) || (pDst->numRows != pDst->numCols)
-     || (pSrc->numRows != pDst->numRows))
-  {
-    /* Set status as ARM_MATH_SIZE_MISMATCH */
-    status = ARM_MATH_SIZE_MISMATCH;
-  }
-  else
-#endif /*    #ifdef ARM_MATH_MATRIX_CHECK    */
-  {
-
-    /*--------------------------------------------------------------------------------------------------------------
-     * Matrix Inverse can be solved using elementary row operations.
-     *
-     *  Gauss-Jordan Method:
-     *
-     *     1. First combine the identity matrix and the input matrix separated by a bar to form an
-     *        augmented matrix as follows:
-     *                      _  _          _     _      _   _         _         _
-     *                     |  |  a11  a12  | | | 1   0  |   |       |  X11 X12  |
-     *                     |  |            | | |        |   |   =   |           |
-     *                     |_ |_ a21  a22 _| | |_0   1 _|  _|       |_ X21 X21 _|
-     *
-     *      2. In our implementation, pDst Matrix is used as identity matrix.
-     *
-     *      3. Begin with the first row. Let i = 1.
-     *
-     *      4. Check to see if the pivot for row i is zero.
-     *         The pivot is the element of the main diagonal that is on the current row.
-     *         For instance, if working with row i, then the pivot element is aii.
-     *         If the pivot is zero, exchange that row with a row below it that does not
-     *         contain a zero in column i. If this is not possible, then an inverse
-     *         to that matrix does not exist.
-     *
-     *      5. Divide every element of row i by the pivot.
-     *
-     *      6. For every row below and  row i, replace that row with the sum of that row and
-     *         a multiple of row i so that each new element in column i below row i is zero.
-     *
-     *      7. Move to the next row and column and repeat steps 2 through 5 until you have zeros
-     *         for every element below and above the main diagonal.
-     *
-     *      8. Now an identical matrix is formed to the left of the bar(input matrix, src).
-     *         Therefore, the matrix to the right of the bar is our solution(dst matrix, dst).
-     *----------------------------------------------------------------------------------------------------------------*/
-
-        /*
-         * Working pointer for destination matrix
-         */
-        pOutT1 = pOut;
-        /*
-         * Loop over the number of rows
-         */
-        rowCnt = numRows;
-        /*
-         * Making the destination matrix as identity matrix
-         */
-        while (rowCnt > 0U)
-        {
-            /*
-             * Writing all zeroes in lower triangle of the destination matrix
-             */
-            j = numRows - rowCnt;
-            while (j > 0U)
-            {
-                *pOutT1++ = 0.0f16;
-                j--;
-            }
-            /*
-             * Writing all ones in the diagonal of the destination matrix
-             */
-            *pOutT1++ = 1.0f16;
-            /*
-             * Writing all zeroes in upper triangle of the destination matrix
-             */
-            j = rowCnt - 1U;
-            while (j > 0U)
-            {
-                *pOutT1++ = 0.0f16;
-                j--;
-            }
-            /*
-             * Decrement the loop counter
-             */
-            rowCnt--;
-        }
-
-        /*
-         * Loop over the number of columns of the input matrix.
-         * All the elements in each column are processed by the row operations
-         */
-        loopCnt = numCols;
-        /*
-         * Index modifier to navigate through the columns
-         */
-        l = 0U;
-        while (loopCnt > 0U)
-        {
-            /*
-             * Check if the pivot element is zero..
-             * If it is zero then interchange the row with non zero row below.
-             * If there is no non zero element to replace in the rows below,
-             * then the matrix is Singular.
-             */
-
-            /*
-             * Working pointer for the input matrix that points
-             * * to the pivot element of the particular row
-             */
-            pInT1 = pIn + (l * numCols);
-            /*
-             * Working pointer for the destination matrix that points
-             * * to the pivot element of the particular row
-             */
-            pOutT1 = pOut + (l * numCols);
-            /*
-             * Temporary variable to hold the pivot value
-             */
-            in = *pInT1;
-            
-
-            /*
-             * Check if the pivot element is zero
-             */
-            if ((_Float16)*pInT1 == 0.0f16)
-            {
-                /*
-                 * Loop over the number rows present below
-                 */
-                for (i = 1U; i < numRows-l; i++)
-                {
-                    /*
-                     * Update the input and destination pointers
-                     */
-                    pInT2 = pInT1 + (numCols * i);
-                    pOutT2 = pOutT1 + (numCols * i);
-                    /*
-                     * Check if there is a non zero pivot element to
-                     * * replace in the rows below
-                     */
-                    if ((_Float16)*pInT2 != 0.0f16)
-                    {
-                        f16x8_t vecA, vecB;
-                        /*
-                         * Loop over number of columns
-                         * * to the right of the pilot element
-                         */
-                        pTmpA = pInT1;
-                        pTmpB = pInT2;
-                        blkCnt = (numCols - l) >> 3;
-                        while (blkCnt > 0U)
-                        {
-                            
-                            vecA = vldrhq_f16(pTmpA);
-                            vecB = vldrhq_f16(pTmpB);
-                            vstrhq_f16(pTmpB, vecA);
-                            vstrhq_f16(pTmpA, vecB);
-
-                            pTmpA += 8;
-                            pTmpB += 8;
-                            /*
-                             * Decrement the blockSize loop counter
-                             */
-                            blkCnt--;
-                        }
-                        /*
-                         * tail
-                         * (will be merged thru tail predication)
-                         */
-                        blkCnt = (numCols - l) & 7;
-                        if (blkCnt > 0U)
-                        {
-                            mve_pred16_t p0 = vctp16q(blkCnt);
-
-                            vecA = vldrhq_f16(pTmpA);
-                            vecB = vldrhq_f16(pTmpB);
-                            vstrhq_p_f16(pTmpB, vecA, p0);
-                            vstrhq_p_f16(pTmpA, vecB, p0);
-                        }
-
-                        pInT1 += numCols - l;
-                        pInT2 += numCols - l;
-                        pTmpA = pOutT1;
-                        pTmpB = pOutT2;
-                        blkCnt = numCols >> 3;
-                        while (blkCnt > 0U)
-                        {
-
-                            vecA = vldrhq_f16(pTmpA);
-                            vecB = vldrhq_f16(pTmpB);
-                            vstrhq_f16(pTmpB, vecA);
-                            vstrhq_f16(pTmpA, vecB);
-                            pTmpA += 8;
-                            pTmpB += 8;
-                            /*
-                             * Decrement the blockSize loop counter
-                             */
-                            blkCnt--;
-                        }
-                        /*
-                         * tail
-                         */
-                        blkCnt = numCols & 7;
-                        if (blkCnt > 0U)
-                        {
-                            mve_pred16_t p0 = vctp16q(blkCnt);
-
-                            vecA = vldrhq_f16(pTmpA);
-                            vecB = vldrhq_f16(pTmpB);
-                            vstrhq_p_f16(pTmpB, vecA, p0);
-                            vstrhq_p_f16(pTmpA, vecB, p0);
-                        }
-
-                        pOutT1 += numCols;
-                        pOutT2 += numCols;
-                        /*
-                         * Flag to indicate whether exchange is done or not
-                         */
-                        flag = 1U;
-
-                        /*
-                         * Break after exchange is done
-                         */
-                        break;
-                    }
-              
-                }
-            }
-
-            /*
-             * Update the status if the matrix is singular
-             */
-            if ((flag != 1U) && (in == 0.0f16))
-            {
-                return ARM_MATH_SINGULAR;
-            }
-
-            /*
-             * Points to the pivot row of input and destination matrices
-             */
-            pPivotRowIn = pIn + (l * numCols);
-            pPivotRowDst = pOut + (l * numCols);
-
-            /*
-             * Temporary pointers to the pivot row pointers
-             */
-            pInT1 = pPivotRowIn;
-            pOutT1 = pPivotRowDst;
-
-            /*
-             * Pivot element of the row
-             */
-            in = *(pIn + (l * numCols));
-
-            pTmpA = pInT1;
-
-            f16x8_t invIn = vdupq_n_f16(1.0f16 / in);
-
-            blkCnt = (numCols - l) >> 3;
-            f16x8_t vecA;
-            while (blkCnt > 0U)
-            {
-                *(f16x8_t *) pTmpA = *(f16x8_t *) pTmpA * invIn;
-                pTmpA += 8;
-                /*
-                 * Decrement the blockSize loop counter
-                 */
-                blkCnt--;
-            }
-            /*
-             * tail
-             */
-            blkCnt = (numCols - l) & 7;
-            if (blkCnt > 0U)
-            {
-                mve_pred16_t p0 = vctp16q(blkCnt);
-                
-
-                vecA = vldrhq_f16(pTmpA);
-                vecA = vecA * invIn;
-                vstrhq_p_f16(pTmpA, vecA, p0);
-            }
-
-            pInT1 += numCols - l;
-            /*
-             * Loop over number of columns
-             * * to the right of the pilot element
-             */
-
-            pTmpA = pOutT1;
-            blkCnt = numCols >> 3;
-            while (blkCnt > 0U)
-            {
-                *(f16x8_t *) pTmpA = *(f16x8_t *) pTmpA *invIn;
-                pTmpA += 8;
-                /*
-                 * Decrement the blockSize loop counter
-                 */
-                blkCnt--;
-            }
-            /*
-             * tail
-             * (will be merged thru tail predication)
-             */
-            blkCnt = numCols & 7;
-            if (blkCnt > 0U)
-            {
-                mve_pred16_t p0 = vctp16q(blkCnt);
-
-                vecA = vldrhq_f16(pTmpA);
-                vecA = vecA * invIn;
-                vstrhq_p_f16(pTmpA, vecA, p0);
-            }
-
-            pOutT1 += numCols;
-
-            /*
-             * Replace the rows with the sum of that row and a multiple of row i
-             * * so that each new element in column i above row i is zero.
-             */
-
-            /*
-             * Temporary pointers for input and destination matrices
-             */
-            pInT1 = pIn;
-            pOutT1 = pOut;
-
-            for (i = 0U; i < numRows; i++)
-            {
-                /*
-                 * Check for the pivot element
-                 */
-                if (i == l)
-                {
-                    /*
-                     * If the processing element is the pivot element,
-                     * only the columns to the right are to be processed
-                     */
-                    pInT1 += numCols - l;
-                    pOutT1 += numCols;
-                }
-                else
-                {
-                    /*
-                     * Element of the reference row
-                     */
-
-                    /*
-                     * Working pointers for input and destination pivot rows
-                     */
-                    pPRT_in = pPivotRowIn;
-                    pPRT_pDst = pPivotRowDst;
-                    /*
-                     * Loop over the number of columns to the right of the pivot element,
-                     * to replace the elements in the input matrix
-                     */
-
-                    in = *pInT1;
-                    f16x8_t tmpV = vdupq_n_f16(in);
-
-                    blkCnt = (numCols - l) >> 3;
-                    while (blkCnt > 0U)
-                    {
-                        f16x8_t vec1, vec2;
-                        /*
-                         * Replace the element by the sum of that row
-                         * and a multiple of the reference row
-                         */
-                        vec1 = vldrhq_f16(pInT1);
-                        vec2 = vldrhq_f16(pPRT_in);
-                        vec1 = vfmsq_f16(vec1, tmpV, vec2);
-                        vstrhq_f16(pInT1, vec1);
-                        pPRT_in += 8;
-                        pInT1 += 8;
-                        /*
-                         * Decrement the blockSize loop counter
-                         */
-                        blkCnt--;
-                    }
-                    /*
-                     * tail
-                     * (will be merged thru tail predication)
-                     */
-                    blkCnt = (numCols - l) & 7;
-                    if (blkCnt > 0U)
-                    {
-                        f16x8_t vec1, vec2;
-                        mve_pred16_t p0 = vctp16q(blkCnt);
-
-                        vec1 = vldrhq_f16(pInT1);
-                        vec2 = vldrhq_f16(pPRT_in);
-                        vec1 = vfmsq_f16(vec1, tmpV, vec2);
-                        vstrhq_p_f16(pInT1, vec1, p0);
-                        pInT1 += blkCnt;
-                    }
-
-                    blkCnt = numCols >> 3;
-                    while (blkCnt > 0U)
-                    {
-                        f16x8_t vec1, vec2;
-
-                        /*
-                         * Replace the element by the sum of that row
-                         * and a multiple of the reference row
-                         */
-                        vec1 = vldrhq_f16(pOutT1);
-                        vec2 = vldrhq_f16(pPRT_pDst);
-                        vec1 = vfmsq_f16(vec1, tmpV, vec2);
-                        vstrhq_f16(pOutT1, vec1);
-                        pPRT_pDst += 8;
-                        pOutT1 += 8;
-                        /*
-                         * Decrement the blockSize loop counter
-                         */
-                        blkCnt--;
-                    }
-                    /*
-                     * tail
-                     * (will be merged thru tail predication)
-                     */
-                    blkCnt = numCols & 7;
-                    if (blkCnt > 0U)
-                    {
-                        f16x8_t vec1, vec2;
-                        mve_pred16_t p0 = vctp16q(blkCnt);
-
-                        vec1 = vldrhq_f16(pOutT1);
-                        vec2 = vldrhq_f16(pPRT_pDst);
-                        vec1 = vfmsq_f16(vec1, tmpV, vec2);
-                        vstrhq_p_f16(pOutT1, vec1, p0);
-
-                        pInT2 += blkCnt;
-                        pOutT1 += blkCnt;
-                    }
-                }
-                /*
-                 * Increment the temporary input pointer
-                 */
-                pInT1 = pInT1 + l;
-            }
-            /*
-             * Increment the input pointer
-             */
-            pIn++;
-            /*
-             * Decrement the loop counter
-             */
-            loopCnt--;
-            /*
-             * Increment the index modifier
-             */
-            l++;
-        }
-
-        /*
-         * Set status as ARM_MATH_SUCCESS
-         */
-        status = ARM_MATH_SUCCESS;
-
-        if ((flag != 1U) && (in == 0.0f16))
-        {
-            pIn = pSrc->pData;
-            for (i = 0; i < numRows * numCols; i++)
-            {
-                if ((_Float16)pIn[i] != 0.0f16)
-                    break;
-            }
-
-            if (i == numRows * numCols)
-                status = ARM_MATH_SINGULAR;
-        }
-  }
-  /* Return to application */
-  return (status);
-}
-
-#else
-
 arm_status arm_mat_inverse_f16(
   const arm_matrix_instance_f16 * pSrc,
         arm_matrix_instance_f16 * pDst)
 {
   float16_t *pIn = pSrc->pData;                  /* input data matrix pointer */
   float16_t *pOut = pDst->pData;                 /* output data matrix pointer */
-  float16_t *pInT1, *pInT2;                      /* Temporary input data matrix pointer */
-  float16_t *pOutT1, *pOutT2;                    /* Temporary output data matrix pointer */
-  float16_t *pPivotRowIn, *pPRT_in, *pPivotRowDst, *pPRT_pDst;  /* Temporary input and output data matrix pointer */
+  
+  float16_t *pTmp;
   uint32_t numRows = pSrc->numRows;              /* Number of rows in the matrix  */
   uint32_t numCols = pSrc->numCols;              /* Number of Cols in the matrix  */
 
-  _Float16 Xchg, in = 0.0f16, in1;                /* Temporary input values  */
-  uint32_t i, rowCnt, flag = 0U, j, loopCnt, k,l;      /* loop counters */
+
+  float16_t pivot = 0.0f16, newPivot=0.0f16;                /* Temporary input values  */
+  uint32_t selectedRow,pivotRow,i, rowNb, rowCnt, flag = 0U, j,column;      /* loop counters */
   arm_status status;                             /* status of matrix inverse */
 
 #ifdef ARM_MATH_MATRIX_CHECK
@@ -581,7 +82,6 @@ arm_status arm_mat_inverse_f16(
 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
 
   {
-
     /*--------------------------------------------------------------------------------------------------------------
      * Matrix Inverse can be solved using elementary row operations.
      *
@@ -618,7 +118,7 @@ arm_status arm_mat_inverse_f16(
      *----------------------------------------------------------------------------------------------------------------*/
 
     /* Working pointer for destination matrix */
-    pOutT1 = pOut;
+    pTmp = pOut;
 
     /* Loop over the number of rows */
     rowCnt = numRows;
@@ -630,18 +130,18 @@ arm_status arm_mat_inverse_f16(
       j = numRows - rowCnt;
       while (j > 0U)
       {
-        *pOutT1++ = 0.0f16;
+        *pTmp++ = 0.0f16;
         j--;
       }
 
       /* Writing all ones in the diagonal of the destination matrix */
-      *pOutT1++ = 1.0f16;
+      *pTmp++ = 1.0f16;
 
       /* Writing all zeroes in upper triangle of the destination matrix */
       j = rowCnt - 1U;
       while (j > 0U)
       {
-        *pOutT1++ = 0.0f16;
+        *pTmp++ = 0.0f16;
         j--;
       }
 
@@ -651,220 +151,100 @@ arm_status arm_mat_inverse_f16(
 
     /* Loop over the number of columns of the input matrix.
        All the elements in each column are processed by the row operations */
-    loopCnt = numCols;
 
     /* Index modifier to navigate through the columns */
-    l = 0U;
-
-    while (loopCnt > 0U)
+    for(column = 0U; column < numCols; column++)
     {
       /* Check if the pivot element is zero..
        * If it is zero then interchange the row with non zero row below.
        * If there is no non zero element to replace in the rows below,
        * then the matrix is Singular. */
 
-      /* Working pointer for the input matrix that points
-       * to the pivot element of the particular row  */
-      pInT1 = pIn + (l * numCols);
-
-      /* Working pointer for the destination matrix that points
-       * to the pivot element of the particular row  */
-      pOutT1 = pOut + (l * numCols);
+      pivotRow = column;
 
       /* Temporary variable to hold the pivot value */
-      in = *pInT1;
-
+      pTmp = ELEM(pSrc,column,column) ;
+      pivot = *pTmp;
+      selectedRow = column;
 
-      /* Check if the pivot element is zero */
-      if ((_Float16)*pInT1 == 0.0f16)
-      {
+     
         /* Loop over the number rows present below */
 
-        for (i = 1U; i < numRows-l; i++)
-        {
+      for (rowNb = column+1; rowNb < numRows; rowNb++)
+      {
           /* Update the input and destination pointers */
-          pInT2 = pInT1 + (numCols * i);
-          pOutT2 = pOutT1 + (numCols * i);
+          pTmp = ELEM(pSrc,rowNb,column);
+          newPivot = *pTmp;
+          if (fabsf((float32_t)newPivot) > fabsf((float32_t)pivot))
+          {
+            selectedRow = rowNb; 
+            pivot = newPivot;
+          }
+
+      }
 
           /* Check if there is a non zero pivot element to
            * replace in the rows below */
-          if ((_Float16)*pInT2 != 0.0f16)
-          {
+      if (((_Float16)pivot != 0.0f16) && (selectedRow != column))
+      {
             /* Loop over number of columns
              * to the right of the pilot element */
-            j = numCols - l;
-
-            while (j > 0U)
-            {
-              /* Exchange the row elements of the input matrix */
-              Xchg = *pInT2;
-              *pInT2++ = *pInT1;
-              *pInT1++ = Xchg;
 
-              /* Decrement the loop counter */
-              j--;
-            }
-
-            /* Loop over number of columns of the destination matrix */
-            j = numCols;
-
-            while (j > 0U)
-            {
-              /* Exchange the row elements of the destination matrix */
-              Xchg = *pOutT2;
-              *pOutT2++ = *pOutT1;
-              *pOutT1++ = Xchg;
-
-              /* Decrement loop counter */
-              j--;
-            }
+            SWAP_ROWS_F16(pSrc,column, pivotRow,selectedRow);
+            SWAP_ROWS_F16(pDst,0, pivotRow,selectedRow);
 
+    
             /* Flag to indicate whether exchange is done or not */
             flag = 1U;
 
-            /* Break after exchange is done */
-            break;
-          }
-
-        }
       }
 
+
       /* Update the status if the matrix is singular */
-      if ((flag != 1U) && (in == 0.0f16))
+      if ((flag != 1U) && ((_Float16)pivot == 0.0f16))
       {
         return ARM_MATH_SINGULAR;
       }
 
-      /* Points to the pivot row of input and destination matrices */
-      pPivotRowIn = pIn + (l * numCols);
-      pPivotRowDst = pOut + (l * numCols);
-
-      /* Temporary pointers to the pivot row pointers */
-      pInT1 = pPivotRowIn;
-      pInT2 = pPivotRowDst;
-
+     
       /* Pivot element of the row */
-      in = *pPivotRowIn;
+      pivot = 1.0f16 / (_Float16)pivot;
 
-      /* Loop over number of columns
-       * to the right of the pilot element */
-      j = (numCols - l);
-
-      while (j > 0U)
-      {
-        /* Divide each element of the row of the input matrix
-         * by the pivot element */
-        in1 = *pInT1;
-        *pInT1++ = in1 / in;
-
-        /* Decrement the loop counter */
-        j--;
-      }
-
-      /* Loop over number of columns of the destination matrix */
-      j = numCols;
-
-      while (j > 0U)
-      {
-        /* Divide each element of the row of the destination matrix
-         * by the pivot element */
-        in1 = *pInT2;
-        *pInT2++ = in1 / in;
-
-        /* Decrement the loop counter */
-        j--;
-      }
+      SCALE_ROW_F16(pSrc,column,pivot,pivotRow);
+      SCALE_ROW_F16(pDst,0,pivot,pivotRow);
 
+      
       /* Replace the rows with the sum of that row and a multiple of row i
        * so that each new element in column i above row i is zero.*/
 
-      /* Temporary pointers for input and destination matrices */
-      pInT1 = pIn;
-      pInT2 = pOut;
-
-      /* index used to check for pivot element */
-      i = 0U;
-
-      /* Loop over number of rows */
-      /*  to be replaced by the sum of that row and a multiple of row i */
-      k = numRows;
-
-      while (k > 0U)
+      rowNb = 0;
+      for (;rowNb < pivotRow; rowNb++)
       {
-        /* Check for the pivot element */
-        if (i == l)
-        {
-          /* If the processing element is the pivot element,
-             only the columns to the right are to be processed */
-          pInT1 += numCols - l;
-
-          pInT2 += numCols;
-        }
-        else
-        {
-          /* Element of the reference row */
-          in = *pInT1;
-
-          /* Working pointers for input and destination pivot rows */
-          pPRT_in = pPivotRowIn;
-          pPRT_pDst = pPivotRowDst;
-
-          /* Loop over the number of columns to the right of the pivot element,
-             to replace the elements in the input matrix */
-          j = (numCols - l);
-
-          while (j > 0U)
-          {
-            /* Replace the element by the sum of that row
-               and a multiple of the reference row  */
-            in1 = *pInT1;
-            *pInT1++ = (_Float16)in1 - ((_Float16)in * (_Float16)*pPRT_in++);
-
-            /* Decrement the loop counter */
-            j--;
-          }
-
-          /* Loop over the number of columns to
-             replace the elements in the destination matrix */
-          j = numCols;
+           pTmp = ELEM(pSrc,rowNb,column) ;
+           pivot = *pTmp;
 
-          while (j > 0U)
-          {
-            /* Replace the element by the sum of that row
-               and a multiple of the reference row  */
-            in1 = *pInT2;
-            *pInT2++ = (_Float16)in1 - ((_Float16)in * (_Float16)*pPRT_pDst++);
+           MAS_ROW_F16(column,pSrc,rowNb,pivot,pSrc,pivotRow);
+           MAS_ROW_F16(0     ,pDst,rowNb,pivot,pDst,pivotRow);
 
-            /* Decrement loop counter */
-            j--;
-          }
 
-        }
+      }
 
-        /* Increment temporary input pointer */
-        pInT1 = pInT1 + l;
+      for (rowNb = pivotRow + 1; rowNb < numRows; rowNb++)
+      {
+           pTmp = ELEM(pSrc,rowNb,column) ;
+           pivot = *pTmp;
 
-        /* Decrement loop counter */
-        k--;
+           MAS_ROW_F16(column,pSrc,rowNb,pivot,pSrc,pivotRow);
+           MAS_ROW_F16(0     ,pDst,rowNb,pivot,pDst,pivotRow);
 
-        /* Increment pivot index */
-        i++;
       }
 
-      /* Increment the input pointer */
-      pIn++;
-
-      /* Decrement the loop counter */
-      loopCnt--;
-
-      /* Increment the index modifier */
-      l++;
     }
 
     /* Set status as ARM_MATH_SUCCESS */
     status = ARM_MATH_SUCCESS;
 
-    if ((flag != 1U) && ((_Float16)in == 0.0f16))
+    if ((flag != 1U) && ((_Float16)pivot == 0.0f16))
     {
       pIn = pSrc->pData;
       for (i = 0; i < numRows * numCols; i++)
@@ -881,8 +261,6 @@ arm_status arm_mat_inverse_f16(
   /* Return to application */
   return (status);
 }
-#endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
-
 /**
   @} end of MatrixInv group
  */

ファイルの差分が大きいため隠しています
+ 55 - 1338
Source/MatrixFunctions/arm_mat_inverse_f32.c


+ 49 - 430
Source/MatrixFunctions/arm_mat_inverse_f64.c

@@ -27,6 +27,7 @@
  */
 
 #include "dsp/matrix_functions.h"
+#include "dsp/matrix_utils.h"
 
 /**
   @ingroup groupMatrix
@@ -54,16 +55,14 @@ arm_status arm_mat_inverse_f64(
 {
   float64_t *pIn = pSrc->pData;                  /* input data matrix pointer */
   float64_t *pOut = pDst->pData;                 /* output data matrix pointer */
-  float64_t *pInT1, *pInT2;                      /* Temporary input data matrix pointer */
-  float64_t *pOutT1, *pOutT2;                    /* Temporary output data matrix pointer */
-  float64_t *pPivotRowIn, *pPRT_in, *pPivotRowDst, *pPRT_pDst;  /* Temporary input and output data matrix pointer */
+  
+  float64_t *pTmp;
   uint32_t numRows = pSrc->numRows;              /* Number of rows in the matrix  */
   uint32_t numCols = pSrc->numCols;              /* Number of Cols in the matrix  */
 
-#if defined (ARM_MATH_DSP)
 
-  float64_t Xchg, in = 0.0, in1;                /* Temporary input values  */
-  uint32_t i, rowCnt, flag = 0U, j, loopCnt, k,l;      /* loop counters */
+  float64_t pivot = 0.0, newPivot=0.0;                /* Temporary input values  */
+  uint32_t selectedRow,pivotRow,i, rowNb, rowCnt, flag = 0U, j,column;      /* loop counters */
   arm_status status;                             /* status of matrix inverse */
 
 #ifdef ARM_MATH_MATRIX_CHECK
@@ -81,7 +80,6 @@ arm_status arm_mat_inverse_f64(
 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
 
   {
-
     /*--------------------------------------------------------------------------------------------------------------
      * Matrix Inverse can be solved using elementary row operations.
      *
@@ -118,7 +116,7 @@ arm_status arm_mat_inverse_f64(
      *----------------------------------------------------------------------------------------------------------------*/
 
     /* Working pointer for destination matrix */
-    pOutT1 = pOut;
+    pTmp = pOut;
 
     /* Loop over the number of rows */
     rowCnt = numRows;
@@ -130,18 +128,18 @@ arm_status arm_mat_inverse_f64(
       j = numRows - rowCnt;
       while (j > 0U)
       {
-        *pOutT1++ = 0.0;
+        *pTmp++ = 0.0;
         j--;
       }
 
       /* Writing all ones in the diagonal of the destination matrix */
-      *pOutT1++ = 1.0;
+      *pTmp++ = 1.0;
 
       /* Writing all zeroes in upper triangle of the destination matrix */
       j = rowCnt - 1U;
       while (j > 0U)
       {
-        *pOutT1++ = 0.0;
+        *pTmp++ = 0.0;
         j--;
       }
 
@@ -151,477 +149,99 @@ arm_status arm_mat_inverse_f64(
 
     /* Loop over the number of columns of the input matrix.
        All the elements in each column are processed by the row operations */
-    loopCnt = numCols;
 
     /* Index modifier to navigate through the columns */
-    l = 0U;
-
-    while (loopCnt > 0U)
+    for(column = 0U; column < numCols; column++)
     {
       /* Check if the pivot element is zero..
        * If it is zero then interchange the row with non zero row below.
        * If there is no non zero element to replace in the rows below,
        * then the matrix is Singular. */
 
-      /* Working pointer for the input matrix that points
-       * to the pivot element of the particular row  */
-      pInT1 = pIn + (l * numCols);
-
-      /* Working pointer for the destination matrix that points
-       * to the pivot element of the particular row  */
-      pOutT1 = pOut + (l * numCols);
+      pivotRow = column;
 
       /* Temporary variable to hold the pivot value */
-      in = *pInT1;
+      pTmp = ELEM(pSrc,column,column) ;
+      pivot = *pTmp;
+      selectedRow = column;
 
-    
-
-      /* Check if the pivot element is zero */
-      if (*pInT1 == 0.0)
-      {
+      
         /* Loop over the number rows present below */
 
-        for (i = 1U; i < numRows - l; i++)
-        {
-          /* Update the input and destination pointers */
-          pInT2 = pInT1 + (numCols * i);
-          pOutT2 = pOutT1 + (numCols * i);
-
-          /* Check if there is a non zero pivot element to
-           * replace in the rows below */
-          if (*pInT2 != 0.0)
-          {
-            /* Loop over number of columns
-             * to the right of the pilot element */
-            j = numCols - l;
-
-            while (j > 0U)
-            {
-              /* Exchange the row elements of the input matrix */
-              Xchg = *pInT2;
-              *pInT2++ = *pInT1;
-              *pInT1++ = Xchg;
-
-              /* Decrement the loop counter */
-              j--;
-            }
-
-            /* Loop over number of columns of the destination matrix */
-            j = numCols;
-
-            while (j > 0U)
-            {
-              /* Exchange the row elements of the destination matrix */
-              Xchg = *pOutT2;
-              *pOutT2++ = *pOutT1;
-              *pOutT1++ = Xchg;
-
-              /* Decrement loop counter */
-              j--;
-            }
-
-            /* Flag to indicate whether exchange is done or not */
-            flag = 1U;
-
-            /* Break after exchange is done */
-            break;
-          }
-
-
-          /* Decrement loop counter */
-        }
-      }
-
-      /* Update the status if the matrix is singular */
-      if ((flag != 1U) && (in == 0.0))
-      {
-        return ARM_MATH_SINGULAR;
-      }
-
-      /* Points to the pivot row of input and destination matrices */
-      pPivotRowIn = pIn + (l * numCols);
-      pPivotRowDst = pOut + (l * numCols);
-
-      /* Temporary pointers to the pivot row pointers */
-      pInT1 = pPivotRowIn;
-      pInT2 = pPivotRowDst;
-
-      /* Pivot element of the row */
-      in = *pPivotRowIn;
-
-      /* Loop over number of columns
-       * to the right of the pilot element */
-      j = (numCols - l);
-
-      while (j > 0U)
-      {
-        /* Divide each element of the row of the input matrix
-         * by the pivot element */
-        in1 = *pInT1;
-        *pInT1++ = in1 / in;
-
-        /* Decrement the loop counter */
-        j--;
-      }
-
-      /* Loop over number of columns of the destination matrix */
-      j = numCols;
-
-      while (j > 0U)
+      for (rowNb = column+1; rowNb < numRows; rowNb++)
       {
-        /* Divide each element of the row of the destination matrix
-         * by the pivot element */
-        in1 = *pInT2;
-        *pInT2++ = in1 / in;
-
-        /* Decrement the loop counter */
-        j--;
-      }
-
-      /* Replace the rows with the sum of that row and a multiple of row i
-       * so that each new element in column i above row i is zero.*/
-
-      /* Temporary pointers for input and destination matrices */
-      pInT1 = pIn;
-      pInT2 = pOut;
-
-      /* index used to check for pivot element */
-      i = 0U;
-
-      /* Loop over number of rows */
-      /*  to be replaced by the sum of that row and a multiple of row i */
-      k = numRows;
-
-      while (k > 0U)
-      {
-        /* Check for the pivot element */
-        if (i == l)
-        {
-          /* If the processing element is the pivot element,
-             only the columns to the right are to be processed */
-          pInT1 += numCols - l;
-
-          pInT2 += numCols;
-        }
-        else
-        {
-          /* Element of the reference row */
-          in = *pInT1;
-
-          /* Working pointers for input and destination pivot rows */
-          pPRT_in = pPivotRowIn;
-          pPRT_pDst = pPivotRowDst;
-
-          /* Loop over the number of columns to the right of the pivot element,
-             to replace the elements in the input matrix */
-          j = (numCols - l);
-
-          while (j > 0U)
-          {
-            /* Replace the element by the sum of that row
-               and a multiple of the reference row  */
-            in1 = *pInT1;
-            *pInT1++ = in1 - (in * *pPRT_in++);
-
-            /* Decrement the loop counter */
-            j--;
-          }
-
-          /* Loop over the number of columns to
-             replace the elements in the destination matrix */
-          j = numCols;
-
-          while (j > 0U)
+          /* Update the input and destination pointers */
+          pTmp = ELEM(pSrc,rowNb,column);
+          newPivot = *pTmp;
+          if (fabs(newPivot) > fabs(pivot))
           {
-            /* Replace the element by the sum of that row
-               and a multiple of the reference row  */
-            in1 = *pInT2;
-            *pInT2++ = in1 - (in * *pPRT_pDst++);
-
-            /* Decrement loop counter */
-            j--;
+            selectedRow = rowNb; 
+            pivot = newPivot;
           }
-
-        }
-
-        /* Increment temporary input pointer */
-        pInT1 = pInT1 + l;
-
-        /* Decrement loop counter */
-        k--;
-
-        /* Increment pivot index */
-        i++;
-      }
-
-      /* Increment the input pointer */
-      pIn++;
-
-      /* Decrement the loop counter */
-      loopCnt--;
-
-      /* Increment the index modifier */
-      l++;
-    }
-
-
-#else
-
-  float64_t Xchg, in = 0.0;                     /* Temporary input values  */
-  uint32_t i, rowCnt, flag = 0U, j, loopCnt, l;      /* loop counters */
-  arm_status status;                             /* status of matrix inverse */
-
-#ifdef ARM_MATH_MATRIX_CHECK
-
-  /* Check for matrix mismatch condition */
-  if ((pSrc->numRows != pSrc->numCols) ||
-      (pDst->numRows != pDst->numCols) ||
-      (pSrc->numRows != pDst->numRows)   )
-  {
-    /* Set status as ARM_MATH_SIZE_MISMATCH */
-    status = ARM_MATH_SIZE_MISMATCH;
-  }
-  else
-
-#endif /* #ifdef ARM_MATH_MATRIX_CHECK */
-
-  {
-
-    /*--------------------------------------------------------------------------------------------------------------
-     * Matrix Inverse can be solved using elementary row operations.
-     *
-     *  Gauss-Jordan Method:
-     *
-     *      1. First combine the identity matrix and the input matrix separated by a bar to form an
-     *        augmented matrix as follows:
-     *                      _  _          _     _      _   _         _         _
-     *                     |  |  a11  a12  | | | 1   0  |   |       |  X11 X12  |
-     *                     |  |            | | |        |   |   =   |           |
-     *                     |_ |_ a21  a22 _| | |_0   1 _|  _|       |_ X21 X21 _|
-     *
-     *      2. In our implementation, pDst Matrix is used as identity matrix.
-     *
-     *      3. Begin with the first row. Let i = 1.
-     *
-     *      4. Check to see if the pivot for row i is zero.
-     *         The pivot is the element of the main diagonal that is on the current row.
-     *         For instance, if working with row i, then the pivot element is aii.
-     *         If the pivot is zero, exchange that row with a row below it that does not
-     *         contain a zero in column i. If this is not possible, then an inverse
-     *         to that matrix does not exist.
-     *
-     *      5. Divide every element of row i by the pivot.
-     *
-     *      6. For every row below and  row i, replace that row with the sum of that row and
-     *         a multiple of row i so that each new element in column i below row i is zero.
-     *
-     *      7. Move to the next row and column and repeat steps 2 through 5 until you have zeros
-     *         for every element below and above the main diagonal.
-     *
-     *      8. Now an identical matrix is formed to the left of the bar(input matrix, src).
-     *         Therefore, the matrix to the right of the bar is our solution(dst matrix, dst).
-     *----------------------------------------------------------------------------------------------------------------*/
-
-    /* Working pointer for destination matrix */
-    pOutT1 = pOut;
-
-    /* Loop over the number of rows */
-    rowCnt = numRows;
-
-    /* Making the destination matrix as identity matrix */
-    while (rowCnt > 0U)
-    {
-      /* Writing all zeroes in lower triangle of the destination matrix */
-      j = numRows - rowCnt;
-      while (j > 0U)
-      {
-        *pOutT1++ = 0.0;
-        j--;
       }
 
-      /* Writing all ones in the diagonal of the destination matrix */
-      *pOutT1++ = 1.0;
-
-      /* Writing all zeroes in upper triangle of the destination matrix */
-      j = rowCnt - 1U;
-      while (j > 0U)
-      {
-        *pOutT1++ = 0.0;
-        j--;
-      }
-
-      /* Decrement loop counter */
-      rowCnt--;
-    }
-
-    /* Loop over the number of columns of the input matrix.
-       All the elements in each column are processed by the row operations */
-    loopCnt = numCols;
-
-    /* Index modifier to navigate through the columns */
-    l = 0U;
-
-    while (loopCnt > 0U)
-    {
-      /* Check if the pivot element is zero..
-       * If it is zero then interchange the row with non zero row below.
-       * If there is no non zero element to replace in the rows below,
-       * then the matrix is Singular. */
-
-      /* Working pointer for the input matrix that points
-       * to the pivot element of the particular row  */
-      pInT1 = pIn + (l * numCols);
-
-      /* Working pointer for the destination matrix that points
-       * to the pivot element of the particular row  */
-      pOutT1 = pOut + (l * numCols);
-
-      /* Temporary variable to hold the pivot value */
-      in = *pInT1;
-
-      /* Check if the pivot element is zero */
-      if (*pInT1 == 0.0)
-      {
-        /* Loop over the number rows present below */
-        for (i = 1U; i < numRows-l; i++)
-        {
-          /* Update the input and destination pointers */
-          pInT2 = pInT1 + (numCols * i);
-          pOutT2 = pOutT1 + (numCols * i);
-
           /* Check if there is a non zero pivot element to
            * replace in the rows below */
-          if (*pInT2 != 0.0)
-          {
+      if ((pivot != 0.0) && (selectedRow != column))
+      {
             /* Loop over number of columns
              * to the right of the pilot element */
-            for (j = 0U; j < (numCols - l); j++)
-            {
-              /* Exchange the row elements of the input matrix */
-              Xchg = *pInT2;
-              *pInT2++ = *pInT1;
-              *pInT1++ = Xchg;
-            }
-
-            for (j = 0U; j < numCols; j++)
-            {
-              Xchg = *pOutT2;
-              *pOutT2++ = *pOutT1;
-              *pOutT1++ = Xchg;
-            }
 
+            SWAP_ROWS_F64(pSrc,column, pivotRow,selectedRow);
+            SWAP_ROWS_F64(pDst,0, pivotRow,selectedRow);
+
+    
             /* Flag to indicate whether exchange is done or not */
             flag = 1U;
 
-            /* Break after exchange is done */
-            break;
-          }
-        }
       }
 
 
       /* Update the status if the matrix is singular */
-      if ((flag != 1U) && (in == 0.0))
+      if ((flag != 1U) && (pivot == 0.0))
       {
         return ARM_MATH_SINGULAR;
       }
 
-      /* Points to the pivot row of input and destination matrices */
-      pPivotRowIn = pIn + (l * numCols);
-      pPivotRowDst = pOut + (l * numCols);
-
-      /* Temporary pointers to the pivot row pointers */
-      pInT1 = pPivotRowIn;
-      pOutT1 = pPivotRowDst;
-
+     
       /* Pivot element of the row */
-      in = *(pIn + (l * numCols));
+      pivot = 1.0 / pivot;
 
-      /* Loop over number of columns
-       * to the right of the pilot element */
-      for (j = 0U; j < (numCols - l); j++)
-      {
-        /* Divide each element of the row of the input matrix
-         * by the pivot element */
-        *pInT1 = *pInT1 / in;
-        pInT1++;
-      }
-      for (j = 0U; j < numCols; j++)
-      {
-        /* Divide each element of the row of the destination matrix
-         * by the pivot element */
-        *pOutT1 = *pOutT1 / in;
-        pOutT1++;
-      }
+      SCALE_ROW_F64(pSrc,column,pivot,pivotRow);
+      SCALE_ROW_F64(pDst,0,pivot,pivotRow);
 
+      
       /* Replace the rows with the sum of that row and a multiple of row i
        * so that each new element in column i above row i is zero.*/
 
-      /* Temporary pointers for input and destination matrices */
-      pInT1 = pIn;
-      pOutT1 = pOut;
-
-      for (i = 0U; i < numRows; i++)
+      rowNb = 0;
+      for (;rowNb < pivotRow; rowNb++)
       {
-        /* Check for the pivot element */
-        if (i == l)
-        {
-          /* If the processing element is the pivot element,
-             only the columns to the right are to be processed */
-          pInT1 += numCols - l;
-          pOutT1 += numCols;
-        }
-        else
-        {
-          /* Element of the reference row */
-          in = *pInT1;
-
-          /* Working pointers for input and destination pivot rows */
-          pPRT_in = pPivotRowIn;
-          pPRT_pDst = pPivotRowDst;
-
-          /* Loop over the number of columns to the right of the pivot element,
-             to replace the elements in the input matrix */
-          for (j = 0U; j < (numCols - l); j++)
-          {
-            /* Replace the element by the sum of that row
-               and a multiple of the reference row  */
-            *pInT1 = *pInT1 - (in * *pPRT_in++);
-            pInT1++;
-          }
+           pTmp = ELEM(pSrc,rowNb,column) ;
+           pivot = *pTmp;
 
-          /* Loop over the number of columns to
-             replace the elements in the destination matrix */
-          for (j = 0U; j < numCols; j++)
-          {
-            /* Replace the element by the sum of that row
-               and a multiple of the reference row  */
-            *pOutT1 = *pOutT1 - (in * *pPRT_pDst++);
-            pOutT1++;
-          }
+           MAS_ROW_F64(column,pSrc,rowNb,pivot,pSrc,pivotRow);
+           MAS_ROW_F64(0     ,pDst,rowNb,pivot,pDst,pivotRow);
 
-        }
 
-        /* Increment temporary input pointer */
-        pInT1 = pInT1 + l;
       }
 
-      /* Increment the input pointer */
-      pIn++;
+      for (rowNb = pivotRow + 1; rowNb < numRows; rowNb++)
+      {
+           pTmp = ELEM(pSrc,rowNb,column) ;
+           pivot = *pTmp;
 
-      /* Decrement the loop counter */
-      loopCnt--;
+           MAS_ROW_F64(column,pSrc,rowNb,pivot,pSrc,pivotRow);
+           MAS_ROW_F64(0     ,pDst,rowNb,pivot,pDst,pivotRow);
 
-      /* Increment the index modifier */
-      l++;
-    }
+      }
 
-#endif /* #if defined (ARM_MATH_DSP) */
+    }
 
     /* Set status as ARM_MATH_SUCCESS */
     status = ARM_MATH_SUCCESS;
 
-    if ((flag != 1U) && (in == 0.0))
+    if ((flag != 1U) && (pivot == 0.0))
     {
       pIn = pSrc->pData;
       for (i = 0; i < numRows * numCols; i++)
@@ -638,7 +258,6 @@ arm_status arm_mat_inverse_f64(
   /* Return to application */
   return (status);
 }
-
 /**
   @} end of MatrixInv group
  */

+ 5 - 56
Source/MatrixFunctions/arm_mat_ldlt_f32.c

@@ -27,44 +27,12 @@
  */
 
 #include "dsp/matrix_functions.h"
-
-
+#include "dsp/matrix_utils.h"
 
 
 
 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
 
-
-/// @private
-#define SWAP_ROWS_F32(A,i,j)                 \
-  {                                      \
-    int cnt = n;                         \
-                                         \
-    for(int w=0;w < n; w+=4)             \
-    {                                    \
-       f32x4_t tmpa,tmpb;                \
-       mve_pred16_t p0 = vctp32q(cnt);   \
-                                         \
-       tmpa=vldrwq_z_f32(&A[i*n + w],p0);\
-       tmpb=vldrwq_z_f32(&A[j*n + w],p0);\
-                                         \
-       vstrwq_p(&A[i*n + w], tmpb, p0);  \
-       vstrwq_p(&A[j*n + w], tmpa, p0);  \
-                                         \
-       cnt -= 4;                         \
-    }                                    \
-  }
-
-/// @private
-#define SWAP_COLS_F32(A,i,j)     \
-  for(int w=0;w < n; w++)    \
-  {                          \
-     float32_t tmp;          \
-     tmp = A[w*n + i];       \
-     A[w*n + i] = A[w*n + j];\
-     A[w*n + j] = tmp;       \
-  }
-
 /**
   @ingroup groupMatrix
  */
@@ -157,8 +125,8 @@ arm_status arm_mat_ldlt_f32(
 
         if(j != k)
         {
-          SWAP_ROWS_F32(pA,k,j);
-          SWAP_COLS_F32(pA,k,j);
+          SWAP_ROWS_F32(pl,0,k,j);
+          SWAP_COLS_F32(pl,0,k,j);
         }
 
 
@@ -323,25 +291,6 @@ arm_status arm_mat_ldlt_f32(
 }
 #else
 
-/// @private
-#define SWAP_ROWS_F32(A,i,j)     \
-  for(w=0;w < n; w++)    \
-  {                          \
-     float32_t tmp;          \
-     tmp = A[i*n + w];       \
-     A[i*n + w] = A[j*n + w];\
-     A[j*n + w] = tmp;       \
-  }
-
-/// @private
-#define SWAP_COLS_F32(A,i,j)     \
-  for(w=0;w < n; w++)    \
-  {                          \
-     float32_t tmp;          \
-     tmp = A[w*n + i];       \
-     A[w*n + i] = A[w*n + j];\
-     A[w*n + j] = tmp;       \
-  }
 
 /**
   @ingroup groupMatrix
@@ -429,8 +378,8 @@ arm_status arm_mat_ldlt_f32(
 
         if(j != k)
         {
-          SWAP_ROWS_F32(pA,k,j);
-          SWAP_COLS_F32(pA,k,j);
+          SWAP_ROWS_F32(pl,0,k,j);
+          SWAP_COLS_F32(pl,0,k,j);
         }
 
 

+ 4 - 29
Source/MatrixFunctions/arm_mat_ldlt_f64.c

@@ -27,35 +27,10 @@
  */
 
 #include "dsp/matrix_functions.h"
-#include <math.h>
-
-
+#include "dsp/matrix_utils.h"
 
-/// @private
-#define SWAP_ROWS_F64(A,i,j) \
-{                            \
-  int w;                     \
-  for(w=0;w < n; w++)        \
-  {                          \
-     float64_t tmp;          \
-     tmp = A[i*n + w];       \
-     A[i*n + w] = A[j*n + w];\
-     A[j*n + w] = tmp;       \
-  }                          \
-}
+#include <math.h>
 
-/// @private
-#define SWAP_COLS_F64(A,i,j) \
-{                            \
-  int w;                     \
-  for(w=0;w < n; w++)        \
-  {                          \
-     float64_t tmp;          \
-     tmp = A[w*n + i];       \
-     A[w*n + i] = A[w*n + j];\
-     A[w*n + j] = tmp;       \
-  }                          \
-}
 
 /**
   @ingroup groupMatrix
@@ -141,8 +116,8 @@ arm_status arm_mat_ldlt_f64(
 
         if(j != k)
         {
-          SWAP_ROWS_F64(pA,k,j);
-          SWAP_COLS_F64(pA,k,j);
+          SWAP_ROWS_F64(pl,0,k,j);
+          SWAP_COLS_F64(pl,0,k,j);
         }
 
 

+ 1 - 1
Source/MatrixFunctions/arm_mat_vec_mult_f32.c

@@ -165,7 +165,7 @@ void arm_mat_vec_mult_f32(
     }
 
     /*
-     * compute 2 rows in parrallel
+     * compute 2 rows in parallel
      */
     if (row >= 2)
     {

+ 3 - 3
Testing/Source/Tests/UnaryTestsF16.cpp

@@ -22,9 +22,9 @@ Comparisons for inverse
 /* Not very accurate for big matrix.
 But big matrix needed for checking the vectorized code */
 
-#define SNR_THRESHOLD_INV 45
-#define REL_ERROR_INV (3.0e-2)
-#define ABS_ERROR_INV (3.0e-2)
+#define SNR_THRESHOLD_INV 52
+#define REL_ERROR_INV (3.0e-3)
+#define ABS_ERROR_INV (2.0e-2)
 
 #define REL_ERROR_SOLVE (6.0e-2)
 #define ABS_ERROR_SOLVE (2.0e-2)

+ 3 - 3
Testing/Source/Tests/UnaryTestsF32.cpp

@@ -22,9 +22,9 @@ Comparisons for inverse
 /* Not very accurate for big matrix.
 But big matrix needed for checking the vectorized code */
 
-#define SNR_THRESHOLD_INV 67
-#define REL_ERROR_INV (1.0e-3)
-#define ABS_ERROR_INV (1.0e-3)
+#define SNR_THRESHOLD_INV 100
+#define REL_ERROR_INV (3.0e-5)
+#define ABS_ERROR_INV (1.0e-5)
 
 /*
 

この差分においてかなりの量のファイルが変更されているため、一部のファイルを表示していません