StatsTestsF32.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665
  1. #include "StatsTestsF32.h"
  2. #include <stdio.h>
  3. #include "Error.h"
  4. #include "arm_math.h"
  5. #include "Test.h"
  6. #define SNR_THRESHOLD 120
  7. /*
  8. Reference patterns are generated with
  9. a double precision computation.
  10. */
  11. #define REL_ERROR (1.0e-5)
  12. void StatsTestsF32::test_max_f32()
  13. {
  14. const float32_t *inp = inputA.ptr();
  15. float32_t result;
  16. uint32_t indexval;
  17. float32_t *refp = ref.ptr();
  18. int16_t *refind = maxIndexes.ptr();
  19. float32_t *outp = output.ptr();
  20. int16_t *ind = index.ptr();
  21. arm_max_f32(inp,
  22. inputA.nbSamples(),
  23. &result,
  24. &indexval);
  25. outp[0] = result;
  26. ind[0] = indexval;
  27. ASSERT_EQ(result,refp[this->refOffset]);
  28. ASSERT_EQ((int16_t)indexval,refind[this->refOffset]);
  29. }
  30. void StatsTestsF32::test_max_no_idx_f32()
  31. {
  32. const float32_t *inp = inputA.ptr();
  33. float32_t result;
  34. float32_t *refp = ref.ptr();
  35. float32_t *outp = output.ptr();
  36. arm_max_no_idx_f32(inp,
  37. inputA.nbSamples(),
  38. &result);
  39. outp[0] = result;
  40. ASSERT_EQ(result,refp[this->refOffset]);
  41. }
  42. void StatsTestsF32::test_min_f32()
  43. {
  44. const float32_t *inp = inputA.ptr();
  45. float32_t result;
  46. uint32_t indexval;
  47. float32_t *refp = ref.ptr();
  48. int16_t *refind = minIndexes.ptr();
  49. float32_t *outp = output.ptr();
  50. int16_t *ind = index.ptr();
  51. arm_min_f32(inp,
  52. inputA.nbSamples(),
  53. &result,
  54. &indexval);
  55. outp[0] = result;
  56. ind[0] = indexval;
  57. ASSERT_EQ(result,refp[this->refOffset]);
  58. ASSERT_EQ((int16_t)indexval,refind[this->refOffset]);
  59. }
  60. void StatsTestsF32::test_mean_f32()
  61. {
  62. const float32_t *inp = inputA.ptr();
  63. float32_t result;
  64. float32_t *refp = ref.ptr();
  65. float32_t *outp = output.ptr();
  66. arm_mean_f32(inp,
  67. inputA.nbSamples(),
  68. &result);
  69. outp[0] = result;
  70. ASSERT_SNR(result,refp[this->refOffset],(float32_t)SNR_THRESHOLD);
  71. ASSERT_REL_ERROR(result,refp[this->refOffset],REL_ERROR);
  72. }
  73. void StatsTestsF32::test_power_f32()
  74. {
  75. const float32_t *inp = inputA.ptr();
  76. float32_t result;
  77. float32_t *refp = ref.ptr();
  78. float32_t *outp = output.ptr();
  79. arm_power_f32(inp,
  80. inputA.nbSamples(),
  81. &result);
  82. outp[0] = result;
  83. ASSERT_SNR(result,refp[this->refOffset],(float32_t)SNR_THRESHOLD);
  84. ASSERT_REL_ERROR(result,refp[this->refOffset],REL_ERROR);
  85. }
  86. void StatsTestsF32::test_rms_f32()
  87. {
  88. const float32_t *inp = inputA.ptr();
  89. float32_t result;
  90. float32_t *refp = ref.ptr();
  91. float32_t *outp = output.ptr();
  92. arm_rms_f32(inp,
  93. inputA.nbSamples(),
  94. &result);
  95. outp[0] = result;
  96. ASSERT_SNR(result,refp[this->refOffset],(float32_t)SNR_THRESHOLD);
  97. ASSERT_REL_ERROR(result,refp[this->refOffset],REL_ERROR);
  98. }
  99. void StatsTestsF32::test_std_f32()
  100. {
  101. const float32_t *inp = inputA.ptr();
  102. float32_t result;
  103. float32_t *refp = ref.ptr();
  104. float32_t *outp = output.ptr();
  105. arm_std_f32(inp,
  106. inputA.nbSamples(),
  107. &result);
  108. outp[0] = result;
  109. ASSERT_SNR(result,refp[this->refOffset],(float32_t)SNR_THRESHOLD);
  110. ASSERT_REL_ERROR(result,refp[this->refOffset],REL_ERROR);
  111. }
  112. void StatsTestsF32::test_var_f32()
  113. {
  114. const float32_t *inp = inputA.ptr();
  115. float32_t result;
  116. float32_t *refp = ref.ptr();
  117. float32_t *outp = output.ptr();
  118. arm_var_f32(inp,
  119. inputA.nbSamples(),
  120. &result);
  121. outp[0] = result;
  122. ASSERT_SNR(result,refp[this->refOffset],(float32_t)SNR_THRESHOLD);
  123. ASSERT_REL_ERROR(result,refp[this->refOffset],REL_ERROR);
  124. }
  125. void StatsTestsF32::test_entropy_f32()
  126. {
  127. const float32_t *inp = inputA.ptr();
  128. const int16_t *dimsp = dims.ptr();
  129. float32_t *refp = ref.ptr();
  130. float32_t *outp = output.ptr();
  131. for(int i=0;i < this->nbPatterns; i++)
  132. {
  133. *outp = arm_entropy_f32(inp,dimsp[i+1]);
  134. outp++;
  135. inp += dimsp[i+1];
  136. }
  137. ASSERT_SNR(ref,output,(float32_t)SNR_THRESHOLD);
  138. ASSERT_REL_ERROR(ref,output,REL_ERROR);
  139. }
  140. void StatsTestsF32::test_logsumexp_f32()
  141. {
  142. const float32_t *inp = inputA.ptr();
  143. const int16_t *dimsp = dims.ptr();
  144. float32_t *refp = ref.ptr();
  145. float32_t *outp = output.ptr();
  146. for(int i=0;i < this->nbPatterns; i++)
  147. {
  148. *outp = arm_logsumexp_f32(inp,dimsp[i+1]);
  149. outp++;
  150. inp += dimsp[i+1];
  151. }
  152. ASSERT_SNR(ref,output,(float32_t)SNR_THRESHOLD);
  153. ASSERT_REL_ERROR(ref,output,REL_ERROR);
  154. }
  155. void StatsTestsF32::test_kullback_leibler_f32()
  156. {
  157. const float32_t *inpA = inputA.ptr();
  158. const float32_t *inpB = inputB.ptr();
  159. const int16_t *dimsp = dims.ptr();
  160. float32_t *refp = ref.ptr();
  161. float32_t *outp = output.ptr();
  162. for(int i=0;i < this->nbPatterns; i++)
  163. {
  164. *outp = arm_kullback_leibler_f32(inpA,inpB,dimsp[i+1]);
  165. outp++;
  166. inpA += dimsp[i+1];
  167. inpB += dimsp[i+1];
  168. }
  169. ASSERT_SNR(ref,output,(float32_t)SNR_THRESHOLD);
  170. ASSERT_REL_ERROR(ref,output,REL_ERROR);
  171. }
  172. void StatsTestsF32::test_logsumexp_dot_prod_f32()
  173. {
  174. const float32_t *inpA = inputA.ptr();
  175. const float32_t *inpB = inputB.ptr();
  176. const int16_t *dimsp = dims.ptr();
  177. float32_t *refp = ref.ptr();
  178. float32_t *outp = output.ptr();
  179. float32_t *tmpp = tmp.ptr();
  180. for(int i=0;i < this->nbPatterns; i++)
  181. {
  182. *outp = arm_logsumexp_dot_prod_f32(inpA,inpB,dimsp[i+1],tmpp);
  183. outp++;
  184. inpA += dimsp[i+1];
  185. inpB += dimsp[i+1];
  186. }
  187. ASSERT_SNR(ref,output,(float32_t)SNR_THRESHOLD);
  188. ASSERT_REL_ERROR(ref,output,REL_ERROR);
  189. }
  190. void StatsTestsF32::setUp(Testing::testID_t id,std::vector<Testing::param_t>& paramsArgs,Client::PatternMgr *mgr)
  191. {
  192. switch(id)
  193. {
  194. case StatsTestsF32::TEST_MAX_F32_1:
  195. {
  196. inputA.reload(StatsTestsF32::INPUT1_F32_ID,mgr,3);
  197. maxIndexes.reload(StatsTestsF32::MAXINDEXES_S16_ID,mgr);
  198. ref.reload(StatsTestsF32::MAXVALS_F32_ID,mgr);
  199. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  200. index.create(1,StatsTestsF32::OUT_S16_ID,mgr);
  201. refOffset = 0;
  202. }
  203. break;
  204. case StatsTestsF32::TEST_MAX_F32_2:
  205. {
  206. inputA.reload(StatsTestsF32::INPUT1_F32_ID,mgr,8);
  207. maxIndexes.reload(StatsTestsF32::MAXINDEXES_S16_ID,mgr);
  208. ref.reload(StatsTestsF32::MAXVALS_F32_ID,mgr);
  209. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  210. index.create(1,StatsTestsF32::OUT_S16_ID,mgr);
  211. refOffset = 1;
  212. }
  213. break;
  214. case StatsTestsF32::TEST_MAX_F32_3:
  215. {
  216. inputA.reload(StatsTestsF32::INPUT1_F32_ID,mgr,11);
  217. maxIndexes.reload(StatsTestsF32::MAXINDEXES_S16_ID,mgr);
  218. ref.reload(StatsTestsF32::MAXVALS_F32_ID,mgr);
  219. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  220. index.create(1,StatsTestsF32::OUT_S16_ID,mgr);
  221. refOffset = 2;
  222. }
  223. break;
  224. case StatsTestsF32::TEST_MEAN_F32_4:
  225. {
  226. inputA.reload(StatsTestsF32::INPUT2_F32_ID,mgr,3);
  227. ref.reload(StatsTestsF32::MEANVALS_F32_ID,mgr);
  228. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  229. refOffset = 0;
  230. }
  231. break;
  232. case StatsTestsF32::TEST_MEAN_F32_5:
  233. {
  234. inputA.reload(StatsTestsF32::INPUT2_F32_ID,mgr,8);
  235. ref.reload(StatsTestsF32::MEANVALS_F32_ID,mgr);
  236. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  237. refOffset = 1;
  238. }
  239. break;
  240. case StatsTestsF32::TEST_MEAN_F32_6:
  241. {
  242. inputA.reload(StatsTestsF32::INPUT2_F32_ID,mgr,11);
  243. ref.reload(StatsTestsF32::MEANVALS_F32_ID,mgr);
  244. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  245. refOffset = 2;
  246. }
  247. break;
  248. case StatsTestsF32::TEST_MIN_F32_7:
  249. {
  250. inputA.reload(StatsTestsF32::INPUT1_F32_ID,mgr,3);
  251. minIndexes.reload(StatsTestsF32::MININDEXES_S16_ID,mgr);
  252. ref.reload(StatsTestsF32::MINVALS_F32_ID,mgr);
  253. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  254. index.create(1,StatsTestsF32::OUT_S16_ID,mgr);
  255. refOffset = 0;
  256. }
  257. break;
  258. case StatsTestsF32::TEST_MIN_F32_8:
  259. {
  260. inputA.reload(StatsTestsF32::INPUT1_F32_ID,mgr,8);
  261. minIndexes.reload(StatsTestsF32::MININDEXES_S16_ID,mgr);
  262. ref.reload(StatsTestsF32::MINVALS_F32_ID,mgr);
  263. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  264. index.create(1,StatsTestsF32::OUT_S16_ID,mgr);
  265. refOffset = 1;
  266. }
  267. break;
  268. case StatsTestsF32::TEST_MIN_F32_9:
  269. {
  270. inputA.reload(StatsTestsF32::INPUT1_F32_ID,mgr,11);
  271. minIndexes.reload(StatsTestsF32::MININDEXES_S16_ID,mgr);
  272. ref.reload(StatsTestsF32::MINVALS_F32_ID,mgr);
  273. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  274. index.create(1,StatsTestsF32::OUT_S16_ID,mgr);
  275. refOffset = 2;
  276. }
  277. break;
  278. case StatsTestsF32::TEST_POWER_F32_10:
  279. {
  280. inputA.reload(StatsTestsF32::INPUT1_F32_ID,mgr,3);
  281. ref.reload(StatsTestsF32::POWERVALS_F32_ID,mgr);
  282. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  283. refOffset = 0;
  284. }
  285. break;
  286. case StatsTestsF32::TEST_POWER_F32_11:
  287. {
  288. inputA.reload(StatsTestsF32::INPUT1_F32_ID,mgr,8);
  289. ref.reload(StatsTestsF32::POWERVALS_F32_ID,mgr);
  290. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  291. refOffset = 1;
  292. }
  293. break;
  294. case StatsTestsF32::TEST_POWER_F32_12:
  295. {
  296. inputA.reload(StatsTestsF32::INPUT1_F32_ID,mgr,11);
  297. ref.reload(StatsTestsF32::POWERVALS_F32_ID,mgr);
  298. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  299. refOffset = 2;
  300. }
  301. break;
  302. case StatsTestsF32::TEST_RMS_F32_13:
  303. {
  304. inputA.reload(StatsTestsF32::INPUT1_F32_ID,mgr,3);
  305. ref.reload(StatsTestsF32::RMSVALS_F32_ID,mgr);
  306. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  307. refOffset = 0;
  308. }
  309. break;
  310. case StatsTestsF32::TEST_RMS_F32_14:
  311. {
  312. inputA.reload(StatsTestsF32::INPUT1_F32_ID,mgr,8);
  313. ref.reload(StatsTestsF32::RMSVALS_F32_ID,mgr);
  314. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  315. refOffset = 1;
  316. }
  317. break;
  318. case StatsTestsF32::TEST_RMS_F32_15:
  319. {
  320. inputA.reload(StatsTestsF32::INPUT1_F32_ID,mgr,11);
  321. ref.reload(StatsTestsF32::RMSVALS_F32_ID,mgr);
  322. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  323. refOffset = 2;
  324. }
  325. break;
  326. case StatsTestsF32::TEST_STD_F32_16:
  327. {
  328. inputA.reload(StatsTestsF32::INPUT1_F32_ID,mgr,3);
  329. ref.reload(StatsTestsF32::STDVALS_F32_ID,mgr);
  330. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  331. refOffset = 0;
  332. }
  333. break;
  334. case StatsTestsF32::TEST_STD_F32_17:
  335. {
  336. inputA.reload(StatsTestsF32::INPUT1_F32_ID,mgr,8);
  337. ref.reload(StatsTestsF32::STDVALS_F32_ID,mgr);
  338. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  339. refOffset = 1;
  340. }
  341. break;
  342. case StatsTestsF32::TEST_STD_F32_18:
  343. {
  344. inputA.reload(StatsTestsF32::INPUT1_F32_ID,mgr,11);
  345. ref.reload(StatsTestsF32::STDVALS_F32_ID,mgr);
  346. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  347. refOffset = 2;
  348. }
  349. break;
  350. case StatsTestsF32::TEST_VAR_F32_19:
  351. {
  352. inputA.reload(StatsTestsF32::INPUT1_F32_ID,mgr,3);
  353. ref.reload(StatsTestsF32::VARVALS_F32_ID,mgr);
  354. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  355. refOffset = 0;
  356. }
  357. break;
  358. case StatsTestsF32::TEST_VAR_F32_20:
  359. {
  360. inputA.reload(StatsTestsF32::INPUT1_F32_ID,mgr,8);
  361. ref.reload(StatsTestsF32::VARVALS_F32_ID,mgr);
  362. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  363. refOffset = 1;
  364. }
  365. break;
  366. case StatsTestsF32::TEST_VAR_F32_21:
  367. {
  368. inputA.reload(StatsTestsF32::INPUT1_F32_ID,mgr,11);
  369. ref.reload(StatsTestsF32::VARVALS_F32_ID,mgr);
  370. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  371. refOffset = 2;
  372. }
  373. break;
  374. case StatsTestsF32::TEST_ENTROPY_F32_22:
  375. {
  376. inputA.reload(StatsTestsF32::INPUT22_F32_ID,mgr);
  377. dims.reload(StatsTestsF32::DIM22_S16_ID,mgr);
  378. ref.reload(StatsTestsF32::REF22_ENTROPY_F32_ID,mgr);
  379. output.create(ref.nbSamples(),StatsTestsF32::OUT_F32_ID,mgr);
  380. const int16_t *dimsp = dims.ptr();
  381. this->nbPatterns=dimsp[0];
  382. }
  383. break;
  384. case StatsTestsF32::TEST_LOGSUMEXP_F32_23:
  385. {
  386. inputA.reload(StatsTestsF32::INPUT23_F32_ID,mgr);
  387. dims.reload(StatsTestsF32::DIM23_S16_ID,mgr);
  388. ref.reload(StatsTestsF32::REF23_LOGSUMEXP_F32_ID,mgr);
  389. output.create(ref.nbSamples(),StatsTestsF32::OUT_F32_ID,mgr);
  390. const int16_t *dimsp = dims.ptr();
  391. this->nbPatterns=dimsp[0];
  392. }
  393. break;
  394. case StatsTestsF32::TEST_KULLBACK_LEIBLER_F32_24:
  395. {
  396. inputA.reload(StatsTestsF32::INPUTA24_F32_ID,mgr);
  397. inputB.reload(StatsTestsF32::INPUTB24_F32_ID,mgr);
  398. dims.reload(StatsTestsF32::DIM24_S16_ID,mgr);
  399. ref.reload(StatsTestsF32::REF24_KL_F32_ID,mgr);
  400. output.create(ref.nbSamples(),StatsTestsF32::OUT_F32_ID,mgr);
  401. const int16_t *dimsp = dims.ptr();
  402. this->nbPatterns=dimsp[0];
  403. }
  404. break;
  405. case StatsTestsF32::TEST_LOGSUMEXP_DOT_PROD_F32_25:
  406. {
  407. inputA.reload(StatsTestsF32::INPUTA25_F32_ID,mgr);
  408. inputB.reload(StatsTestsF32::INPUTB25_F32_ID,mgr);
  409. dims.reload(StatsTestsF32::DIM25_S16_ID,mgr);
  410. ref.reload(StatsTestsF32::REF25_LOGSUMEXP_DOT_F32_ID,mgr);
  411. output.create(ref.nbSamples(),StatsTestsF32::OUT_F32_ID,mgr);
  412. const int16_t *dimsp = dims.ptr();
  413. this->nbPatterns=dimsp[0];
  414. /* 12 is max vecDim as defined in Python script generating the data */
  415. tmp.create(12,StatsTestsF32::TMP_F32_ID,mgr);
  416. }
  417. break;
  418. case StatsTestsF32::TEST_MAX_NO_IDX_F32_26:
  419. {
  420. inputA.reload(StatsTestsF32::INPUT1_F32_ID,mgr,3);
  421. ref.reload(StatsTestsF32::MAXVALS_F32_ID,mgr);
  422. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  423. refOffset = 0;
  424. }
  425. break;
  426. case StatsTestsF32::TEST_MAX_NO_IDX_F32_27:
  427. {
  428. inputA.reload(StatsTestsF32::INPUT1_F32_ID,mgr,8);
  429. ref.reload(StatsTestsF32::MAXVALS_F32_ID,mgr);
  430. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  431. refOffset = 1;
  432. }
  433. break;
  434. case StatsTestsF32::TEST_MAX_NO_IDX_F32_28:
  435. {
  436. inputA.reload(StatsTestsF32::INPUT1_F32_ID,mgr,11);
  437. ref.reload(StatsTestsF32::MAXVALS_F32_ID,mgr);
  438. output.create(1,StatsTestsF32::OUT_F32_ID,mgr);
  439. refOffset = 2;
  440. }
  441. break;
  442. }
  443. }
  444. void StatsTestsF32::tearDown(Testing::testID_t id,Client::PatternMgr *mgr)
  445. {
  446. switch(id)
  447. {
  448. case StatsTestsF32::TEST_MAX_F32_1:
  449. case StatsTestsF32::TEST_MAX_F32_2:
  450. case StatsTestsF32::TEST_MAX_F32_3:
  451. case StatsTestsF32::TEST_MIN_F32_7:
  452. case StatsTestsF32::TEST_MIN_F32_8:
  453. case StatsTestsF32::TEST_MIN_F32_9:
  454. index.dump(mgr);
  455. output.dump(mgr);
  456. break;
  457. default:
  458. output.dump(mgr);
  459. }
  460. }