| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155 |
- #include "StatsTestsF32.h"
- #include "Error.h"
- #include "arm_math.h"
- #include "Test.h"
- #include <cstdio>
- void StatsTestsF32::test_entropy_f32()
- {
- const float32_t *inp = inputA.ptr();
- float32_t *refp = ref.ptr();
- float32_t *outp = output.ptr();
- float32_t *result;
- for(int i=0;i < this->nbPatterns; i++)
- {
- *outp = arm_entropy_f32(inp,this->vecDim);
- outp++;
- inp += vecDim;
- }
- ASSERT_NEAR_EQ(ref,output,(float32_t)1e-6);
- }
- void StatsTestsF32::test_logsumexp_f32()
- {
- const float32_t *inp = inputA.ptr();
- float32_t *refp = ref.ptr();
- float32_t *outp = output.ptr();
- float32_t *result;
- for(int i=0;i < this->nbPatterns; i++)
- {
- *outp = arm_logsumexp_f32(inp,this->vecDim);
- outp++;
- inp += vecDim;
- }
- ASSERT_NEAR_EQ(ref,output,(float32_t)1e-6);
- }
- void StatsTestsF32::test_kullback_leibler_f32()
- {
- const float32_t *inpA = inputA.ptr();
- const float32_t *inpB = inputB.ptr();
- float32_t *refp = ref.ptr();
- float32_t *outp = output.ptr();
- float32_t *result;
- for(int i=0;i < this->nbPatterns; i++)
- {
- *outp = arm_kullback_leibler_f32(inpA,inpB,this->vecDim);
- outp++;
- inpA += vecDim;
- inpB += vecDim;
- }
- ASSERT_NEAR_EQ(ref,output,(float32_t)1e-6);
- }
- void StatsTestsF32::test_logsumexp_dot_prod_f32()
- {
- const float32_t *inpA = inputA.ptr();
- const float32_t *inpB = inputB.ptr();
- float32_t *refp = ref.ptr();
- float32_t *outp = output.ptr();
- float32_t *tmpp = tmp.ptr();
- float32_t *result;
- for(int i=0;i < this->nbPatterns; i++)
- {
- *outp = arm_logsumexp_dot_prod_f32(inpA,inpB,this->vecDim,tmpp);
- outp++;
- inpA += vecDim;
- inpB += vecDim;
- }
- ASSERT_NEAR_EQ(ref,output,(float32_t)1e-6);
- }
-
-
- void StatsTestsF32::setUp(Testing::testID_t id,std::vector<Testing::param_t>& paramsArgs,Client::PatternMgr *mgr)
- {
- switch(id)
- {
- case StatsTestsF32::TEST_ENTROPY_F32_1:
- {
- inputA.reload(StatsTestsF32::INPUT1_F32_ID,mgr);
- dims.reload(StatsTestsF32::DIM1_S16_ID,mgr);
- ref.reload(StatsTestsF32::REF1_ENTROPY_F32_ID,mgr);
- output.create(ref.nbSamples(),StatsTestsF32::OUT_F32_ID,mgr);
- const int16_t *dimsp = dims.ptr();
- this->nbPatterns=dimsp[0];
- this->vecDim=dimsp[1];
- }
- break;
- case StatsTestsF32::TEST_LOGSUMEXP_F32_2:
- {
- inputA.reload(StatsTestsF32::INPUT2_F32_ID,mgr);
- dims.reload(StatsTestsF32::DIM2_S16_ID,mgr);
- ref.reload(StatsTestsF32::REF2_LOGSUMEXP_F32_ID,mgr);
- output.create(ref.nbSamples(),StatsTestsF32::OUT_F32_ID,mgr);
- const int16_t *dimsp = dims.ptr();
- this->nbPatterns=dimsp[0];
- this->vecDim=dimsp[1];
- }
- break;
- case StatsTestsF32::TEST_KULLBACK_LEIBLER_F32_3:
- {
- inputA.reload(StatsTestsF32::INPUTA3_F32_ID,mgr);
- inputB.reload(StatsTestsF32::INPUTB3_F32_ID,mgr);
- dims.reload(StatsTestsF32::DIM3_S16_ID,mgr);
- ref.reload(StatsTestsF32::REF3_KL_F32_ID,mgr);
- output.create(ref.nbSamples(),StatsTestsF32::OUT_F32_ID,mgr);
- const int16_t *dimsp = dims.ptr();
- this->nbPatterns=dimsp[0];
- this->vecDim=dimsp[1];
- }
- break;
- case StatsTestsF32::TEST_LOGSUMEXP_DOT_PROD_F32_4:
- {
- inputA.reload(StatsTestsF32::INPUTA4_F32_ID,mgr);
- inputB.reload(StatsTestsF32::INPUTB4_F32_ID,mgr);
- dims.reload(StatsTestsF32::DIM4_S16_ID,mgr);
- ref.reload(StatsTestsF32::REF4_LOGSUMEXP_DOT_F32_ID,mgr);
- output.create(ref.nbSamples(),StatsTestsF32::OUT_F32_ID,mgr);
- const int16_t *dimsp = dims.ptr();
- this->nbPatterns=dimsp[0];
- this->vecDim=dimsp[1];
- tmp.create(this->vecDim,StatsTestsF32::TMP_F32_ID,mgr);
- }
- break;
- }
-
- }
- void StatsTestsF32::tearDown(Testing::testID_t id,Client::PatternMgr *mgr)
- {
- output.dump(mgr);
- }
|