Softmax.cpp 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. #include "Softmax.h"
  2. #include <stdio.h>
  3. #include "Error.h"
  4. #include "arm_nnfunctions.h"
  5. #include "Test.h"
  6. /*
  7. Tests have shown that, compared to a float32 implementation
  8. there is an average error of 4.2 percent and standard deviation
  9. of 0.89.
  10. Which means that with 100 batches, 4 batches will give the wrong position
  11. for the max.
  12. But it depends highly of the vector dimension.
  13. Regressions are giving:
  14. Average error rate = -0.555548 + 0.246918 vecDim
  15. Variance = -0.0112281 + 0.0382476 vecDim
  16. So for vecDim = 21 we have
  17. Average error rate = 4.6 percent
  18. Variance for error rate = 0.8
  19. This data is used to define the threshold for tests
  20. */
  21. #define THRESHOLD 7.5
  22. int16_t findMaxIndex(q7_t *vec_in, int length)
  23. {
  24. int16_t currentIndex=0;
  25. int16_t i=1;
  26. q7_t currentMax=vec_in[0];
  27. while(i<length)
  28. {
  29. if (vec_in[i] > currentMax)
  30. {
  31. currentMax = vec_in[i];
  32. currentIndex = i;
  33. }
  34. i++;
  35. }
  36. return(currentIndex+1);
  37. }
  38. int16_t differences(int16_t *pa,int16_t *pb, int length)
  39. {
  40. int16_t d=0;
  41. int i=0;
  42. while(i < length)
  43. {
  44. if (*pa != *pb)
  45. {
  46. d++;
  47. }
  48. pa++;
  49. pb++;
  50. i++;
  51. }
  52. return(d);
  53. }
  54. void Softmax::test_softmax_q7()
  55. {
  56. const q7_t *vec_in = input.ptr();
  57. q7_t *pTmp = temp.ptr();
  58. int16_t *pOut = output.ptr();
  59. int16_t maxIndex;
  60. for(int i=0; i <this->nbSamples;i++)
  61. {
  62. arm_softmax_q7(vec_in, this->vecDim, pTmp );
  63. maxIndex=findMaxIndex(pTmp,this->vecDim);
  64. *pOut++ = maxIndex;
  65. vec_in += this->vecDim;
  66. pTmp += this->vecDim;
  67. }
  68. int diff = differences(ref.ptr(),output.ptr(),this->nbSamples);
  69. ASSERT_TRUE(100.0*diff/this->nbSamples <= THRESHOLD);
  70. }
  71. void Softmax::test_softmax_with_batch_q7()
  72. {
  73. const q7_t *vec_in = input.ptr();
  74. q7_t *pTmp = temp.ptr();
  75. int16_t *pOut = output.ptr();
  76. int16_t maxIndex;
  77. arm_softmax_with_batch_q7(vec_in, this->nbSamples,this->vecDim, pTmp );
  78. for(int i=0; i <this->nbSamples;i++)
  79. {
  80. maxIndex=findMaxIndex(pTmp,this->vecDim);
  81. *pOut++ = maxIndex;
  82. pTmp += this->vecDim;
  83. }
  84. int diff = differences(ref.ptr(),output.ptr(),this->nbSamples);
  85. ASSERT_TRUE(100.0*diff/this->nbSamples <= THRESHOLD);
  86. }
  87. void Softmax::setUp(Testing::testID_t id,std::vector<Testing::param_t>& paramsArgs,Client::PatternMgr *mgr)
  88. {
  89. switch(id)
  90. {
  91. case Softmax::TEST_SOFTMAX_Q7_1:
  92. {
  93. ref.reload(Softmax::REF1_S16_ID,mgr);
  94. dims.reload(Softmax::DIMS1_S16_ID,mgr);
  95. input.reload(Softmax::INPUT1_Q7_ID,mgr);
  96. const int16_t *pDims=dims.ptr();
  97. this->nbSamples = pDims[0];
  98. this->vecDim = pDims[1];
  99. }
  100. break;
  101. case Softmax::TEST_SOFTMAX_WITH_BATCH_Q7_2:
  102. {
  103. ref.reload(Softmax::REF1_S16_ID,mgr);
  104. dims.reload(Softmax::DIMS1_S16_ID,mgr);
  105. input.reload(Softmax::INPUT1_Q7_ID,mgr);
  106. const int16_t *pDims=dims.ptr();
  107. this->nbSamples = pDims[0];
  108. this->vecDim = pDims[1];
  109. }
  110. break;
  111. }
  112. output.create(ref.nbSamples(),Softmax::OUTPUT_S16_ID,mgr);
  113. // Used to compare bit exactness of the reference C version
  114. // and the optimized version.
  115. temp.create(this->vecDim*this->nbSamples,Softmax::TEMP_Q7_ID,mgr);
  116. }
  117. void Softmax::tearDown(Testing::testID_t id,Client::PatternMgr *mgr)
  118. {
  119. // Array are big so by default they are not dumped and only
  120. // used for debug.
  121. //output.dump(mgr);
  122. //temp.dump(mgr);
  123. }