StatsTestsF32.cpp 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. #include "StatsTestsF32.h"
  2. #include "Error.h"
  3. #include "arm_math.h"
  4. #include "Test.h"
  5. #include <cstdio>
  6. void StatsTestsF32::test_entropy_f32()
  7. {
  8. const float32_t *inp = inputA.ptr();
  9. float32_t *refp = ref.ptr();
  10. float32_t *outp = output.ptr();
  11. float32_t *result;
  12. for(int i=0;i < this->nbPatterns; i++)
  13. {
  14. *outp = arm_entropy_f32(inp,this->vecDim);
  15. outp++;
  16. inp += vecDim;
  17. }
  18. ASSERT_NEAR_EQ(ref,output,(float32_t)1e-6);
  19. }
  20. void StatsTestsF32::test_logsumexp_f32()
  21. {
  22. const float32_t *inp = inputA.ptr();
  23. float32_t *refp = ref.ptr();
  24. float32_t *outp = output.ptr();
  25. float32_t *result;
  26. for(int i=0;i < this->nbPatterns; i++)
  27. {
  28. *outp = arm_logsumexp_f32(inp,this->vecDim);
  29. outp++;
  30. inp += vecDim;
  31. }
  32. ASSERT_NEAR_EQ(ref,output,(float32_t)1e-6);
  33. }
  34. void StatsTestsF32::test_kullback_leibler_f32()
  35. {
  36. const float32_t *inpA = inputA.ptr();
  37. const float32_t *inpB = inputB.ptr();
  38. float32_t *refp = ref.ptr();
  39. float32_t *outp = output.ptr();
  40. float32_t *result;
  41. for(int i=0;i < this->nbPatterns; i++)
  42. {
  43. *outp = arm_kullback_leibler_f32(inpA,inpB,this->vecDim);
  44. outp++;
  45. inpA += vecDim;
  46. inpB += vecDim;
  47. }
  48. ASSERT_NEAR_EQ(ref,output,(float32_t)1e-6);
  49. }
  50. void StatsTestsF32::test_logsumexp_dot_prod_f32()
  51. {
  52. const float32_t *inpA = inputA.ptr();
  53. const float32_t *inpB = inputB.ptr();
  54. float32_t *refp = ref.ptr();
  55. float32_t *outp = output.ptr();
  56. float32_t *tmpp = tmp.ptr();
  57. float32_t *result;
  58. for(int i=0;i < this->nbPatterns; i++)
  59. {
  60. *outp = arm_logsumexp_dot_prod_f32(inpA,inpB,this->vecDim,tmpp);
  61. outp++;
  62. inpA += vecDim;
  63. inpB += vecDim;
  64. }
  65. ASSERT_NEAR_EQ(ref,output,(float32_t)1e-6);
  66. }
  67. void StatsTestsF32::setUp(Testing::testID_t id,std::vector<Testing::param_t>& paramsArgs,Client::PatternMgr *mgr)
  68. {
  69. switch(id)
  70. {
  71. case StatsTestsF32::TEST_ENTROPY_F32_1:
  72. {
  73. inputA.reload(StatsTestsF32::INPUT1_F32_ID,mgr);
  74. dims.reload(StatsTestsF32::DIM1_S16_ID,mgr);
  75. ref.reload(StatsTestsF32::REF1_ENTROPY_F32_ID,mgr);
  76. output.create(ref.nbSamples(),StatsTestsF32::OUT_F32_ID,mgr);
  77. const int16_t *dimsp = dims.ptr();
  78. this->nbPatterns=dimsp[0];
  79. this->vecDim=dimsp[1];
  80. }
  81. break;
  82. case StatsTestsF32::TEST_LOGSUMEXP_F32_2:
  83. {
  84. inputA.reload(StatsTestsF32::INPUT2_F32_ID,mgr);
  85. dims.reload(StatsTestsF32::DIM2_S16_ID,mgr);
  86. ref.reload(StatsTestsF32::REF2_LOGSUMEXP_F32_ID,mgr);
  87. output.create(ref.nbSamples(),StatsTestsF32::OUT_F32_ID,mgr);
  88. const int16_t *dimsp = dims.ptr();
  89. this->nbPatterns=dimsp[0];
  90. this->vecDim=dimsp[1];
  91. }
  92. break;
  93. case StatsTestsF32::TEST_KULLBACK_LEIBLER_F32_3:
  94. {
  95. inputA.reload(StatsTestsF32::INPUTA3_F32_ID,mgr);
  96. inputB.reload(StatsTestsF32::INPUTB3_F32_ID,mgr);
  97. dims.reload(StatsTestsF32::DIM3_S16_ID,mgr);
  98. ref.reload(StatsTestsF32::REF3_KL_F32_ID,mgr);
  99. output.create(ref.nbSamples(),StatsTestsF32::OUT_F32_ID,mgr);
  100. const int16_t *dimsp = dims.ptr();
  101. this->nbPatterns=dimsp[0];
  102. this->vecDim=dimsp[1];
  103. }
  104. break;
  105. case StatsTestsF32::TEST_LOGSUMEXP_DOT_PROD_F32_4:
  106. {
  107. inputA.reload(StatsTestsF32::INPUTA4_F32_ID,mgr);
  108. inputB.reload(StatsTestsF32::INPUTB4_F32_ID,mgr);
  109. dims.reload(StatsTestsF32::DIM4_S16_ID,mgr);
  110. ref.reload(StatsTestsF32::REF4_LOGSUMEXP_DOT_F32_ID,mgr);
  111. output.create(ref.nbSamples(),StatsTestsF32::OUT_F32_ID,mgr);
  112. const int16_t *dimsp = dims.ptr();
  113. this->nbPatterns=dimsp[0];
  114. this->vecDim=dimsp[1];
  115. tmp.create(this->vecDim,StatsTestsF32::TMP_F32_ID,mgr);
  116. }
  117. break;
  118. }
  119. }
  120. void StatsTestsF32::tearDown(Testing::testID_t id,Client::PatternMgr *mgr)
  121. {
  122. output.dump(mgr);
  123. }