BayesF32.cpp 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. #include "BayesF32.h"
  2. #include <stdio.h>
  3. #include "Error.h"
  4. #include "Test.h"
  5. void BayesF32::test_gaussian_naive_bayes_predict_f32()
  6. {
  7. const float32_t *inp = input.ptr();
  8. float32_t *bufp = outputProbas.ptr();
  9. float32_t *tempp = temp.ptr();
  10. int16_t *p = outputPredicts.ptr();
  11. for(int i=0; i < this->nbPatterns ; i ++)
  12. {
  13. *p = arm_gaussian_naive_bayes_predict_f32(&bayes,
  14. inp,
  15. bufp,tempp);
  16. inp += this->vecDim;
  17. bufp += this->classNb;
  18. p++;
  19. }
  20. ASSERT_REL_ERROR(outputProbas,probas,(float32_t)5e-6);
  21. ASSERT_EQ(outputPredicts,predicts);
  22. }
  23. void BayesF32::setUp(Testing::testID_t id,std::vector<Testing::param_t>& paramsArgs,Client::PatternMgr *mgr)
  24. {
  25. (void)paramsArgs;
  26. switch(id)
  27. {
  28. case BayesF32::TEST_GAUSSIAN_NAIVE_BAYES_PREDICT_F32_1:
  29. input.reload(BayesF32::INPUTS1_F32_ID,mgr);
  30. params.reload(BayesF32::PARAMS1_F32_ID,mgr);
  31. dims.reload(BayesF32::DIMS1_S16_ID,mgr);
  32. const int16_t *dimsp=dims.ptr();
  33. const float32_t *paramsp = params.ptr();
  34. this->nbPatterns=dimsp[0];
  35. this->classNb=dimsp[1];
  36. this->vecDim=dimsp[2];
  37. this->theta=paramsp;
  38. this->sigma=paramsp + (this->classNb * this->vecDim);
  39. this->classPrior=paramsp + 2*(this->classNb * this->vecDim);
  40. this->epsilon=paramsp[this->classNb + 2*(this->classNb * this->vecDim)];
  41. //printf("%f %f %f\n",this->theta[0],this->sigma[0],this->classPrior[0]);
  42. // Reference patterns are not loaded when we are in dump mode
  43. probas.reload(BayesF32::PROBAS1_F32_ID,mgr);
  44. predicts.reload(BayesF32::PREDICTS1_S16_ID,mgr);
  45. outputProbas.create(this->nbPatterns*this->classNb,BayesF32::OUT_PROBA_F32_ID,mgr);
  46. temp.create(this->nbPatterns*this->classNb,BayesF32::OUT_PROBA_F32_ID,mgr);
  47. outputPredicts.create(this->nbPatterns,BayesF32::OUT_PREDICT_S16_ID,mgr);
  48. bayes.vectorDimension=this->vecDim;
  49. bayes.numberOfClasses=this->classNb;
  50. bayes.theta=this->theta;
  51. bayes.sigma=this->sigma;
  52. bayes.classPriors=this->classPrior;
  53. bayes.epsilon=this->epsilon;
  54. break;
  55. }
  56. }
  57. void BayesF32::tearDown(Testing::testID_t id,Client::PatternMgr *mgr)
  58. {
  59. (void)id;
  60. outputProbas.dump(mgr);
  61. outputPredicts.dump(mgr);
  62. }