UnaryTestsF32.cpp 32 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054
  1. #include "UnaryTestsF32.h"
  2. #include "Error.h"
  3. #define SNR_THRESHOLD 120
  4. /*
  5. Reference patterns are generated with
  6. a double precision computation.
  7. */
  8. #define REL_ERROR (1.0e-5)
  9. #define ABS_ERROR (1.0e-5)
  10. /*
  11. Comparisons for Householder
  12. */
  13. #define SNR_HOUSEHOLDER_THRESHOLD 140
  14. #define REL_HOUSEHOLDER_ERROR (1.0e-7)
  15. #define ABS_HOUSEHOLDER_ERROR (1.0e-7)
  16. /*
  17. Comparisons for QR decomposition
  18. */
  19. #define SNR_QR_THRESHOLD 90
  20. #define REL_QR_ERROR (1.0e-4)
  21. #define ABS_QR_ERROR (2.0e-4)
  22. /*
  23. Comparisons for inverse
  24. */
  25. /* Not very accurate for big matrix.
  26. But big matrix needed for checking the vectorized code */
  27. #define SNR_THRESHOLD_INV 100
  28. #define REL_ERROR_INV (3.0e-5)
  29. #define ABS_ERROR_INV (1.0e-5)
  30. /*
  31. Comparison for Cholesky
  32. */
  33. #define SNR_THRESHOLD_CHOL 92
  34. #define REL_ERROR_CHOL (1.0e-5)
  35. #define ABS_ERROR_CHOL (5.0e-4)
  36. /* LDLT comparison */
  37. #define REL_ERROR_LDLT (1e-5)
  38. #define ABS_ERROR_LDLT (1e-5)
  39. #define REL_ERROR_LDLT_SPDO (1e-5)
  40. #define ABS_ERROR_LDLT_SDPO (2e-1)
  41. /* Upper bound of maximum matrix dimension used by Python */
  42. #define MAXMATRIXDIM 40
  43. static void checkInnerTailOverflow(float32_t *b)
  44. {
  45. ASSERT_TRUE(b[0] == 0);
  46. ASSERT_TRUE(b[1] == 0);
  47. ASSERT_TRUE(b[2] == 0);
  48. ASSERT_TRUE(b[3] == 0);
  49. }
  50. #define LOADDATA2() \
  51. const float32_t *inp1=input1.ptr(); \
  52. const float32_t *inp2=input2.ptr(); \
  53. \
  54. float32_t *ap=a.ptr(); \
  55. float32_t *bp=b.ptr(); \
  56. \
  57. float32_t *outp=output.ptr(); \
  58. int16_t *dimsp = dims.ptr(); \
  59. int nbMatrixes = dims.nbSamples() >> 1;\
  60. int rows,columns; \
  61. int i;
  62. #define LOADDATA1() \
  63. const float32_t *inp1=input1.ptr(); \
  64. \
  65. float32_t *ap=a.ptr(); \
  66. \
  67. float32_t *outp=output.ptr(); \
  68. int16_t *dimsp = dims.ptr(); \
  69. int nbMatrixes = dims.nbSamples() >> 1;\
  70. int rows,columns; \
  71. int i;
  72. #define PREPAREDATA2() \
  73. in1.numRows=rows; \
  74. in1.numCols=columns; \
  75. memcpy((void*)ap,(const void*)inp1,sizeof(float32_t)*rows*columns);\
  76. in1.pData = ap; \
  77. \
  78. in2.numRows=rows; \
  79. in2.numCols=columns; \
  80. memcpy((void*)bp,(const void*)inp2,sizeof(float32_t)*rows*columns);\
  81. in2.pData = bp; \
  82. \
  83. out.numRows=rows; \
  84. out.numCols=columns; \
  85. out.pData = outp;
  86. #define PREPAREDATALT() \
  87. in1.numRows=rows; \
  88. in1.numCols=rows; \
  89. memcpy((void*)ap,(const void*)inp1,sizeof(float32_t)*rows*rows); \
  90. in1.pData = ap; \
  91. \
  92. in2.numRows=rows; \
  93. in2.numCols=columns; \
  94. memcpy((void*)bp,(const void*)inp2,sizeof(float32_t)*rows*columns);\
  95. in2.pData = bp; \
  96. \
  97. out.numRows=rows; \
  98. out.numCols=columns; \
  99. out.pData = outp;
  100. #define PREPAREDATA1(TRANSPOSED) \
  101. in1.numRows=rows; \
  102. in1.numCols=columns; \
  103. memcpy((void*)ap,(const void*)inp1,sizeof(float32_t)*rows*columns);\
  104. in1.pData = ap; \
  105. \
  106. if (TRANSPOSED) \
  107. { \
  108. out.numRows=columns; \
  109. out.numCols=rows; \
  110. } \
  111. else \
  112. { \
  113. out.numRows=rows; \
  114. out.numCols=columns; \
  115. } \
  116. out.pData = outp;
  117. #define PREPAREDATA1C(TRANSPOSED) \
  118. in1.numRows=rows; \
  119. in1.numCols=columns; \
  120. memcpy((void*)ap,(const void*)inp1,2*sizeof(float32_t)*rows*columns);\
  121. in1.pData = ap; \
  122. \
  123. if (TRANSPOSED) \
  124. { \
  125. out.numRows=columns; \
  126. out.numCols=rows; \
  127. } \
  128. else \
  129. { \
  130. out.numRows=rows; \
  131. out.numCols=columns; \
  132. } \
  133. out.pData = outp;
  134. #define LOADVECDATA2() \
  135. const float32_t *inp1=input1.ptr(); \
  136. const float32_t *inp2=input2.ptr(); \
  137. \
  138. float32_t *ap=a.ptr(); \
  139. float32_t *bp=b.ptr(); \
  140. \
  141. float32_t *outp=output.ptr(); \
  142. int16_t *dimsp = dims.ptr(); \
  143. int nbMatrixes = dims.nbSamples() / 2;\
  144. int rows,internal; \
  145. int i;
  146. #define PREPAREVECDATA2() \
  147. in1.numRows=rows; \
  148. in1.numCols=internal; \
  149. memcpy((void*)ap,(const void*)inp1,sizeof(float32_t)*rows*internal);\
  150. in1.pData = ap; \
  151. \
  152. memcpy((void*)bp,(const void*)inp2,sizeof(float32_t)*internal);
  153. #define PREPAREDATALL1() \
  154. in1.numRows=rows; \
  155. in1.numCols=columns; \
  156. memcpy((void*)ap,(const void*)inp1,sizeof(float32_t)*rows*columns);\
  157. in1.pData = ap; \
  158. \
  159. outll.numRows=rows; \
  160. outll.numCols=columns; \
  161. \
  162. outll.pData = outllp;
  163. #define SWAP_ROWS(A,i,j) \
  164. for(int w=0;w < n; w++) \
  165. { \
  166. float64_t tmp; \
  167. tmp = A[i*n + w]; \
  168. A[i*n + w] = A[j*n + w];\
  169. A[j*n + w] = tmp; \
  170. }
  171. void UnaryTestsF32::test_householder_f32()
  172. {
  173. int32_t vecDim;
  174. const int16_t *dimsp = dims.ptr();
  175. const int nbVectors = dims.nbSamples();
  176. const float32_t *inp1=input1.ptr();
  177. float32_t *outp=output.ptr();
  178. float32_t *outBetap=outputBeta.ptr();
  179. for(int i=0; i < nbVectors ; i++)
  180. {
  181. vecDim = *dimsp++;
  182. float32_t beta = arm_householder_f32(inp1,DEFAULT_HOUSEHOLDER_THRESHOLD_F32,vecDim,outp);
  183. *outBetap = beta;
  184. outp += vecDim;
  185. inp1 += vecDim;
  186. outBetap++;
  187. checkInnerTailOverflow(outp);
  188. checkInnerTailOverflow(outBetap);
  189. }
  190. ASSERT_EMPTY_TAIL(output);
  191. ASSERT_EMPTY_TAIL(outputBeta);
  192. ASSERT_SNR(output,ref,(float32_t)SNR_HOUSEHOLDER_THRESHOLD);
  193. ASSERT_SNR(outputBeta,refBeta,(float32_t)SNR_HOUSEHOLDER_THRESHOLD);
  194. ASSERT_CLOSE_ERROR(output,ref,ABS_HOUSEHOLDER_ERROR,REL_HOUSEHOLDER_ERROR);
  195. ASSERT_CLOSE_ERROR(outputBeta,refBeta,ABS_HOUSEHOLDER_ERROR,REL_HOUSEHOLDER_ERROR);
  196. }
  197. void UnaryTestsF32::test_mat_qr_f32()
  198. {
  199. int32_t rows, columns, rank;
  200. int nb;
  201. const int16_t *dimsp = dims.ptr();
  202. const int nbMatrixes = dims.nbSamples() / 3;
  203. const float32_t *inp1=input1.ptr();
  204. float32_t *outTaup=outputTau.ptr();
  205. float32_t *outRp=outputR.ptr();
  206. float32_t *outQp=outputQ.ptr();
  207. float32_t *pTmpA=a.ptr();
  208. float32_t *pTmpB=b.ptr();
  209. (void) outTaup;
  210. (void) outRp;
  211. (void) outQp;
  212. (void)nbMatrixes;
  213. (void)nb;
  214. nb=0;
  215. for(int i=0; i < nbMatrixes ; i++)
  216. //for(int i=0; i < 1 ; i++)
  217. {
  218. rows = *dimsp++;
  219. columns = *dimsp++;
  220. rank = *dimsp++;
  221. (void)rank;
  222. //printf("--> %d %d\n",nb,i);
  223. nb += rows * columns;
  224. in1.numRows=rows;
  225. in1.numCols=columns;
  226. in1.pData = (float32_t*)inp1;
  227. outR.numRows = rows;
  228. outR.numCols = columns;
  229. outR.pData = (float32_t*)outRp;
  230. outQ.numRows = rows;
  231. outQ.numCols = rows;
  232. outQ.pData = (float32_t*)outQp;
  233. arm_status status=arm_mat_qr_f32(&in1,DEFAULT_HOUSEHOLDER_THRESHOLD_F32,&outR,&outQ,outTaup,pTmpA,pTmpB);
  234. ASSERT_TRUE(status==ARM_MATH_SUCCESS);
  235. inp1 += rows * columns;
  236. outRp += rows * columns;
  237. outQp += rows * rows;
  238. outTaup += columns;
  239. checkInnerTailOverflow(outRp);
  240. checkInnerTailOverflow(outQp);
  241. checkInnerTailOverflow(outTaup);
  242. }
  243. ASSERT_EMPTY_TAIL(outputR);
  244. ASSERT_EMPTY_TAIL(outputQ);
  245. ASSERT_EMPTY_TAIL(outputTau);
  246. ASSERT_SNR(refQ,outputQ,(float32_t)SNR_QR_THRESHOLD);
  247. ASSERT_SNR(refR,outputR,(float32_t)SNR_QR_THRESHOLD);
  248. ASSERT_SNR(refTau,outputTau,(float32_t)SNR_QR_THRESHOLD);
  249. ASSERT_CLOSE_ERROR(refQ,outputQ,ABS_QR_ERROR,REL_QR_ERROR);
  250. ASSERT_CLOSE_ERROR(refR,outputR,ABS_QR_ERROR,REL_QR_ERROR);
  251. ASSERT_CLOSE_ERROR(refTau,outputTau,ABS_QR_ERROR,REL_QR_ERROR);
  252. }
  253. void UnaryTestsF32::test_mat_vec_mult_f32()
  254. {
  255. LOADVECDATA2();
  256. for(i=0;i < nbMatrixes ; i ++)
  257. {
  258. rows = *dimsp++;
  259. internal = *dimsp++;
  260. PREPAREVECDATA2();
  261. arm_mat_vec_mult_f32(&this->in1, bp, outp);
  262. outp += rows ;
  263. checkInnerTailOverflow(outp);
  264. }
  265. ASSERT_EMPTY_TAIL(output);
  266. ASSERT_SNR(output,ref,(float32_t)SNR_THRESHOLD);
  267. ASSERT_CLOSE_ERROR(output,ref,ABS_ERROR,REL_ERROR);
  268. }
  269. void UnaryTestsF32::test_mat_add_f32()
  270. {
  271. LOADDATA2();
  272. arm_status status;
  273. for(i=0;i < nbMatrixes ; i ++)
  274. {
  275. rows = *dimsp++;
  276. columns = *dimsp++;
  277. PREPAREDATA2();
  278. status=arm_mat_add_f32(&this->in1,&this->in2,&this->out);
  279. ASSERT_TRUE(status==ARM_MATH_SUCCESS);
  280. outp += (rows * columns);
  281. checkInnerTailOverflow(outp);
  282. }
  283. ASSERT_EMPTY_TAIL(output);
  284. ASSERT_SNR(output,ref,(float32_t)SNR_THRESHOLD);
  285. ASSERT_CLOSE_ERROR(output,ref,ABS_ERROR,REL_ERROR);
  286. }
  287. void UnaryTestsF32::test_mat_sub_f32()
  288. {
  289. LOADDATA2();
  290. arm_status status;
  291. for(i=0;i < nbMatrixes ; i ++)
  292. {
  293. rows = *dimsp++;
  294. columns = *dimsp++;
  295. PREPAREDATA2();
  296. status=arm_mat_sub_f32(&this->in1,&this->in2,&this->out);
  297. ASSERT_TRUE(status==ARM_MATH_SUCCESS);
  298. outp += (rows * columns);
  299. checkInnerTailOverflow(outp);
  300. }
  301. ASSERT_EMPTY_TAIL(output);
  302. ASSERT_SNR(output,ref,(float32_t)SNR_THRESHOLD);
  303. ASSERT_CLOSE_ERROR(output,ref,ABS_ERROR,REL_ERROR);
  304. }
  305. void UnaryTestsF32::test_mat_scale_f32()
  306. {
  307. LOADDATA1();
  308. arm_status status;
  309. for(i=0;i < nbMatrixes ; i ++)
  310. {
  311. rows = *dimsp++;
  312. columns = *dimsp++;
  313. PREPAREDATA1(false);
  314. status=arm_mat_scale_f32(&this->in1,0.5f,&this->out);
  315. ASSERT_TRUE(status==ARM_MATH_SUCCESS);
  316. outp += (rows * columns);
  317. checkInnerTailOverflow(outp);
  318. }
  319. ASSERT_EMPTY_TAIL(output);
  320. ASSERT_SNR(output,ref,(float32_t)SNR_THRESHOLD);
  321. ASSERT_CLOSE_ERROR(output,ref,ABS_ERROR,REL_ERROR);
  322. }
  323. void UnaryTestsF32::test_mat_trans_f32()
  324. {
  325. LOADDATA1();
  326. arm_status status;
  327. for(i=0;i < nbMatrixes ; i ++)
  328. {
  329. rows = *dimsp++;
  330. columns = *dimsp++;
  331. PREPAREDATA1(true);
  332. status=arm_mat_trans_f32(&this->in1,&this->out);
  333. ASSERT_TRUE(status==ARM_MATH_SUCCESS);
  334. outp += (rows * columns);
  335. checkInnerTailOverflow(outp);
  336. }
  337. ASSERT_EMPTY_TAIL(output);
  338. ASSERT_SNR(output,ref,(float32_t)SNR_THRESHOLD);
  339. ASSERT_CLOSE_ERROR(output,ref,ABS_ERROR,REL_ERROR);
  340. }
  341. void UnaryTestsF32::test_mat_cmplx_trans_f32()
  342. {
  343. LOADDATA1();
  344. arm_status status;
  345. for(i=0;i < nbMatrixes ; i ++)
  346. {
  347. rows = *dimsp++;
  348. columns = *dimsp++;
  349. PREPAREDATA1C(true);
  350. status=arm_mat_cmplx_trans_f32(&this->in1,&this->out);
  351. ASSERT_TRUE(status==ARM_MATH_SUCCESS);
  352. outp += 2*(rows * columns);
  353. checkInnerTailOverflow(outp);
  354. }
  355. ASSERT_EMPTY_TAIL(output);
  356. ASSERT_SNR(output,ref,(float32_t)SNR_THRESHOLD);
  357. ASSERT_CLOSE_ERROR(output,ref,ABS_ERROR,REL_ERROR);
  358. }
  359. static void refInnerTail(float32_t *b)
  360. {
  361. b[0] = 1.0f;
  362. b[1] = -2.0f;
  363. b[2] = 3.0f;
  364. b[3] = -4.0f;
  365. }
  366. static void checkInnerTail(float32_t *b)
  367. {
  368. ASSERT_TRUE(b[0] == 1.0f);
  369. ASSERT_TRUE(b[1] == -2.0f);
  370. ASSERT_TRUE(b[2] == 3.0f);
  371. ASSERT_TRUE(b[3] == -4.0f);
  372. }
  373. void UnaryTestsF32::test_mat_inverse_f32()
  374. {
  375. const float32_t *inp1=input1.ptr();
  376. float32_t *ap=a.ptr();
  377. float32_t *outp=output.ptr();
  378. int16_t *dimsp = dims.ptr();
  379. int nbMatrixes = dims.nbSamples();
  380. int rows,columns;
  381. int i;
  382. arm_status status;
  383. for(i=0;i < nbMatrixes ; i ++)
  384. {
  385. rows = *dimsp++;
  386. columns = rows;
  387. PREPAREDATA1(false);
  388. refInnerTail(outp+(rows * columns));
  389. status=arm_mat_inverse_f32(&this->in1,&this->out);
  390. ASSERT_TRUE(status==ARM_MATH_SUCCESS);
  391. outp += (rows * columns);
  392. inp1 += (rows * columns);
  393. checkInnerTail(outp);
  394. }
  395. ASSERT_SNR(output,ref,(float32_t)SNR_THRESHOLD_INV);
  396. ASSERT_CLOSE_ERROR(output,ref,ABS_ERROR_INV,REL_ERROR_INV);
  397. }
  398. void UnaryTestsF32::test_mat_cholesky_dpo_f32()
  399. {
  400. float32_t *ap=a.ptr();
  401. const float32_t *inp1=input1.ptr();
  402. float32_t *outp=output.ptr();
  403. int16_t *dimsp = dims.ptr();
  404. int nbMatrixes = dims.nbSamples();
  405. int rows,columns;
  406. int i;
  407. arm_status status;
  408. for(i=0;i < nbMatrixes ; i ++)
  409. {
  410. rows = *dimsp++;
  411. columns = rows;
  412. PREPAREDATA1(false);
  413. status=arm_mat_cholesky_f32(&this->in1,&this->out);
  414. ASSERT_TRUE(status==ARM_MATH_SUCCESS);
  415. outp += (rows * columns);
  416. inp1 += (rows * columns);
  417. checkInnerTailOverflow(outp);
  418. }
  419. ASSERT_EMPTY_TAIL(output);
  420. ASSERT_SNR(output,ref,(float32_t)SNR_THRESHOLD_CHOL);
  421. ASSERT_CLOSE_ERROR(ref,output,ABS_ERROR_CHOL,REL_ERROR_CHOL);
  422. }
  423. void UnaryTestsF32::test_solve_upper_triangular_f32()
  424. {
  425. float32_t *ap=a.ptr();
  426. const float32_t *inp1=input1.ptr();
  427. float32_t *bp=b.ptr();
  428. const float32_t *inp2=input2.ptr();
  429. float32_t *outp=output.ptr();
  430. int16_t *dimsp = dims.ptr();
  431. int nbMatrixes = dims.nbSamples()>>1;
  432. int rows,columns;
  433. int i;
  434. arm_status status;
  435. for(i=0;i < nbMatrixes ; i ++)
  436. {
  437. rows = *dimsp++;
  438. columns = *dimsp++;
  439. PREPAREDATALT();
  440. status=arm_mat_solve_upper_triangular_f32(&this->in1,&this->in2,&this->out);
  441. ASSERT_TRUE(status==ARM_MATH_SUCCESS);
  442. outp += (rows * columns);
  443. inp1 += (rows * rows);
  444. inp2 += (rows * columns);
  445. checkInnerTailOverflow(outp);
  446. }
  447. ASSERT_EMPTY_TAIL(output);
  448. ASSERT_SNR(output,ref,(float32_t)SNR_THRESHOLD);
  449. ASSERT_CLOSE_ERROR(ref,output,ABS_ERROR,REL_ERROR);
  450. }
  451. void UnaryTestsF32::test_solve_lower_triangular_f32()
  452. {
  453. float32_t *ap=a.ptr();
  454. const float32_t *inp1=input1.ptr();
  455. float32_t *bp=b.ptr();
  456. const float32_t *inp2=input2.ptr();
  457. float32_t *outp=output.ptr();
  458. int16_t *dimsp = dims.ptr();
  459. int nbMatrixes = dims.nbSamples() >> 1;
  460. int rows,columns;
  461. int i;
  462. arm_status status;
  463. for(i=0;i < nbMatrixes ; i ++)
  464. {
  465. rows = *dimsp++;
  466. columns = *dimsp++;
  467. PREPAREDATALT();
  468. status=arm_mat_solve_lower_triangular_f32(&this->in1,&this->in2,&this->out);
  469. ASSERT_TRUE(status==ARM_MATH_SUCCESS);
  470. outp += (rows * columns);
  471. inp1 += (rows * rows);
  472. inp2 += (rows * columns);
  473. checkInnerTailOverflow(outp);
  474. }
  475. ASSERT_EMPTY_TAIL(output);
  476. ASSERT_SNR(output,ref,(float32_t)SNR_THRESHOLD);
  477. ASSERT_CLOSE_ERROR(ref,output,ABS_ERROR,REL_ERROR);
  478. }
  479. static void trans_f64(const float64_t *src, float64_t *dst, int n)
  480. {
  481. for(int r=0; r<n ; r++)
  482. {
  483. for(int c=0; c<n ; c++)
  484. {
  485. dst[c*n+r] = src[r*n+c];
  486. }
  487. }
  488. }
  489. static void trans_f32_f64(const float32_t *src, float64_t *dst, int n)
  490. {
  491. for(int r=0; r<n ; r++)
  492. {
  493. for(int c=0; c<n ; c++)
  494. {
  495. dst[c*n+r] = (float64_t)src[r*n+c];
  496. }
  497. }
  498. }
  499. static void mult_f32_f64(const float32_t *srcA, const float64_t *srcB, float64_t *dst,int n)
  500. {
  501. for(int r=0; r<n ; r++)
  502. {
  503. for(int c=0; c<n ; c++)
  504. {
  505. float64_t sum=0.0;
  506. for(int k=0; k < n ; k++)
  507. {
  508. sum += (float64_t)srcA[r*n+k] * srcB[k*n+c];
  509. }
  510. dst[r*n+c] = sum;
  511. }
  512. }
  513. }
  514. static void mult_f64_f64(const float64_t *srcA, const float64_t *srcB, float64_t *dst,int n)
  515. {
  516. for(int r=0; r<n ; r++)
  517. {
  518. for(int c=0; c<n ; c++)
  519. {
  520. float64_t sum=0.0;
  521. for(int k=0; k < n ; k++)
  522. {
  523. sum += srcA[r*n+k] * srcB[k*n+c];
  524. }
  525. dst[r*n+c] = sum;
  526. }
  527. }
  528. }
  529. void UnaryTestsF32::compute_ldlt_error(const int n,const int16_t *outpp)
  530. {
  531. float64_t *tmpa = tmpapat.ptr() ;
  532. float64_t *tmpb = tmpbpat.ptr() ;
  533. float64_t *tmpc = tmpcpat.ptr() ;
  534. /* Compute P A P^t */
  535. // Create identiy matrix
  536. for(int r=0; r < n; r++)
  537. {
  538. for(int c=0; c < n; c++)
  539. {
  540. if (r == c)
  541. {
  542. tmpa[r*n+c] = 1.0;
  543. }
  544. else
  545. {
  546. tmpa[r*n+c] = 0.0;
  547. }
  548. }
  549. }
  550. // Create permutation matrix
  551. for(int r=0;r < n; r++)
  552. {
  553. SWAP_ROWS(tmpa,r,outpp[r]);
  554. }
  555. trans_f64((const float64_t*)tmpa,tmpb,n);
  556. mult_f32_f64((const float32_t*)this->in1.pData,(const float64_t*)tmpb,tmpc,n);
  557. mult_f64_f64((const float64_t*)tmpa,(const float64_t*)tmpc,outa,n);
  558. /* Compute L D L^t */
  559. trans_f32_f64((const float32_t*)this->outll.pData,tmpc,n);
  560. mult_f32_f64((const float32_t*)this->outd.pData,(const float64_t*)tmpc,tmpa,n);
  561. mult_f32_f64((const float32_t*)this->outll.pData,(const float64_t*)tmpa,outb,n);
  562. }
  563. void UnaryTestsF32::test_mat_ldl_f32()
  564. {
  565. float32_t *ap=a.ptr();
  566. const float32_t *inp1=input1.ptr();
  567. float32_t *outllp=outputll.ptr();
  568. float32_t *outdp=outputd.ptr();
  569. int16_t *outpp=outputp.ptr();
  570. outa=outputa.ptr();
  571. outb=outputb.ptr();
  572. int16_t *dimsp = dims.ptr();
  573. int nbMatrixes = dims.nbSamples();
  574. int rows,columns;
  575. int i;
  576. arm_status status;
  577. for(i=0;i < nbMatrixes ; i ++)
  578. {
  579. rows = *dimsp++;
  580. columns = rows;
  581. PREPAREDATALL1();
  582. outd.numRows=rows;
  583. outd.numCols=columns;
  584. outd.pData=outdp;
  585. memset(outpp,0,rows*sizeof(uint16_t));
  586. memset(outdp,0,columns*rows*sizeof(float32_t));
  587. status=arm_mat_ldlt_f32(&this->in1,&this->outll,&this->outd,(uint16_t*)outpp);
  588. ASSERT_TRUE(status==ARM_MATH_SUCCESS);
  589. compute_ldlt_error(rows,outpp);
  590. outllp += (rows * columns);
  591. outdp += (rows * columns);
  592. outpp += rows;
  593. outa += (rows * columns);
  594. outb +=(rows * columns);
  595. inp1 += (rows * columns);
  596. checkInnerTailOverflow(outllp);
  597. checkInnerTailOverflow(outdp);
  598. }
  599. ASSERT_EMPTY_TAIL(outputll);
  600. ASSERT_EMPTY_TAIL(outputd);
  601. ASSERT_EMPTY_TAIL(outputp);
  602. ASSERT_EMPTY_TAIL(outputa);
  603. ASSERT_EMPTY_TAIL(outputb);
  604. ASSERT_CLOSE_ERROR(outputa,outputb,snrAbs,snrRel);
  605. }
  606. void UnaryTestsF32::setUp(Testing::testID_t id,std::vector<Testing::param_t>& params,Client::PatternMgr *mgr)
  607. {
  608. (void)params;
  609. switch(id)
  610. {
  611. case TEST_MAT_ADD_F32_1:
  612. input1.reload(UnaryTestsF32::INPUTS1_F32_ID,mgr);
  613. input2.reload(UnaryTestsF32::INPUTS2_F32_ID,mgr);
  614. dims.reload(UnaryTestsF32::DIMSUNARY1_S16_ID,mgr);
  615. ref.reload(UnaryTestsF32::REFADD1_F32_ID,mgr);
  616. output.create(ref.nbSamples(),UnaryTestsF32::OUT_F32_ID,mgr);
  617. a.create(MAXMATRIXDIM*MAXMATRIXDIM,UnaryTestsF32::TMPA_F32_ID,mgr);
  618. b.create(MAXMATRIXDIM*MAXMATRIXDIM,UnaryTestsF32::TMPB_F32_ID,mgr);
  619. break;
  620. case TEST_MAT_SUB_F32_2:
  621. input1.reload(UnaryTestsF32::INPUTS1_F32_ID,mgr);
  622. input2.reload(UnaryTestsF32::INPUTS2_F32_ID,mgr);
  623. dims.reload(UnaryTestsF32::DIMSUNARY1_S16_ID,mgr);
  624. ref.reload(UnaryTestsF32::REFSUB1_F32_ID,mgr);
  625. output.create(ref.nbSamples(),UnaryTestsF32::OUT_F32_ID,mgr);
  626. a.create(MAXMATRIXDIM*MAXMATRIXDIM,UnaryTestsF32::TMPA_F32_ID,mgr);
  627. b.create(MAXMATRIXDIM*MAXMATRIXDIM,UnaryTestsF32::TMPB_F32_ID,mgr);
  628. break;
  629. case TEST_MAT_SCALE_F32_3:
  630. input1.reload(UnaryTestsF32::INPUTS1_F32_ID,mgr);
  631. dims.reload(UnaryTestsF32::DIMSUNARY1_S16_ID,mgr);
  632. ref.reload(UnaryTestsF32::REFSCALE1_F32_ID,mgr);
  633. output.create(ref.nbSamples(),UnaryTestsF32::OUT_F32_ID,mgr);
  634. a.create(MAXMATRIXDIM*MAXMATRIXDIM,UnaryTestsF32::TMPA_F32_ID,mgr);
  635. break;
  636. case TEST_MAT_TRANS_F32_4:
  637. input1.reload(UnaryTestsF32::INPUTS1_F32_ID,mgr);
  638. dims.reload(UnaryTestsF32::DIMSUNARY1_S16_ID,mgr);
  639. ref.reload(UnaryTestsF32::REFTRANS1_F32_ID,mgr);
  640. output.create(ref.nbSamples(),UnaryTestsF32::OUT_F32_ID,mgr);
  641. a.create(MAXMATRIXDIM*MAXMATRIXDIM,UnaryTestsF32::TMPA_F32_ID,mgr);
  642. break;
  643. case TEST_MAT_INVERSE_F32_5:
  644. input1.reload(UnaryTestsF32::INPUTSINV_F32_ID,mgr);
  645. dims.reload(UnaryTestsF32::DIMSINVERT1_S16_ID,mgr);
  646. ref.reload(UnaryTestsF32::REFINV1_F32_ID,mgr);
  647. output.create(ref.nbSamples(),UnaryTestsF32::OUT_F32_ID,mgr);
  648. a.create(MAXMATRIXDIM*MAXMATRIXDIM,UnaryTestsF32::TMPA_F32_ID,mgr);
  649. break;
  650. case TEST_MAT_VEC_MULT_F32_6:
  651. input1.reload(UnaryTestsF32::INPUTS1_F32_ID,mgr);
  652. input2.reload(UnaryTestsF32::INPUTVEC1_F32_ID,mgr);
  653. dims.reload(UnaryTestsF32::DIMSUNARY1_S16_ID,mgr);
  654. ref.reload(UnaryTestsF32::REFVECMUL1_F32_ID,mgr);
  655. output.create(ref.nbSamples(),UnaryTestsF32::OUT_F32_ID,mgr);
  656. a.create(MAXMATRIXDIM*MAXMATRIXDIM,UnaryTestsF32::TMPA_F32_ID,mgr);
  657. b.create(MAXMATRIXDIM,UnaryTestsF32::TMPB_F32_ID,mgr);
  658. break;
  659. case TEST_MAT_CMPLX_TRANS_F32_7:
  660. input1.reload(UnaryTestsF32::INPUTSC1_F32_ID,mgr);
  661. dims.reload(UnaryTestsF32::DIMSUNARY1_S16_ID,mgr);
  662. ref.reload(UnaryTestsF32::REFTRANSC1_F32_ID,mgr);
  663. output.create(ref.nbSamples(),UnaryTestsF32::OUT_F32_ID,mgr);
  664. a.create(MAXMATRIXDIM*MAXMATRIXDIM,UnaryTestsF32::TMPA_F32_ID,mgr);
  665. break;
  666. case TEST_MAT_CHOLESKY_DPO_F32_8:
  667. input1.reload(UnaryTestsF32::INPUTSCHOLESKY1_DPO_F32_ID,mgr);
  668. dims.reload(UnaryTestsF32::DIMSCHOLESKY1_DPO_S16_ID,mgr);
  669. ref.reload(UnaryTestsF32::REFCHOLESKY1_DPO_F32_ID,mgr);
  670. output.create(ref.nbSamples(),UnaryTestsF32::OUT_F32_ID,mgr);
  671. a.create(MAXMATRIXDIM*MAXMATRIXDIM,UnaryTestsF32::TMPA_F32_ID,mgr);
  672. break;
  673. case TEST_SOLVE_UPPER_TRIANGULAR_F32_9:
  674. input1.reload(UnaryTestsF32::INPUT_MAT_UTSOLVE_F32_ID,mgr);
  675. input2.reload(UnaryTestsF32::INPUT_VEC_LTSOLVE_F32_ID,mgr);
  676. dims.reload(UnaryTestsF32::DIM_LTSOLVE_F32_ID,mgr);
  677. ref.reload(UnaryTestsF32::REF_UT_SOLVE_F32_ID,mgr);
  678. output.create(ref.nbSamples(),UnaryTestsF32::OUT_F32_ID,mgr);
  679. a.create(MAXMATRIXDIM*MAXMATRIXDIM,UnaryTestsF32::TMPA_F32_ID,mgr);
  680. b.create(MAXMATRIXDIM*MAXMATRIXDIM,UnaryTestsF32::TMPB_F32_ID,mgr);
  681. break;
  682. case TEST_SOLVE_LOWER_TRIANGULAR_F32_10:
  683. input1.reload(UnaryTestsF32::INPUT_MAT_LTSOLVE_F32_ID,mgr);
  684. input2.reload(UnaryTestsF32::INPUT_VEC_LTSOLVE_F32_ID,mgr);
  685. dims.reload(UnaryTestsF32::DIM_LTSOLVE_F32_ID,mgr);
  686. ref.reload(UnaryTestsF32::REF_LT_SOLVE_F32_ID,mgr);
  687. output.create(ref.nbSamples(),UnaryTestsF32::OUT_F32_ID,mgr);
  688. a.create(MAXMATRIXDIM*MAXMATRIXDIM,UnaryTestsF32::TMPA_F32_ID,mgr);
  689. b.create(MAXMATRIXDIM*MAXMATRIXDIM,UnaryTestsF32::TMPB_F32_ID,mgr);
  690. break;
  691. case TEST_MAT_LDL_F32_11:
  692. // Definite positive test
  693. input1.reload(UnaryTestsF32::INPUTSCHOLESKY1_DPO_F32_ID,mgr);
  694. dims.reload(UnaryTestsF32::DIMSCHOLESKY1_DPO_S16_ID,mgr);
  695. outputll.create(input1.nbSamples(),UnaryTestsF32::LL_F32_ID,mgr);
  696. outputd.create(input1.nbSamples(),UnaryTestsF32::D_F32_ID,mgr);
  697. outputp.create(input1.nbSamples(),UnaryTestsF32::PERM_S16_ID,mgr);
  698. outputa.create(input1.nbSamples(),UnaryTestsF32::OUTA_F64_ID,mgr);
  699. outputb.create(input1.nbSamples(),UnaryTestsF32::OUTB_F64_ID,mgr);
  700. a.create(MAXMATRIXDIM*MAXMATRIXDIM,UnaryTestsF32::TMPA_F32_ID,mgr);
  701. tmpapat.create(MAXMATRIXDIM*MAXMATRIXDIM,UnaryTestsF32::TMPB_F64_ID,mgr);
  702. tmpbpat.create(MAXMATRIXDIM*MAXMATRIXDIM,UnaryTestsF32::TMPC_F64_ID,mgr);
  703. tmpcpat.create(MAXMATRIXDIM*MAXMATRIXDIM,UnaryTestsF32::TMPD_F64_ID,mgr);
  704. this->snrRel=REL_ERROR_LDLT;
  705. this->snrAbs=ABS_ERROR_LDLT;
  706. break;
  707. case TEST_MAT_LDL_F32_12:
  708. // Semi definite positive test
  709. input1.reload(UnaryTestsF32::INPUTSCHOLESKY1_SDPO_F32_ID,mgr);
  710. dims.reload(UnaryTestsF32::DIMSCHOLESKY1_SDPO_S16_ID,mgr);
  711. outputll.create(input1.nbSamples(),UnaryTestsF32::LL_F32_ID,mgr);
  712. outputd.create(input1.nbSamples(),UnaryTestsF32::D_F32_ID,mgr);
  713. outputp.create(input1.nbSamples(),UnaryTestsF32::PERM_S16_ID,mgr);
  714. outputa.create(input1.nbSamples(),UnaryTestsF32::OUTA_F64_ID,mgr);
  715. outputb.create(input1.nbSamples(),UnaryTestsF32::OUTB_F64_ID,mgr);
  716. a.create(MAXMATRIXDIM*MAXMATRIXDIM,UnaryTestsF32::TMPA_F32_ID,mgr);
  717. tmpapat.create(MAXMATRIXDIM*MAXMATRIXDIM,UnaryTestsF32::TMPB_F64_ID,mgr);
  718. tmpbpat.create(MAXMATRIXDIM*MAXMATRIXDIM,UnaryTestsF32::TMPC_F64_ID,mgr);
  719. tmpcpat.create(MAXMATRIXDIM*MAXMATRIXDIM,UnaryTestsF32::TMPD_F64_ID,mgr);
  720. this->snrRel=REL_ERROR_LDLT_SPDO;
  721. this->snrAbs=ABS_ERROR_LDLT_SDPO;
  722. break;
  723. case TEST_HOUSEHOLDER_F32_13:
  724. input1.reload(UnaryTestsF32::INPUTS_HOUSEHOLDER_F32_ID,mgr);
  725. dims.reload(UnaryTestsF32::DIMS_HOUSEHOLDER_S16_ID,mgr);
  726. ref.reload(UnaryTestsF32::REF_HOUSEHOLDER_V_F32_ID,mgr);
  727. refBeta.reload(UnaryTestsF32::REF_HOUSEHOLDER_BETA_F32_ID,mgr);
  728. output.create(ref.nbSamples(),UnaryTestsF32::TMPA_F32_ID,mgr);
  729. outputBeta.create(refBeta.nbSamples(),UnaryTestsF32::TMPB_F32_ID,mgr);
  730. break;
  731. case TEST_MAT_QR_F32_14:
  732. input1.reload(UnaryTestsF32::INPUTS_QR_F32_ID,mgr);
  733. dims.reload(UnaryTestsF32::DIMS_QR_S16_ID,mgr);
  734. refTau.reload(UnaryTestsF32::REF_QR_TAU_F32_ID,mgr);
  735. refR.reload(UnaryTestsF32::REF_QR_R_F32_ID,mgr);
  736. refQ.reload(UnaryTestsF32::REF_QR_Q_F32_ID,mgr);
  737. outputTau.create(refTau.nbSamples(),UnaryTestsF32::TMPA_F32_ID,mgr);
  738. outputR.create(refR.nbSamples(),UnaryTestsF32::TMPB_F32_ID,mgr);
  739. outputQ.create(refQ.nbSamples(),UnaryTestsF32::TMPC_F32_ID,mgr);
  740. a.create(47,UnaryTestsF32::TMPC_F32_ID,mgr);
  741. b.create(47,UnaryTestsF32::TMPD_F32_ID,mgr);
  742. break;
  743. }
  744. }
  745. void UnaryTestsF32::tearDown(Testing::testID_t id,Client::PatternMgr *mgr)
  746. {
  747. (void)id;
  748. (void)mgr;
  749. switch(id)
  750. {
  751. case TEST_MAT_LDL_F32_11:
  752. //outputll.dump(mgr);
  753. break;
  754. case TEST_MAT_QR_F32_14:
  755. //outputR.dump(mgr);
  756. break;
  757. }
  758. //output.dump(mgr);
  759. }