BayesF16.cpp 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. #include "BayesF16.h"
  2. #include <stdio.h>
  3. #include "Error.h"
  4. #include "Test.h"
  5. #define REL_ERROR ((float16_t)3e-3)
  6. void BayesF16::test_gaussian_naive_bayes_predict_f16()
  7. {
  8. const float16_t *inp = input.ptr();
  9. float16_t *bufp = outputProbas.ptr();
  10. float16_t *tempp = temp.ptr();
  11. int16_t *p = outputPredicts.ptr();
  12. for(int i=0; i < this->nbPatterns ; i ++)
  13. {
  14. *p = arm_gaussian_naive_bayes_predict_f16(&bayes,
  15. inp,
  16. bufp,tempp);
  17. inp += this->vecDim;
  18. bufp += this->classNb;
  19. p++;
  20. }
  21. ASSERT_REL_ERROR(outputProbas,probas,REL_ERROR);
  22. ASSERT_EQ(outputPredicts,predicts);
  23. }
  24. void BayesF16::setUp(Testing::testID_t id,std::vector<Testing::param_t>& paramsArgs,Client::PatternMgr *mgr)
  25. {
  26. (void)paramsArgs;
  27. switch(id)
  28. {
  29. case BayesF16::TEST_GAUSSIAN_NAIVE_BAYES_PREDICT_F16_1:
  30. input.reload(BayesF16::INPUTS1_F16_ID,mgr);
  31. params.reload(BayesF16::PARAMS1_F16_ID,mgr);
  32. dims.reload(BayesF16::DIMS1_S16_ID,mgr);
  33. const int16_t *dimsp=dims.ptr();
  34. const float16_t *paramsp = params.ptr();
  35. this->nbPatterns=dimsp[0];
  36. this->classNb=dimsp[1];
  37. this->vecDim=dimsp[2];
  38. this->theta=paramsp;
  39. this->sigma=paramsp + (this->classNb * this->vecDim);
  40. this->classPrior=paramsp + 2*(this->classNb * this->vecDim);
  41. this->epsilon=paramsp[this->classNb + 2*(this->classNb * this->vecDim)];
  42. //printf("%f %f %f\n",this->theta[0],this->sigma[0],this->classPrior[0]);
  43. // Reference patterns are not loaded when we are in dump mode
  44. probas.reload(BayesF16::PROBAS1_F16_ID,mgr);
  45. predicts.reload(BayesF16::PREDICTS1_S16_ID,mgr);
  46. outputProbas.create(this->nbPatterns*this->classNb,BayesF16::OUT_PROBA_F16_ID,mgr);
  47. temp.create(this->nbPatterns*this->classNb,BayesF16::OUT_PROBA_F16_ID,mgr);
  48. outputPredicts.create(this->nbPatterns,BayesF16::OUT_PREDICT_S16_ID,mgr);
  49. bayes.vectorDimension=this->vecDim;
  50. bayes.numberOfClasses=this->classNb;
  51. bayes.theta=this->theta;
  52. bayes.sigma=this->sigma;
  53. bayes.classPriors=this->classPrior;
  54. bayes.epsilon=this->epsilon;
  55. break;
  56. }
  57. }
  58. void BayesF16::tearDown(Testing::testID_t id,Client::PatternMgr *mgr)
  59. {
  60. (void)id;
  61. outputProbas.dump(mgr);
  62. outputPredicts.dump(mgr);
  63. }