matrix_test.cpp 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863
  1. extern "C" {
  2. extern void matrix_test();
  3. }
  4. #include "allocator.h"
  5. #include <dsppp/arch.hpp>
  6. #include <dsppp/fixed_point.hpp>
  7. #include <dsppp/matrix.hpp>
  8. #include <dsppp/unroll.hpp>
  9. #include <iostream>
  10. #include <cmsis_tests.h>
  11. #include "boost/mp11.hpp"
  12. using namespace boost::mp11;
  13. extern "C" {
  14. #include "dsp/matrix_functions.h"
  15. #include "dsp/matrix_utils.h"
  16. }
  17. template<typename T>
  18. struct MatTestConstant;
  19. template<>
  20. struct MatTestConstant<double>
  21. {
  22. constexpr static double value = 0.001;
  23. constexpr static double half = 0.5;
  24. };
  25. template<>
  26. struct MatTestConstant<float32_t>
  27. {
  28. constexpr static float value = 0.001f;
  29. constexpr static float half = 0.5f;
  30. };
  31. #if !defined(DISABLEFLOAT16)
  32. template<>
  33. struct MatTestConstant<float16_t>
  34. {
  35. constexpr static float16_t value = (float16_t)0.001f;
  36. constexpr static float16_t half = (float16_t)0.5f;
  37. };
  38. #endif
  39. template<>
  40. struct MatTestConstant<Q7>
  41. {
  42. constexpr static Q7 value = 0.001_q7;
  43. constexpr static Q7 half = 0.5_q7;
  44. };
  45. template<>
  46. struct MatTestConstant<Q15>
  47. {
  48. constexpr static Q15 value = 0.001_q15;
  49. constexpr static Q15 half = 0.5_q15;
  50. };
  51. template<>
  52. struct MatTestConstant<Q31>
  53. {
  54. constexpr static Q31 value = 0.001_q31;
  55. constexpr static Q31 half = 0.5_q31;
  56. };
  57. template<typename T>
  58. struct ErrThreshold
  59. {
  60. constexpr static float abserr = 0;
  61. constexpr static float relerr = 0;
  62. constexpr static float abserr_cholesky = 0;
  63. constexpr static float relerr_cholesky = 0;
  64. constexpr static float abserr_householder = 0;
  65. constexpr static float relerr_householder = 0;
  66. constexpr static float abserr_qr = 0;
  67. constexpr static float relerr_qr = 0;
  68. constexpr static float abserr_inv = 0;
  69. constexpr static float relerr_inv = 0;
  70. };
  71. // Should be more accurate than F32 but right know
  72. // we only check there is no regression compared to f32
  73. template<>
  74. struct ErrThreshold<double>
  75. {
  76. constexpr static float abserr = ABS_ERROR;
  77. constexpr static float relerr = REL_ERROR;
  78. constexpr static float abserr_cholesky = 3e-4;
  79. constexpr static float relerr_cholesky = 1e-4;
  80. constexpr static float abserr_householder = ABS_ERROR;
  81. constexpr static float relerr_householder = REL_ERROR;
  82. constexpr static float abserr_qr = ABS_ERROR;
  83. constexpr static float relerr_qr = REL_ERROR;
  84. constexpr static float abserr_inv = ABS_ERROR;
  85. constexpr static float relerr_inv = REL_ERROR;
  86. };
  87. template<>
  88. struct ErrThreshold<float32_t>
  89. {
  90. constexpr static float abserr = ABS_ERROR;
  91. constexpr static float relerr = REL_ERROR;
  92. constexpr static float abserr_cholesky = 3e-4;
  93. constexpr static float relerr_cholesky = 1e-4;
  94. constexpr static float abserr_householder = ABS_ERROR;
  95. constexpr static float relerr_householder = REL_ERROR;
  96. constexpr static float abserr_qr = ABS_ERROR;
  97. constexpr static float relerr_qr = REL_ERROR;
  98. constexpr static float abserr_inv = 4.0e-6;
  99. constexpr static float relerr_inv = 5.0e-6;
  100. };
  101. #if !defined(DISABLEFLOAT16)
  102. template<>
  103. struct ErrThreshold<float16_t>
  104. {
  105. constexpr static float abserr = ABS_ERROR;
  106. constexpr static float relerr = REL_ERROR;
  107. constexpr static float abserr_cholesky = 2e-1;
  108. constexpr static float relerr_cholesky = 2e-1;
  109. constexpr static float abserr_householder = 2e-4;
  110. constexpr static float relerr_householder = 2e-3;
  111. // 32x32 is not numerically behaving well with
  112. // the matrix used as input
  113. constexpr static float abserr_qr = 2.0;
  114. constexpr static float relerr_qr = 1e-2;
  115. constexpr static float abserr_inv = 3e-2;
  116. constexpr static float relerr_inv = 3e-2;
  117. };
  118. #endif
  119. void cmsisdsp_mat_inv(float64_t *amod,
  120. float64_t* b,
  121. uint32_t r,uint32_t c)
  122. {
  123. arm_matrix_instance_f64 src;
  124. arm_matrix_instance_f64 dst;
  125. src.numRows = r;
  126. src.numCols = c;
  127. src.pData = amod;
  128. dst.numRows = r;
  129. dst.numCols = c;
  130. dst.pData = b;
  131. arm_status status = arm_mat_inverse_f64(&src,&dst);
  132. (void)status;
  133. };
  134. void cmsisdsp_mat_inv(float32_t *amod,
  135. float32_t* b,
  136. uint32_t r,uint32_t c)
  137. {
  138. arm_matrix_instance_f32 src;
  139. arm_matrix_instance_f32 dst;
  140. src.numRows = r;
  141. src.numCols = c;
  142. src.pData = amod;
  143. dst.numRows = r;
  144. dst.numCols = c;
  145. dst.pData = b;
  146. arm_status status = arm_mat_inverse_f32(&src,&dst);
  147. (void)status;
  148. };
  149. #if !defined(DISABLEFLOAT16)
  150. void cmsisdsp_mat_inv(float16_t *amod,
  151. float16_t* b,
  152. uint32_t r,uint32_t c)
  153. {
  154. arm_matrix_instance_f16 src;
  155. arm_matrix_instance_f16 dst;
  156. src.numRows = r;
  157. src.numCols = c;
  158. src.pData = amod;
  159. dst.numRows = r;
  160. dst.numCols = c;
  161. dst.pData = b;
  162. arm_status status = arm_mat_inverse_f16(&src,&dst);
  163. (void)status;
  164. };
  165. #endif
  166. const float32_t mat64[64] = {0.395744, 0.623798, 0.885422, 0.95415, 0.310384, 0.257541,
  167. 0.631426, 0.424491, 0.130945, 0.799959, 0.133693, 0.479455,
  168. 0.519254, 0.381039, 0.617455, 0.748273, 0.146944, 0.928945,
  169. 0.430936, 0.508207, 0.829023, 0.358027, 0.999501, 0.851953,
  170. 0.273895, 0.685898, 0.0436612, 0.295212, 0.467651, 0.0515567,
  171. 0.21037, 0.607475, 0.570295, 0.281109, 0.979219, 0.0947969,
  172. 0.319016, 0.398405, 0.349953, 0.710002, 0.431597, 0.447659,
  173. 0.0747669, 0.057063, 0.165648, 0.773106, 0.135765, 0.709327,
  174. 0.873836, 0.292361, 0.00202529, 0.392942, 0.520183, 0.0528055,
  175. 0.797982, 0.613497, 0.509682, 0.0435791, 0.780526, 0.960582,
  176. 0.535914, 0.216113, 0.134108, 0.225859};
  177. const float32_t mat16[16] = {1.0, 2.0, 3.0, 4.0, 2.0, 4.0, 5.0, 6.0,
  178. 3.0, 5.0, 9.0, 10.0, 4.0, 6.0, 10.0, 16.0};
  179. const float32_t mat256[256] = {0.97936, 0.498105, 0.452618, 0.299761, 0.688624, 0.247212, \
  180. 0.228337, 0.22905, 0.563815, 0.251998, 0.5238, 0.141223, 0.0980689, \
  181. 0.79112, 0.771182, 0.890995, 0.0256181, 0.0377277, 0.575629, \
  182. 0.648138, 0.926218, 0.803878, 0.620333, 0.325635, 0.587355, 0.041795, \
  183. 0.934271, 0.0690131, 0.0240136, 0.800828, 0.522999, 0.374706, \
  184. 0.266977, 0.208028, 0.112878, 0.0389899, 0.658311, 0.205067, \
  185. 0.244172, 0.0762778, 0.190575, 0.677312, 0.0682093, 0.367328, \
  186. 0.0191464, 0.988968, 0.437477, 0.130622, 0.907823, 0.0116559, \
  187. 0.614526, 0.447443, 0.0126975, 0.995496, 0.947676, 0.659996, \
  188. 0.321547, 0.725415, 0.658426, 0.0243924, 0.0843519, 0.351748, \
  189. 0.974332, 0.673381, 0.375012, 0.719626, 0.721219, 0.766905, \
  190. 0.17065, 0.648905, 0.770983, 0.360008, 0.344226, 0.179633, 0.347905, \
  191. 0.555561, 0.742615, 0.908389, 0.806959, 0.176078, 0.872167, \
  192. 0.321839, 0.098607, 0.954515, 0.627286, 0.235082, 0.746179, 0.163606, \
  193. 0.899323, 0.871471, 0.712448, 0.956971, 0.736687, 0.750702, 0.843348, \
  194. 0.302435, 0.444862, 0.0644597, 0.765519, 0.518397, 0.765541, \
  195. 0.900375, 0.201853, 0.490325, 0.721786, 0.893647, 0.774724, \
  196. 0.0983631, 0.339887, 0.526084, 0.0786152, 0.515697, 0.438801, \
  197. 0.226628, 0.125093, 0.886642, 0.617766, 0.71696, 0.473172, 0.640949, \
  198. 0.67688, 0.676214, 0.453662, 0.345796, 0.608999, 0.904448, 0.0965741, \
  199. 0.00461771, 0.467399, 0.292235, 0.0418646, 0.116632, 0.0766192, \
  200. 0.269051, 0.411649, 0.0538381, 0.973959, 0.667106, 0.301662, \
  201. 0.977206, 0.891751, 0.420267, 0.441334, 0.0896179, 0.249969, \
  202. 0.672614, 0.623966, 0.609733, 0.320772, 0.39723, 0.845196, 0.653877, \
  203. 0.0599186, 0.340188, 0.199787, 0.598104, 0.45664, 0.920485, 0.969439, \
  204. 0.446555, 0.0932837, 0.0247635, 0.747644, 0.438759, 0.639154, \
  205. 0.754049, 0.379433, 0.968655, 0.0452146, 0.208123, 0.252654, \
  206. 0.261898, 0.608665, 0.145211, 0.395368, 0.799111, 0.697823, \
  207. 0.382906, 0.456515, 0.262579, 0.284169, 0.881488, 0.860877, 0.155548, \
  208. 0.537387, 0.804235, 0.311383, 0.183216, 0.677692, 0.829542, 0.406049, \
  209. 0.860392, 0.467668, 0.385633, 0.654692, 0.841125, 0.178406, \
  210. 0.668945, 0.369609, 0.809711, 0.454593, 0.632028, 0.605791, 0.643851, \
  211. 0.787023, 0.285633, 0.832216, 0.30892, 0.303559, 0.704898, 0.61118, \
  212. 0.435547, 0.173678, 0.788689, 0.319511, 0.648378, 0.635417, 0.125127, \
  213. 0.310251, 0.800819, 0.4863, 0.924361, 0.308059, 0.952175, 0.449844, \
  214. 0.215496, 0.257826, 0.556383, 0.259735, 0.197234, 0.0509903, 0.21474, \
  215. 0.145085, 0.41288, 0.876758, 0.096721, 0.228955, 0.0152248, 0.126501, \
  216. 0.28899, 0.336668, 0.580015, 0.932761, 0.989783, 0.667379, \
  217. 0.798751, 0.587173, 0.445902, 0.041448, 0.311878, 0.0332857, \
  218. 0.401984, 0.795049, 0.8222, 0.678648, 0.807558};
  219. template<typename T,int R,int C, template<int> typename A>
  220. void init_mat(Matrix<T,R,C,A> &pDst,std::size_t r,std::size_t c)
  221. {
  222. const float32_t *p;
  223. if ((r==4) && (r==c))
  224. {
  225. p = mat16;
  226. }
  227. if ((r==8) && (r==c))
  228. {
  229. p = mat64;
  230. }
  231. if ((r==16) && (r==c))
  232. {
  233. p = mat256;
  234. }
  235. for(std::size_t i=0;i<r*c;i++)
  236. {
  237. pDst[i] = T(p[i]);
  238. }
  239. }
  240. #if !defined(DISABLEFLOAT16)
  241. __STATIC_FORCEINLINE float16_t _abs(float16_t x)
  242. {
  243. return((float16_t)fabsf((float)x));
  244. };
  245. #endif
  246. __STATIC_FORCEINLINE float32_t _abs(float32_t x)
  247. {
  248. return(fabsf(x));
  249. };
  250. __STATIC_FORCEINLINE double _abs(double x)
  251. {
  252. return(fabs(x));
  253. };
  254. template<typename T,
  255. int NB,
  256. template<int> typename A,
  257. typename M>
  258. void _matinv(const Matrix<T,NB,NB,A> &a,M && res)
  259. {
  260. Matrix<T,NB,NB,TMP_ALLOC> b = a;
  261. const vector_length_t nb_rows = a.rows();
  262. const vector_length_t nb_cols = a.columns();
  263. for(index_t r=0;r < nb_rows ; r++)
  264. {
  265. res.row(r) = T{};
  266. res(r,r) = number_traits<T>::one();
  267. }
  268. for(index_t c=0;c < nb_cols ; c++)
  269. {
  270. T pivot = b(c,c);
  271. index_t selectedRow = c;
  272. for(index_t r=c+1;r < nb_rows ; r++)
  273. {
  274. T newPivot = b(r,c);
  275. if (_abs(newPivot)>_abs(pivot))
  276. {
  277. pivot = newPivot;
  278. selectedRow = r;
  279. }
  280. }
  281. if ((pivot!=T{}) && (selectedRow != c))
  282. {
  283. swap(b.row(c,c),b.row(selectedRow,c));
  284. swap(res.row(c),res.row(selectedRow));
  285. }
  286. else if (pivot == T{})
  287. {
  288. break;
  289. }
  290. pivot = number_traits<T>::one() / pivot;
  291. b.row(c,c) *= pivot;
  292. res.row(c) *= pivot;
  293. index_t r=0;
  294. for(;r < c ; r++)
  295. {
  296. const T tmp = b(r,c);
  297. b.row(r,c) -= b.row(c,c)*tmp;
  298. res.row(r) -= res.row(c)*tmp;
  299. }
  300. for(r=c+1;r < nb_rows ; r++)
  301. {
  302. const T tmp = b(r,c);
  303. b.row(r,c) -= b.row(c,c)*tmp;
  304. res.row(r) -= res.row(c)*tmp;
  305. }
  306. }
  307. }
  308. template<typename T,
  309. int NB,
  310. template<int> typename A,
  311. typename std::enable_if<(NB>0),bool>::type = true>
  312. Matrix<T,NB,NB,TMP_ALLOC> matinv(const Matrix<T,NB,NB,A> &a)
  313. {
  314. Matrix<T,NB,NB,TMP_ALLOC> res;
  315. _matinv(a,res);
  316. return(res);
  317. }
  318. template<typename T,
  319. int NB,
  320. template<int> typename A,
  321. typename std::enable_if<(NB<0),bool>::type = true>
  322. Matrix<T,DYNAMIC,DYNAMIC,TMP_ALLOC> matinv(const Matrix<T,NB,NB,A> &a)
  323. {
  324. Matrix<T,DYNAMIC,DYNAMIC,TMP_ALLOC> res(a.rows(),a.columns());
  325. return (_matinv(a,res));
  326. return(res);
  327. }
  328. template<typename T,
  329. int NB,
  330. template<int> typename A,
  331. typename std::enable_if<(NB<0),bool>::type = true>
  332. void matinv(Matrix<T,DYNAMIC,DYNAMIC,TMP_ALLOC> &res, const Matrix<T,NB,NB,A> &a)
  333. {
  334. (void)_matinv(a,res);
  335. }
  336. template<typename T,int R,int C>
  337. void testinv()
  338. {
  339. std::cout << "----\r\n";
  340. std::cout << R << " x " << C << "\r\n";
  341. #if defined(STATIC_TEST)
  342. PMat<T,R,C> a;
  343. #else
  344. PMat<T> a(R,C);
  345. #endif
  346. init_mat(a,R,C);
  347. #if !defined(STATIC_TEST)
  348. PMat<T> res(R,C);
  349. #endif
  350. INIT_SYSTICK;
  351. START_CYCLE_MEASUREMENT;
  352. startSectionNB(1);
  353. #if defined(STATIC_TEST)
  354. PMat<T,R,C> res = matinv(a);
  355. #else
  356. matinv(res,a);
  357. #endif
  358. stopSectionNB(1);
  359. STOP_CYCLE_MEASUREMENT;
  360. PMat<T> amod(a);
  361. PMat<T> cmsis_res(R,C);
  362. INIT_SYSTICK;
  363. START_CYCLE_MEASUREMENT;
  364. cmsisdsp_mat_inv(amod.ptr(),cmsis_res.ptr(),R,C);
  365. STOP_CYCLE_MEASUREMENT;
  366. if (!validate(res.const_ptr(),cmsis_res.const_ptr(),R*C,
  367. ErrThreshold<T>::abserr_inv,ErrThreshold<T>::relerr_inv))
  368. {
  369. printf("inv failed \r\n");
  370. }
  371. std::cout << "=====\r\n";
  372. }
  373. template<typename T,int R,int C>
  374. void testadd()
  375. {
  376. std::cout << "----\r\n";
  377. std::cout << R << " x " << C << "\r\n";
  378. #if defined(STATIC_TEST)
  379. PMat<T,R,C> a;
  380. PMat<T,R,C> b;
  381. #else
  382. PMat<T> a(R,C);
  383. PMat<T> b(R,C);
  384. #endif
  385. init_array(a,R*C);
  386. init_array(b,R*C);
  387. INIT_SYSTICK;
  388. START_CYCLE_MEASUREMENT;
  389. startSectionNB(1);
  390. #if defined(STATIC_TEST)
  391. PMat<T,R,C> res = a+b;
  392. #else
  393. PMat<T> res = a+b;
  394. #endif
  395. stopSectionNB(1);
  396. STOP_CYCLE_MEASUREMENT;
  397. //PrintType<decltype(a+b)>();
  398. //PrintType<decltype(res)>();
  399. //
  400. //std::cout << "a: " << IsVector<decltype(a)>::value << "\r\n";
  401. //std::cout << "b: " << IsVector<decltype(b)>::value << "\r\n";
  402. //std::cout << "a+b: " << IsVector<decltype(a+b)>::value << "\r\n";
  403. //std::cout << "res: " << IsVector<decltype(res)>::value << "\r\n";
  404. //std::cout << "same: " << SameElementType<decltype(res),decltype(a+b)>::value << "\r\n";
  405. //
  406. //std::cout << "vec inst: " << has_vector_inst<decltype(res)>() << "\r\n";
  407. //std::cout << "vec index pair: " << vector_idx_pair<decltype(res),decltype(a+b)>() << "\r\n";
  408. //std::cout << "must use mat idx: " << must_use_matrix_idx_pair<decltype(res),decltype(a+b)>() << "\r\n";
  409. INIT_SYSTICK;
  410. START_CYCLE_MEASUREMENT;
  411. #if defined(STATIC_TEST)
  412. PMat<T,R,C> cmsis_res;
  413. #else
  414. PMat<T> cmsis_res(R,C);
  415. #endif
  416. cmsisdsp_mat_add(a.const_ptr(),b.const_ptr(),cmsis_res.ptr(),R,C);
  417. STOP_CYCLE_MEASUREMENT;
  418. if (!validate(res.const_ptr(),cmsis_res.const_ptr(),R*C,
  419. ErrThreshold<T>::abserr,ErrThreshold<T>::relerr))
  420. {
  421. printf("add failed \r\n");
  422. }
  423. std::cout << "=====\r\n";
  424. }
  425. template<typename T,int R,int C>
  426. void testdiag()
  427. {
  428. std::cout << "----\r\n";
  429. std::cout << R << " x " << C << "\r\n";
  430. #if defined(STATIC_TEST)
  431. PVector<T,R> a;
  432. #else
  433. PVector<T> a(R);
  434. #endif
  435. init_array(a,R);
  436. INIT_SYSTICK;
  437. START_CYCLE_MEASUREMENT;
  438. startSectionNB(1);
  439. #if defined(STATIC_TEST)
  440. PMat<T,R,C> res=PMat<T,R,C>::diagonal(a);
  441. #else
  442. PMat<T> res=PMat<T>::diagonal(a);
  443. #endif
  444. stopSectionNB(1);
  445. STOP_CYCLE_MEASUREMENT;
  446. const T* ap = a.const_ptr();
  447. INIT_SYSTICK;
  448. START_CYCLE_MEASUREMENT;
  449. #if defined(STATIC_TEST)
  450. PMat<T,R,C> cmsis_res;
  451. #else
  452. PMat<T> cmsis_res(R,C);
  453. #endif
  454. T* refp = cmsis_res.ptr();
  455. UNROLL_LOOP
  456. for(index_t row=0;row < R; row++)
  457. {
  458. UNROLL_LOOP
  459. for(index_t col=0;col < C; col++)
  460. {
  461. if (row != col)
  462. {
  463. refp[row*C+col] = T{};
  464. }
  465. else
  466. {
  467. refp[row*C+col] = ap[row];
  468. }
  469. }
  470. }
  471. STOP_CYCLE_MEASUREMENT;
  472. if (!validate(res.const_ptr(),cmsis_res.const_ptr(),R*C,
  473. ErrThreshold<T>::abserr,ErrThreshold<T>::relerr))
  474. {
  475. printf("diag failed \r\n");
  476. }
  477. std::cout << "=====\r\n";
  478. }
  479. template<typename T,int R,int C>
  480. void testouter()
  481. {
  482. std::cout << "----\r\n";
  483. std::cout << R << " x " << C << "\r\n";
  484. PVector<T,R> a;
  485. PVector<T,C> b;
  486. init_array(a,R);
  487. init_array(b,C);
  488. b = b + b;
  489. INIT_SYSTICK;
  490. START_CYCLE_MEASUREMENT;
  491. startSectionNB(1);
  492. PMat<T,R,C> res = outer(a,b);
  493. stopSectionNB(1);
  494. STOP_CYCLE_MEASUREMENT;
  495. INIT_SYSTICK;
  496. START_CYCLE_MEASUREMENT;
  497. startSectionNB(2);
  498. #if defined(STATIC_TEST)
  499. PMat<T,R,C> cmsis_res;
  500. #else
  501. PMat<T> cmsis_res(R,C);
  502. #endif
  503. CMSISOuter<T>::run(a.const_ptr(),b.const_ptr(),cmsis_res.ptr(),R,C);
  504. startSectionNB(2);
  505. STOP_CYCLE_MEASUREMENT;
  506. //std::cout<<cmsis_res;
  507. if (!validate(res.const_ptr(),cmsis_res.const_ptr(),R*C,
  508. ErrThreshold<T>::abserr,ErrThreshold<T>::relerr))
  509. {
  510. printf("outer failed \r\n");
  511. }
  512. std::cout << "=====\r\n";
  513. }
  514. template<typename T,int R,int C>
  515. void testview()
  516. {
  517. std::cout << "----\r\n";
  518. std::cout << R << " x " << C << "\r\n";
  519. #if defined(STATIC_TEST)
  520. PVector<T,R> a;
  521. #else
  522. PVector<T> a(R);
  523. #endif
  524. init_array(a,R);
  525. #if defined(STATIC_TEST)
  526. PMat<T,R,C> res=PMat<T,R,C>::diagonal(a);
  527. #else
  528. PMat<T> res=PMat<T>::diagonal(a);
  529. #endif
  530. //std::cout << res;
  531. constexpr int subsize = 8;
  532. constexpr int subpos = 8;
  533. auto r = res.sub(Slice(subpos,subpos+subsize),Slice(subpos,subpos+subsize));
  534. #if defined(STATIC_TEST)
  535. PMat<T,subsize,subsize> resb;
  536. #else
  537. PMat<T> resb(subsize,subsize);
  538. #endif
  539. INIT_SYSTICK;
  540. START_CYCLE_MEASUREMENT;
  541. startSectionNB(1);
  542. resb = r+r;
  543. stopSectionNB(1);
  544. STOP_CYCLE_MEASUREMENT;
  545. //std::cout << IsMatrix<decltype(r+r)>::value << "\r\n";
  546. PMat<T,subsize,subsize> cmsis_res;
  547. INIT_SYSTICK;
  548. START_CYCLE_MEASUREMENT;
  549. startSectionNB(2);
  550. DISABLE_LOOP_UNROLL
  551. for(index_t row=0;row < subsize ; row++)
  552. {
  553. DISABLE_LOOP_UNROLL
  554. for(index_t col=0;col < subsize ; col++)
  555. {
  556. cmsis_res(row,col) = r(row,col)+r(row,col);
  557. }
  558. }
  559. startSectionNB(2);
  560. STOP_CYCLE_MEASUREMENT;
  561. //std::cout<<resb;
  562. //std::cout<<cmsis_res;
  563. if (!validate(resb,cmsis_res,
  564. ErrThreshold<T>::abserr,ErrThreshold<T>::relerr))
  565. {
  566. printf("sub matrix failed \r\n");
  567. }
  568. std::cout << "=====\r\n";
  569. }
  570. template<typename T,int R,int C>
  571. void testmatvec()
  572. {
  573. using STO = typename vector_traits<T>::storage_type;
  574. std::cout << "----\r\n";
  575. std::cout << R << " x " << C << "\r\n";
  576. #if defined(STATIC_TEST)
  577. PVector<T,C> a;
  578. #else
  579. PVector<T> a(C);
  580. #endif
  581. init_array(a,C);
  582. #if defined(STATIC_TEST)
  583. PMat<T,R,C> m;
  584. #else
  585. PMat<T> m(R,C);
  586. #endif
  587. init_array(m,R*C);
  588. INIT_SYSTICK;
  589. START_CYCLE_MEASUREMENT;
  590. startSectionNB(1);
  591. #if defined(STATIC_TEST)
  592. PVector<T,R> res = dot(m,a);
  593. #else
  594. PVector<T> res = dot(m,a);
  595. #endif
  596. stopSectionNB(1);
  597. STOP_CYCLE_MEASUREMENT;
  598. //std::cout << IsMatrix<decltype(r+r)>::value << "\r\n";
  599. INIT_SYSTICK;
  600. START_CYCLE_MEASUREMENT;
  601. #if defined(STATIC_TEST)
  602. PVector<T,R> cmsis_res;
  603. #else
  604. PVector<T> cmsis_res(R);
  605. #endif
  606. typename CMSISMatrixType<T>::type S;
  607. S.numRows = R;
  608. S.numCols = C;
  609. S.pData = reinterpret_cast<STO*>(const_cast<T*>(m.ptr()));
  610. startSectionNB(2);
  611. cmsis_mat_vec_mult(&S, a.const_ptr(), cmsis_res.ptr());
  612. startSectionNB(2);
  613. STOP_CYCLE_MEASUREMENT;
  614. //std::cout << cmsis_res;
  615. if (!validate(res.const_ptr(),cmsis_res.const_ptr(),R,
  616. ErrThreshold<T>::abserr,ErrThreshold<T>::relerr))
  617. {
  618. printf("matrix times vector failed \r\n");
  619. }
  620. std::cout << "=====\r\n";
  621. }
  622. template<typename T,int R,int C>
  623. void testcomplexmatvec()
  624. {
  625. const T scalar = MatTestConstant<T>::half;
  626. using STO = typename vector_traits<T>::storage_type;
  627. std::cout << "----\r\n";
  628. std::cout << R << " x " << C << "\r\n";
  629. #if defined(STATIC_TEST)
  630. PVector<T,C> a;
  631. PVector<T,C> b;
  632. #else
  633. PVector<T> a(C);
  634. PVector<T> b(C);
  635. #endif
  636. init_array(a,C);
  637. init_array(b,C);
  638. #if defined(STATIC_TEST)
  639. PMat<T,R,C> m;
  640. #else
  641. PMat<T> m(R,C);
  642. #endif
  643. init_array(m,R*C);
  644. INIT_SYSTICK;
  645. START_CYCLE_MEASUREMENT;
  646. startSectionNB(1);
  647. #if defined(STATIC_TEST)
  648. PVector<T,C> tmpv = a + b * scalar;
  649. PVector<T,R> res = dot(m,tmpv);
  650. #else
  651. PVector<T> tmpv = a + b * scalar;
  652. PVector<T> res = dot(m,tmpv);
  653. #endif
  654. stopSectionNB(1);
  655. STOP_CYCLE_MEASUREMENT;
  656. //std::cout << IsMatrix<decltype(r+r)>::value << "\r\n";
  657. INIT_SYSTICK;
  658. START_CYCLE_MEASUREMENT;
  659. #if defined(STATIC_TEST)
  660. PVector<T,R> cmsis_res;
  661. PVector<T,C> tmp;
  662. #else
  663. PVector<T> cmsis_res(R);
  664. PVector<T> tmp(C);
  665. #endif
  666. typename CMSISMatrixType<T>::type S;
  667. S.numRows = R;
  668. S.numCols = C;
  669. S.pData = reinterpret_cast<STO*>(const_cast<T*>(m.ptr()));
  670. startSectionNB(2);
  671. cmsis_complex_mat_vec(&S,
  672. a.const_ptr(),
  673. b.const_ptr(),
  674. scalar,
  675. tmp.ptr(),
  676. cmsis_res.ptr());
  677. startSectionNB(2);
  678. STOP_CYCLE_MEASUREMENT;
  679. //std::cout << cmsis_res;
  680. if (!validate(res.const_ptr(),cmsis_res.const_ptr(),R,
  681. ErrThreshold<T>::abserr,ErrThreshold<T>::relerr))
  682. {
  683. printf("matrix times vector expression failed \r\n");
  684. }
  685. std::cout << "=====\r\n";
  686. }
  687. template<typename T,int R, int K,int C>
  688. void testmatmult()
  689. {
  690. std::cout << "----\r\n";
  691. std::cout << R << " x " << K << " x " << C << "\r\n";
  692. using S = typename CMSISMatrixType<T>::scalar;
  693. #if defined(STATIC_TEST)
  694. PMat<T,R,K> ma;
  695. #else
  696. PMat<T> ma(R,K);
  697. #endif
  698. init_array(ma,R*K);
  699. #if defined(STATIC_TEST)
  700. PMat<T,K,C> mb;
  701. #else
  702. PMat<T> mb(K,C);
  703. #endif
  704. init_array(mb,K*C);
  705. mb += TestConstant<T>::small;
  706. //std::cout << ma;
  707. //std::cout << mb;
  708. INIT_SYSTICK;
  709. START_CYCLE_MEASUREMENT;
  710. startSectionNB(1);
  711. #if defined(STATIC_TEST)
  712. PMat<T,R,C> res = dot(ma,mb);
  713. #else
  714. PMat<T> res = dot(ma,mb);
  715. #endif
  716. stopSectionNB(1);
  717. STOP_CYCLE_MEASUREMENT;
  718. //PrintType<decltype(ma)>();
  719. //PrintType<decltype(mb)>();
  720. //std::cout << ma;
  721. //std::cout << mb;
  722. //std::cout << res;
  723. //std::cout << IsMatrix<decltype(r+r)>::value << "\r\n";
  724. PMat<T> tmp(C,K);
  725. INIT_SYSTICK;
  726. START_CYCLE_MEASUREMENT;
  727. #if defined(STATIC_TEST)
  728. PMat<T,R,C> cmsis_res;
  729. #else
  730. PMat<T> cmsis_res(R,C);
  731. #endif
  732. typename CMSISMatrixType<T>::type SA;
  733. SA.numRows = R;
  734. SA.numCols = K;
  735. SA.pData = reinterpret_cast<S*>(ma.ptr());
  736. typename CMSISMatrixType<T>::type SB;
  737. SB.numRows = K;
  738. SB.numCols = C;
  739. SB.pData = reinterpret_cast<S*>(mb.ptr());
  740. typename CMSISMatrixType<T>::type RES;
  741. RES.numRows = R;
  742. RES.numCols = C;
  743. RES.pData = reinterpret_cast<S*>(cmsis_res.ptr());
  744. startSectionNB(2);
  745. cmsis_mat_mult(&SA, &SB, &RES,reinterpret_cast<S*>(tmp.ptr()));
  746. startSectionNB(2);
  747. STOP_CYCLE_MEASUREMENT;
  748. //std::cout << cmsis_res;
  749. if (!validate(res,cmsis_res,
  750. ErrThreshold<T>::abserr,ErrThreshold<T>::relerr))
  751. {
  752. printf("matrix times matrix expression failed \r\n");
  753. }
  754. std::cout << "=====\r\n";
  755. }
  756. template<typename T,int R,int K, int C>
  757. void testsubmatmult()
  758. {
  759. std::cout << "----\r\n";
  760. std::cout << R << " x " << K << " x " << C << "\r\n";
  761. using S = typename CMSISMatrixType<T>::scalar;
  762. constexpr int TOTALA = 4 + 2*K + 2*R + K*R;
  763. constexpr int TOTALB = 4 + 2*C + 2*K + C*K;
  764. #if defined(STATIC_TEST)
  765. PMat<T,R+2,K+2> ma;
  766. #else
  767. PMat<T> ma(R+2,K+2);
  768. #endif
  769. init_array(ma,TOTALA);
  770. #if defined(STATIC_TEST)
  771. PMat<T,K+2,C+2> mb;
  772. #else
  773. PMat<T> mb(K+2,C+2);
  774. #endif
  775. init_array(mb,TOTALB);
  776. mb += MatTestConstant<T>::value;
  777. //std::cout << ma;
  778. //std::cout << mb;
  779. INIT_SYSTICK;
  780. START_CYCLE_MEASUREMENT;
  781. #if defined(STATIC_TEST)
  782. PMat<T,R,C> res(T{});
  783. #else
  784. PMat<T> res(R,C,T{});
  785. #endif
  786. startSectionNB(1);
  787. res.sub(Slice(0,R),Slice(0,C)) = copy(dot(ma.sub(Slice(0,R),Slice(0,K)),mb.sub(Slice(0,K),Slice(0,C))));
  788. stopSectionNB(1);
  789. STOP_CYCLE_MEASUREMENT;
  790. //PrintType<decltype(ma)>();
  791. //PrintType<decltype(mb)>();
  792. //std::cout << ma;
  793. //std::cout << mb;
  794. //std::cout << res;
  795. //std::cout << IsMatrix<decltype(r+r)>::value << "\r\n";
  796. PMat<T> cmsis_res(R,C);
  797. PMat<T> cmsis_ma(R,K);
  798. PMat<T> cmsis_mb(K,C);
  799. PMat<T> tmp(C,K);
  800. typename CMSISMatrixType<T>::type SA;
  801. SA.numRows = R;
  802. SA.numCols = K;
  803. SA.pData = reinterpret_cast<S*>(cmsis_ma.ptr());
  804. typename CMSISMatrixType<T>::type SB;
  805. SB.numRows = K;
  806. SB.numCols = C;
  807. SB.pData = reinterpret_cast<S*>(cmsis_mb.ptr());
  808. typename CMSISMatrixType<T>::type RES;
  809. RES.numRows = R;
  810. RES.numCols = C;
  811. RES.pData = reinterpret_cast<S*>(cmsis_res.ptr());
  812. INIT_SYSTICK;
  813. START_CYCLE_MEASUREMENT;
  814. startSectionNB(2);
  815. cmsis_ma = copy(ma.sub(Slice(0,R),Slice(0,K)));
  816. cmsis_mb = copy(mb.sub(Slice(0,K),Slice(0,C)));
  817. cmsis_mat_mult(&SA, &SB, &RES,reinterpret_cast<S*>(tmp.ptr()));
  818. startSectionNB(2);
  819. STOP_CYCLE_MEASUREMENT;
  820. //std::cout << cmsis_res;
  821. if (!validate(res.sub(Slice(0,R),Slice(0,C)),cmsis_res,
  822. ErrThreshold<T>::abserr,ErrThreshold<T>::relerr))
  823. {
  824. printf("matrix times matrix expression failed \r\n");
  825. }
  826. std::cout << "=====\r\n";
  827. }
  828. template<typename T,int R,int C>
  829. void testmattranspose()
  830. {
  831. std::cout << "----\r\n";
  832. std::cout << R << " x " << C << "\r\n";
  833. #if defined(STATIC_TEST)
  834. PMat<T,R,C> ma;
  835. #else
  836. PMat<T> ma(R,C);
  837. #endif
  838. init_array(ma,R*C);
  839. //PrintType<decltype(ma)>();
  840. INIT_SYSTICK;
  841. START_CYCLE_MEASUREMENT;
  842. startSectionNB(1);
  843. #if defined(STATIC_TEST)
  844. PMat<T,C,R> res = ma.transpose();
  845. #else
  846. PMat<T> res = ma.transpose();
  847. #endif
  848. stopSectionNB(1);
  849. STOP_CYCLE_MEASUREMENT;
  850. //std::cout << IsMatrix<decltype(r+r)>::value << "\r\n";
  851. INIT_SYSTICK;
  852. START_CYCLE_MEASUREMENT;
  853. #if defined(STATIC_TEST)
  854. PMat<T,C,R> cmsis_res;
  855. #else
  856. PMat<T> cmsis_res(C,R);
  857. #endif
  858. typename CMSISMatrixType<T>::type SA;
  859. SA.numRows = R;
  860. SA.numCols = C;
  861. SA.pData = reinterpret_cast<typename CMSISMatrixType<T>::scalar*>(ma.ptr());
  862. typename CMSISMatrixType<T>::type RES;
  863. RES.numRows = C;
  864. RES.numCols = R;
  865. RES.pData = reinterpret_cast<typename CMSISMatrixType<T>::scalar*>(cmsis_res.ptr());
  866. startSectionNB(2);
  867. cmsis_mat_trans(&SA, &RES);
  868. startSectionNB(2);
  869. STOP_CYCLE_MEASUREMENT;
  870. //std::cout << cmsis_res;
  871. if (!validate(res,cmsis_res,
  872. ErrThreshold<T>::abserr,ErrThreshold<T>::relerr))
  873. {
  874. printf("matrix transpose failed \r\n");
  875. }
  876. std::cout << "=====\r\n";
  877. }
  878. #if !defined(DISABLEFLOAT16)
  879. static float16_t _gen_sqrt(const float16_t v)
  880. {
  881. return((float16_t)sqrtf(v));
  882. }
  883. #endif
  884. static float32_t _gen_sqrt(const float32_t v)
  885. {
  886. return(sqrtf(v));
  887. }
  888. static float64_t _gen_sqrt(const float64_t v)
  889. {
  890. return(sqrt(v));
  891. }
  892. template<int L,template<int> typename A,
  893. typename V,typename T>
  894. inline T _householder(Vector<T,L,A> &res,const V&v,const T eps)
  895. {
  896. T alpha = v[0];
  897. T tau;
  898. T beta;
  899. if (v.length()==1)
  900. {
  901. res[0] = T{};
  902. return(T{});
  903. }
  904. T xnorm2 = dot(v.sub(1),v.sub(1));
  905. //std::cout << xnorm2 << "\r\n";
  906. if (xnorm2 <= eps)
  907. {
  908. tau = T{};
  909. res = T{};
  910. }
  911. else
  912. {
  913. if (alpha<=0)
  914. {
  915. beta = _gen_sqrt(alpha*alpha+xnorm2);
  916. }
  917. else
  918. {
  919. beta = -_gen_sqrt(alpha*alpha+xnorm2);
  920. }
  921. T r = number_traits<T>::one() / (alpha - beta);
  922. res = v * r;
  923. tau = (beta - alpha)/beta;
  924. res[0] = number_traits<T>::one();
  925. }
  926. return(tau);
  927. }
  928. template<typename V,typename T,
  929. typename std::enable_if<
  930. !IsDynamic<V>::value &&
  931. SameElementType<V,T>::value &&
  932. IsVector<V>::value,bool>::type = true>
  933. auto householder(const V&v,const T threshold)
  934. {
  935. constexpr int NB = StaticLength<V>::value;
  936. Vector<T,NB,TMP_ALLOC> res;
  937. T beta = _householder(res,v,threshold);
  938. return std::tuple<T,Vector<T,NB,TMP_ALLOC>>(beta,res);
  939. }
  940. template<typename V,typename T,
  941. typename std::enable_if<
  942. IsDynamic<V>::value &&
  943. SameElementType<V,T>::value &&
  944. IsVector<V>::value,bool>::type = true>
  945. auto householder(const V&v,const T threshold)
  946. {
  947. Vector<T,DYNAMIC,TMP_ALLOC> res(v.length());
  948. T beta = _householder(res,v,threshold);
  949. return std::tuple<T,Vector<T,DYNAMIC,TMP_ALLOC>>(beta,res);
  950. }
  951. template<typename V,typename T,typename TMP,
  952. typename std::enable_if<
  953. IsDynamic<V>::value &&
  954. SameElementType<V,T>::value &&
  955. IsVector<V>::value,bool>::type = true>
  956. auto householder(const V&v,const T threshold,TMP &res)
  957. {
  958. T beta = _householder(res,v,threshold);
  959. return beta;
  960. }
  961. template<typename T>
  962. struct HouseholderThreshold;
  963. #if !defined(DISABLEFLOAT16)
  964. template<>
  965. struct HouseholderThreshold<float16_t>
  966. {
  967. static constexpr float16_t value = DEFAULT_HOUSEHOLDER_THRESHOLD_F16;
  968. };
  969. #endif
  970. template<>
  971. struct HouseholderThreshold<float64_t>
  972. {
  973. static constexpr float64_t value = DEFAULT_HOUSEHOLDER_THRESHOLD_F64;
  974. };
  975. template<>
  976. struct HouseholderThreshold<float32_t>
  977. {
  978. static constexpr float32_t value = DEFAULT_HOUSEHOLDER_THRESHOLD_F32;
  979. };
  980. template<typename T,int NB>
  981. static void testHouseholder()
  982. {
  983. std::cout << "----\r\n" << "N = " << NB << "\r\n";
  984. #if defined(STATIC_TEST)
  985. PVector<T,NB> a;
  986. #else
  987. PVector<T> a(NB);
  988. #endif
  989. cmsis_init_householder(a.ptr(),NB);
  990. INIT_SYSTICK;
  991. START_CYCLE_MEASUREMENT;
  992. startSectionNB(1);
  993. auto res = householder(a,HouseholderThreshold<T>::value);
  994. //PVector<T,NB> res;// = a + b;
  995. //float res_beta=0;
  996. stopSectionNB(1);
  997. STOP_CYCLE_MEASUREMENT;
  998. INIT_SYSTICK;
  999. START_CYCLE_MEASUREMENT;
  1000. #if defined(STATIC_TEST)
  1001. PVector<T,NB> ref;
  1002. #else
  1003. PVector<T> ref(NB);
  1004. #endif
  1005. T ref_beta = cmsis_householder(a.const_ptr(),ref.ptr(),NB);
  1006. STOP_CYCLE_MEASUREMENT;
  1007. if (!validate(std::get<1>(res).const_ptr(),ref.const_ptr(),NB,
  1008. ErrThreshold<T>::abserr_householder,ErrThreshold<T>::relerr_householder))
  1009. {
  1010. printf("householder vector failed \r\n");
  1011. }
  1012. if (!validate(std::get<0>(res),ref_beta,
  1013. ErrThreshold<T>::abserr_householder,ErrThreshold<T>::relerr_householder))
  1014. {
  1015. printf("householder beta failed \r\n");
  1016. }
  1017. std::cout << "=====\r\n";
  1018. }
  1019. #include "debug_mat.h"
  1020. #if 1
  1021. // R >= C
  1022. template<typename T,int R,int C,template<int> typename A>
  1023. auto QR(const Matrix<T,R,C,A>&m,const T eps,bool wantQ)
  1024. {
  1025. #if defined(STATIC_TEST)
  1026. Vector<T,C,TMP_ALLOC> tau;
  1027. Matrix<T,R,C,TMP_ALLOC> RM = m;
  1028. Matrix<T,R,R,TMP_ALLOC> Q = Matrix<T,R,R>::identity();
  1029. // Temporaries
  1030. Vector<T,R,TMP_ALLOC> tmpvec;
  1031. Matrix<T,1,R,TMP_ALLOC> tmpmat;
  1032. #else
  1033. Vector<T> tau(m.columns());
  1034. Matrix<T> RM = m;
  1035. Matrix<T> Q = Matrix<T>::identity(m.rows());
  1036. // Temporaries
  1037. Vector<T> tmpvec(m.rows());
  1038. Matrix<T> tmpmat(1,m.rows());
  1039. #endif
  1040. const int NBC = m.columns();
  1041. const int NBR = m.rows();
  1042. for(index_t c=0;c<NBC-1;c++)
  1043. {
  1044. auto beta = householder(RM.col(c,c),eps,tmpvec);
  1045. tau[c] = beta;
  1046. MatrixView<T> vt(tmpvec,1,NBR-c);
  1047. dot(tmpmat.sub(0,1,0,NBC-c),vt,RM.sub(c,c));
  1048. RM.sub(c,c) =
  1049. RM.sub(c,c) - beta * outer(tmpvec.sub(0,NBR-c),tmpmat.row(0,0,NBC-c));
  1050. // Copy householder reflector
  1051. // Not valid when c == C-1
  1052. // We don't want to use a test since CMSIS-DSP is not using
  1053. // one and introducing a test would give worse performance
  1054. RM.col(c,c+1) = copy(tmpvec.sub(1,NBR-c));
  1055. }
  1056. auto beta = householder(RM.col(NBC-1,NBC-1),eps,tmpvec);
  1057. tau[NBC-1] = beta;
  1058. MatrixView<T> vt(tmpvec,1,NBR-(NBC-1));
  1059. dot(tmpmat.sub(0,1,0,NBC-(NBC-1)),vt,RM.sub(NBC-1,NBC-1));
  1060. RM.sub(NBC-1,NBC-1) =
  1061. RM.sub(NBC-1,NBC-1) - beta * outer(tmpvec.sub(0,NBR-(NBC-1)),tmpmat.row(0,0,NBC-(NBC-1)));
  1062. if (wantQ)
  1063. {
  1064. for(index_t c=NBC-1;c>=0;c--)
  1065. {
  1066. tmpvec.sub(1) = copy(RM.col(c,c+1));
  1067. tmpvec[0] = number_traits<T>::one();
  1068. MatrixView<T> vt(tmpvec,1,NBR-c);
  1069. dot(tmpmat.sub(0,1,0,NBR-c),vt,Q.sub(c,c));
  1070. Q.sub(c,c) =
  1071. Q.sub(c,c) - tau[c] * outer(tmpvec.sub(0,NBR-c),tmpmat.row(0,0,NBR-c));
  1072. }
  1073. }
  1074. return std::make_tuple(RM,Q,tau);
  1075. }
  1076. template<typename T,int R,int C>
  1077. static void testQR()
  1078. {
  1079. std::cout << "----\r\n";
  1080. std::cout << R << " x " << C << "\r\n";
  1081. #if defined(STATIC_TEST)
  1082. PMat<T,R,C> a;
  1083. #else
  1084. PMat<T> a(R,C);
  1085. #endif
  1086. cmsis_init_qr(a.ptr(),R,C);
  1087. INIT_SYSTICK;
  1088. START_CYCLE_MEASUREMENT;
  1089. startSectionNB(1);
  1090. auto res = QR(a,HouseholderThreshold<T>::value,true);
  1091. stopSectionNB(1);
  1092. STOP_CYCLE_MEASUREMENT;
  1093. //std::cout << "next\r\n";
  1094. //std::cout << std::get<0>(res);
  1095. //std::cout << std::get<1>(res);
  1096. //std::cout << std::get<2>(res);
  1097. // For fair comparison, in dynamic mode we must take into
  1098. // account the memory allocations since they are made
  1099. // by the QR algorithms
  1100. #if !defined(STATIC_TEST)
  1101. INIT_SYSTICK;
  1102. START_CYCLE_MEASUREMENT;
  1103. #endif
  1104. #if 0 //defined(STATIC_TEST)
  1105. PMat<T,R,C> cmsis_res;
  1106. PMat<T,R,C> cmsis_outRp;
  1107. PMat<T,R,R> cmsis_outQp;
  1108. PVector<T,C> cmsis_tau;
  1109. PVector<T,R> cmsis_tmpa;
  1110. PVector<T,C> cmsis_tmpb;
  1111. #else
  1112. PMat<T> cmsis_res(R,C);
  1113. PMat<T> cmsis_outRp(R,C);
  1114. PMat<T> cmsis_outQp(R,R);
  1115. PVector<T> cmsis_tau(C);
  1116. PVector<T> cmsis_tmpa(R);
  1117. PVector<T> cmsis_tmpb(C);
  1118. #endif
  1119. typename CMSISMatrixType<T>::type RP;
  1120. RP.numRows = R;
  1121. RP.numCols = C;
  1122. RP.pData = cmsis_outRp.ptr();
  1123. typename CMSISMatrixType<T>::type QP;
  1124. QP.numRows = R;
  1125. QP.numCols = R;
  1126. QP.pData = cmsis_outQp.ptr();
  1127. typename CMSISMatrixType<T>::type IN;
  1128. IN.numRows = R;
  1129. IN.numCols = C;
  1130. IN.pData = a.ptr();
  1131. //std::cout << "-------\r\n";
  1132. #if defined(STATIC_TEST)
  1133. INIT_SYSTICK;
  1134. START_CYCLE_MEASUREMENT;
  1135. #endif
  1136. arm_status status=cmsis_qr(&IN,HouseholderThreshold<T>::value,
  1137. &RP,&QP,
  1138. cmsis_tau.ptr(),
  1139. cmsis_tmpa.ptr(),
  1140. cmsis_tmpb.ptr());
  1141. (void)status;
  1142. STOP_CYCLE_MEASUREMENT;
  1143. //std::cout << cmsis_outRp;
  1144. //std::cout << cmsis_outQp;
  1145. //std::cout << cmsis_tau;
  1146. if (!validate(std::get<0>(res),cmsis_outRp,
  1147. ErrThreshold<T>::abserr_qr,ErrThreshold<T>::relerr_qr))
  1148. {
  1149. printf("QR Rp matrix failed \r\n");
  1150. }
  1151. if (!validate(std::get<1>(res),cmsis_outQp,
  1152. ErrThreshold<T>::abserr_qr,ErrThreshold<T>::relerr_qr))
  1153. {
  1154. printf("QR Qp matrix failed \r\n");
  1155. }
  1156. if (!validate(std::get<2>(res),cmsis_tau,
  1157. ErrThreshold<T>::abserr_qr,ErrThreshold<T>::relerr_qr))
  1158. {
  1159. printf("QR tau failed \r\n");
  1160. }
  1161. std::cout << "=====\r\n";
  1162. }
  1163. #endif
  1164. template<typename T,int R,template<int> typename A>
  1165. auto cholesky(const Matrix<T,R,R,A>&a)
  1166. {
  1167. // Temporaries
  1168. #if defined(STATIC_TEST)
  1169. Matrix<T,R,R,TMP_ALLOC> g = a;
  1170. Vector<T,R,TMP_ALLOC> tmp;
  1171. #else
  1172. Matrix<T> g = a;
  1173. Vector<T> tmp(a.rows());
  1174. #endif
  1175. const int NBR = a.rows();
  1176. g.col(0,0) = g.col(0,0) * (T)(number_traits<T>::one() / _gen_sqrt(g(0,0)));
  1177. for(int j=1;j<NBR;j++)
  1178. {
  1179. dot(tmp.sub(j),g.sub(j,NBR,0,j) , g.row(j,0,j));
  1180. g.col(j,j) = (g.col(j,j) - tmp.sub(j)) * (T)(number_traits<T>::one() / _gen_sqrt(g(j,j)- tmp[j]));
  1181. }
  1182. return(g);
  1183. }
  1184. template<typename T,int R>
  1185. static void testCholesky()
  1186. {
  1187. std::cout << "----\r\n";
  1188. std::cout << R << " x " << R << "\r\n";
  1189. #if defined(STATIC_TEST)
  1190. PMat<T,R,R> a;
  1191. #else
  1192. PMat<T> a(R,R);
  1193. #endif
  1194. cmsis_init_cholesky(a.ptr(),R,R);
  1195. //std::cout << a;
  1196. INIT_SYSTICK;
  1197. START_CYCLE_MEASUREMENT;
  1198. startSectionNB(1);
  1199. // Not totally equivalent to CMSIS implementation
  1200. // It should be possible to rewrite it to avoid use of
  1201. // temporary buffer like CMSIS-DSP
  1202. auto res = cholesky(a);
  1203. stopSectionNB(1);
  1204. STOP_CYCLE_MEASUREMENT;
  1205. //std::cout << res;
  1206. PMat<T,R,R> cmsis_res(T{});
  1207. typename CMSISMatrixType<T>::type OUT;
  1208. OUT.numRows = R;
  1209. OUT.numCols = R;
  1210. OUT.pData = cmsis_res.ptr();
  1211. typename CMSISMatrixType<T>::type IN;
  1212. IN.numRows = R;
  1213. IN.numCols = R;
  1214. IN.pData = a.ptr();
  1215. //std::cout << "-------\r\n";
  1216. INIT_SYSTICK;
  1217. START_CYCLE_MEASUREMENT;
  1218. arm_status status=cmsis_cholesky(&IN,&OUT);
  1219. (void)status;
  1220. STOP_CYCLE_MEASUREMENT;
  1221. //std::cout << cmsis_res;
  1222. if (!validateLT(res,cmsis_res,
  1223. ErrThreshold<T>::abserr_cholesky,ErrThreshold<T>::relerr_cholesky))
  1224. {
  1225. printf("cholesky failed \r\n");
  1226. }
  1227. std::cout << "=====\r\n";
  1228. }
  1229. template<typename TT,typename ...T>
  1230. struct TESTINV
  1231. {
  1232. static void all()
  1233. {
  1234. testinv<TT,T::value...>();
  1235. }
  1236. };
  1237. template<typename TT,typename ...T>
  1238. struct TESTOUTER
  1239. {
  1240. static void all()
  1241. {
  1242. testouter<TT,T::value...>();
  1243. }
  1244. };
  1245. template<typename TT,typename ...T>
  1246. struct TESTMATVEC
  1247. {
  1248. static void all()
  1249. {
  1250. testmatvec<TT,T::value...>();
  1251. }
  1252. };
  1253. template<typename TT,typename ...T>
  1254. struct TESTCOMPLEXMATVEC
  1255. {
  1256. static void all()
  1257. {
  1258. testcomplexmatvec<TT,T::value...>();
  1259. }
  1260. };
  1261. template<typename TT,typename ...T>
  1262. struct TESTADD
  1263. {
  1264. static void all()
  1265. {
  1266. testadd <TT,T::value...>();
  1267. }
  1268. };
  1269. template<typename TT,typename ...T>
  1270. struct TESTMATTRANSPOSE
  1271. {
  1272. static void all()
  1273. {
  1274. testmattranspose<TT,T::value...>();
  1275. }
  1276. };
  1277. template<typename TT,typename ...T>
  1278. struct TESTMATMULT
  1279. {
  1280. static void all()
  1281. {
  1282. testmatmult<TT,T::value...>();
  1283. }
  1284. };
  1285. template<typename TT,typename ...T>
  1286. struct TESTSUBMATMULT
  1287. {
  1288. static void all()
  1289. {
  1290. testsubmatmult<TT,T::value...>();
  1291. }
  1292. };
  1293. template<typename TT,typename... T>
  1294. struct TEST_CASES
  1295. {
  1296. static void all()
  1297. {
  1298. (mp_push_front<T,TT>::all(),...);
  1299. }
  1300. };
  1301. template<template <typename, typename...> typename F,
  1302. typename... A>
  1303. using ALL_TESTS = mp_rename<mp_push_front<mp_product<F, A...>,float>,TEST_CASES>;
  1304. template<typename T>
  1305. void matrix_all_test()
  1306. {
  1307. #if defined(MATRIX_TEST)
  1308. #if !defined(SUBTEST1) && !defined(SUBTEST2)
  1309. const int nb_tails = TailForTests<T>::tail;
  1310. const int nb_loops = TailForTests<T>::loop;
  1311. using UNROLL = mp_rename<mp_list_v<1,2,4,8,9,11>,mp_list>;
  1312. #endif
  1313. #if defined(SUBTEST8) || defined(SUBTEST14)
  1314. using UNROLLA = mp_rename<mp_list_v<1>,mp_list>;
  1315. #endif
  1316. #if defined(SUBTEST9) || defined(SUBTEST15)
  1317. using UNROLLA = mp_rename<mp_list_v<2>,mp_list>;
  1318. #endif
  1319. #if defined(SUBTEST10) || defined(SUBTEST16)
  1320. using UNROLLA = mp_rename<mp_list_v<4>,mp_list>;
  1321. #endif
  1322. #if defined(SUBTEST11) || defined(SUBTEST17)
  1323. using UNROLLA = mp_rename<mp_list_v<8>,mp_list>;
  1324. #endif
  1325. #if defined(SUBTEST12) || defined(SUBTEST18)
  1326. using UNROLLA = mp_rename<mp_list_v<9>,mp_list>;
  1327. #endif
  1328. #if defined(SUBTEST13) || defined(SUBTEST19)
  1329. using UNROLLA = mp_rename<mp_list_v<11>,mp_list>;
  1330. #endif
  1331. #if !defined(SUBTEST1) && !defined(SUBTEST2)
  1332. using VEC = mp_rename<mp_list_v<1,
  1333. nb_tails,
  1334. nb_loops,
  1335. nb_loops+1,
  1336. nb_loops+nb_tails>,mp_list>;
  1337. #endif
  1338. //using res = mp_rename<mp_push_front<mp_product<TEST, A, B,C>,float>,TEST_CASES>;
  1339. //using res = ALL_TESTS<TEST,A,B,C>;
  1340. //PrintType<test>();
  1341. if constexpr (number_traits<T>::is_float)
  1342. {
  1343. #if defined(SUBTEST1)
  1344. title<T>("Householder");
  1345. testHouseholder<T,NBVEC_4>();
  1346. testHouseholder<T,NBVEC_16>();
  1347. testHouseholder<T,NBVEC_32>();
  1348. title<T>("QR");
  1349. testQR<T,NBVEC_4,NBVEC_4>();
  1350. testQR<T,NBVEC_16,NBVEC_16>();
  1351. testQR<T,NBVEC_32,NBVEC_32>();
  1352. title<T>("Cholesky");
  1353. testCholesky<T,NBVEC_4>();
  1354. testCholesky<T,NBVEC_16>();
  1355. testCholesky<T,NBVEC_32>();
  1356. #endif
  1357. #if defined(SUBTEST2)
  1358. title<T>("Matrix inverse");
  1359. testinv<T,NBVEC_4,NBVEC_4>();
  1360. testinv<T,NBVEC_8,NBVEC_8>();
  1361. testinv<T,NBVEC_16,NBVEC_16>();
  1362. #endif
  1363. }
  1364. #if defined(SUBTEST1)
  1365. title<T>("Matrix outer product");
  1366. testouter<T,NBVEC_4,NBVEC_4>();
  1367. testouter<T,NBVEC_8,NBVEC_8>();
  1368. testouter<T,NBVEC_16,NBVEC_16>();
  1369. #endif
  1370. #if defined(SUBTEST3)
  1371. title<T>("Matrix outer product");
  1372. ALL_TESTS<TESTOUTER,UNROLL,VEC>::all();
  1373. #endif
  1374. if constexpr (!std::is_same<T,double>::value)
  1375. {
  1376. #if defined(SUBTEST1)
  1377. title<T>("Matrix times vector");
  1378. testmatvec<T,NBVEC_4 ,NBVEC_4>();
  1379. testmatvec<T,NBVEC_16,NBVEC_16>();
  1380. testmatvec<T,NBVEC_32,NBVEC_32>();
  1381. testmatvec<T,NBVEC_44,NBVEC_44>();
  1382. testmatvec<T,NBVEC_47,NBVEC_47>();
  1383. #endif
  1384. #if defined(SUBTEST4)
  1385. title<T>("Matrix times vector");
  1386. ALL_TESTS<TESTMATVEC,UNROLL,VEC>::all();
  1387. #endif
  1388. #if defined(SUBTEST1)
  1389. title<T>("Matrix times vector expression");
  1390. testcomplexmatvec<T,NBVEC_4 ,NBVEC_4>();
  1391. testcomplexmatvec<T,NBVEC_16,NBVEC_16>();
  1392. testcomplexmatvec<T,NBVEC_32,NBVEC_32>();
  1393. testcomplexmatvec<T,NBVEC_44,NBVEC_44>();
  1394. testcomplexmatvec<T,NBVEC_47,NBVEC_47>();
  1395. #endif
  1396. #if defined(SUBTEST5)
  1397. title<T>("Matrix times vector expression");
  1398. ALL_TESTS<TESTCOMPLEXMATVEC,UNROLL,VEC>::all();
  1399. #endif
  1400. }
  1401. if constexpr (!std::is_same<T,Q7>::value && !std::is_same<T,double>::value)
  1402. {
  1403. #if defined(SUBTEST1)
  1404. title<T>("Matrix add");
  1405. testadd<T,NBVEC_4,NBVEC_4>();
  1406. testadd<T,NBVEC_8,NBVEC_8>();
  1407. testadd<T,NBVEC_16,NBVEC_16>();
  1408. #endif
  1409. #if defined(SUBTEST6)
  1410. title<T>("Matrix add");
  1411. ALL_TESTS<TESTADD,UNROLL,VEC>::all();
  1412. #endif
  1413. }
  1414. #if defined(SUBTEST1)
  1415. title<T>("Matrix diag");
  1416. testdiag<T,NBVEC_4,NBVEC_4>();
  1417. testdiag<T,NBVEC_8,NBVEC_8>();
  1418. testdiag<T,NBVEC_16,NBVEC_16>();
  1419. title<T>("Matrix submatrix");
  1420. testview<T,NBVEC_16,NBVEC_16>();
  1421. #endif
  1422. #if defined(SUBTEST1)
  1423. title<T>("Matrix multiply");
  1424. testmatmult<T,NBVEC_4,NBVEC_4,NBVEC_4>();
  1425. testmatmult<T,NBVEC_16,NBVEC_16,NBVEC_16>();
  1426. testmatmult<T,NBVEC_32,NBVEC_32,NBVEC_32>();
  1427. #endif
  1428. #if defined(SUBTEST1)
  1429. title<T>("Matrix transpose");
  1430. testmattranspose<T,NBVEC_2,NBVEC_2>();
  1431. testmattranspose<T,NBVEC_3,NBVEC_3>();
  1432. testmattranspose<T,NBVEC_4,NBVEC_4>();
  1433. testmattranspose<T,NBVEC_16,NBVEC_16>();
  1434. testmattranspose<T,NBVEC_32,NBVEC_32>();
  1435. #endif
  1436. #if defined(SUBTEST7)
  1437. title<T>("Matrix transpose");
  1438. ALL_TESTS<TESTMATTRANSPOSE,UNROLL,VEC>::all();
  1439. #endif
  1440. #if defined(SUBTEST8) || defined(SUBTEST9)|| defined(SUBTEST10)|| defined(SUBTEST11)|| defined(SUBTEST12)|| defined(SUBTEST13)
  1441. title<T>("Matrix multiply");
  1442. ALL_TESTS<TESTMATMULT,UNROLLA,VEC,UNROLL>::all();
  1443. #endif
  1444. #if defined(SUBTEST1)
  1445. title<T>("Submatrix multiply");
  1446. testsubmatmult<T,NBVEC_4,NBVEC_4,NBVEC_4>();
  1447. testsubmatmult<T,NBVEC_16,NBVEC_16,NBVEC_16>();
  1448. #endif
  1449. #if defined(SUBTEST14) || defined(SUBTEST15) || defined(SUBTEST16)|| defined(SUBTEST17)|| defined(SUBTEST18)|| defined(SUBTEST19)
  1450. title<T>("Submatrix multiply");
  1451. ALL_TESTS<TESTSUBMATMULT,UNROLLA,VEC,UNROLL>::all();
  1452. #endif
  1453. //testsubmatmult<T,NBVEC_32>();
  1454. #endif
  1455. };
  1456. void matrix_test()
  1457. {
  1458. #if defined(MATRIX_TEST)
  1459. #if defined(F64_DT)
  1460. matrix_all_test<double>();
  1461. #endif
  1462. #if defined(F32_DT)
  1463. matrix_all_test<float>();
  1464. #endif
  1465. #if defined(F16_DT) && !defined(DISABLEFLOAT16)
  1466. matrix_all_test<float16_t>();
  1467. #endif
  1468. #if defined(Q31_DT)
  1469. matrix_all_test<Q31>();
  1470. #endif
  1471. #if defined(Q15_DT)
  1472. matrix_all_test<Q15>();
  1473. #endif
  1474. #if defined(Q7_DT)
  1475. matrix_all_test<Q7>();
  1476. #endif
  1477. #endif
  1478. }