dot_test.cpp 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. extern "C" {
  2. extern void dot_test();
  3. }
  4. #include "allocator.h"
  5. #include <dsppp/arch.hpp>
  6. #include <dsppp/fixed_point.hpp>
  7. #include <dsppp/matrix.hpp>
  8. #include <iostream>
  9. #include <cmsis_tests.h>
  10. #include "dsp/basic_math_functions.h"
  11. #include "dsp/basic_math_functions_f16.h"
  12. template<typename T,int NB,typename O>
  13. static void complex_test(const T scale)
  14. {
  15. std::cout << "----\r\n" << "N = " << NB << "\r\n";
  16. #if defined(STATIC_TEST)
  17. PVector<T,NB> a;
  18. PVector<T,NB> b;
  19. PVector<T,NB> c;
  20. PVector<T,NB> d;
  21. PVector<T,NB> res;
  22. #else
  23. PVector<T> a(NB);
  24. PVector<T> b(NB);
  25. PVector<T> c(NB);
  26. PVector<T> d(NB);
  27. PVector<T> res(NB);
  28. #endif
  29. init_array(a,NB);
  30. init_array(b,NB);
  31. init_array(c,NB);
  32. init_array(d,NB);
  33. INIT_SYSTICK;
  34. START_CYCLE_MEASUREMENT;
  35. startSectionNB(1);
  36. O result = dot(scale*(a+b),c*d);
  37. stopSectionNB(1);
  38. STOP_CYCLE_MEASUREMENT;
  39. O ref;
  40. PVector<T,NB> tmp1;
  41. PVector<T,NB> tmp2;
  42. INIT_SYSTICK;
  43. START_CYCLE_MEASUREMENT;
  44. cmsisdsp_dot_expr(a.const_ptr(),
  45. b.const_ptr(),
  46. c.const_ptr(),
  47. d.const_ptr(),
  48. tmp1.ptr(),
  49. tmp2.ptr(),
  50. scale,
  51. ref,NB);
  52. STOP_CYCLE_MEASUREMENT;
  53. if (!validate(result,ref))
  54. {
  55. printf("dot expr failed \r\n");
  56. }
  57. std::cout << "=====\r\n";
  58. }
  59. template<typename T,int NB,typename O>
  60. static void test()
  61. {
  62. std::cout << "----\r\n" << "N = " << NB << "\r\n";
  63. #if defined(STATIC_TEST)
  64. PVector<T,NB> a;
  65. PVector<T,NB> b;
  66. PVector<T,NB> res;
  67. #else
  68. PVector<T> a(NB);
  69. PVector<T> b(NB);
  70. PVector<T> res(NB);
  71. #endif
  72. init_array(a,NB);
  73. init_array(b,NB);
  74. INIT_SYSTICK;
  75. START_CYCLE_MEASUREMENT;
  76. startSectionNB(1);
  77. O result = dot(a,b);
  78. stopSectionNB(1);
  79. STOP_CYCLE_MEASUREMENT;
  80. O ref;
  81. INIT_SYSTICK;
  82. START_CYCLE_MEASUREMENT;
  83. cmsisdsp_dot(a.const_ptr(),b.const_ptr(),ref,NB);
  84. STOP_CYCLE_MEASUREMENT;
  85. if (!validate(result,ref))
  86. {
  87. printf("dot failed \r\n");
  88. }
  89. std::cout << "=====\r\n";
  90. }
  91. template<typename T>
  92. void all_dot_test()
  93. {
  94. const int nb_tails = TailForTests<T>::tail;
  95. const int nb_loops = TailForTests<T>::loop;
  96. using ACC = typename number_traits<T>::accumulator;
  97. constexpr auto v = TestConstant<T>::v;
  98. title<T>("Dot product");
  99. test<T,NBVEC_4,ACC>();
  100. test<T,NBVEC_8,ACC>();
  101. test<T,NBVEC_9,ACC>();
  102. test<T,NBVEC_16,ACC>();
  103. test<T,NBVEC_32,ACC>();
  104. test<T,NBVEC_64,ACC>();
  105. test<T,NBVEC_128,ACC>();
  106. test<T,NBVEC_256,ACC>();
  107. test<T,NBVEC_258,ACC>();
  108. test<T,NBVEC_512,ACC>();
  109. test<T,NBVEC_1024,ACC>();
  110. if constexpr (!std::is_same<T,double>::value)
  111. {
  112. test<T,NBVEC_2048,ACC>();
  113. }
  114. test<T,1,ACC>();
  115. test<T,nb_tails,ACC>();
  116. test<T,nb_loops,ACC>();
  117. test<T,nb_loops+1,ACC>();
  118. test<T,nb_loops+nb_tails,ACC>();
  119. title<T>("Dot product with expressions");
  120. complex_test<T,NBVEC_4,ACC>(v);
  121. complex_test<T,NBVEC_8,ACC>(v);
  122. complex_test<T,NBVEC_9,ACC>(v);
  123. complex_test<T,NBVEC_32,ACC>(v);
  124. complex_test<T,NBVEC_64,ACC>(v);
  125. complex_test<T,NBVEC_128,ACC>(v);
  126. complex_test<T,NBVEC_256,ACC>(v);
  127. complex_test<T,NBVEC_258,ACC>(v);
  128. complex_test<T,NBVEC_512,ACC>(v);
  129. complex_test<T,NBVEC_1024,ACC>(v);
  130. if constexpr (!std::is_same<T,double>::value)
  131. {
  132. complex_test<T,NBVEC_2048,ACC>(v);
  133. }
  134. complex_test<T,1,ACC>(v);
  135. complex_test<T,nb_tails,ACC>(v);
  136. complex_test<T,nb_loops,ACC>(v);
  137. complex_test<T,nb_loops+1,ACC>(v);
  138. complex_test<T,nb_loops+nb_tails,ACC>(v);
  139. //print_map("Stats",max_stats);
  140. }
  141. void dot_test()
  142. {
  143. #if defined(DOT_TEST)
  144. #if defined(F64_DT)
  145. all_dot_test<double>();
  146. #endif
  147. #if defined(F32_DT)
  148. all_dot_test<float>();
  149. #endif
  150. #if defined(F16_DT) && !defined(DISABLEFLOAT16)
  151. all_dot_test<float16_t>();
  152. #endif
  153. #if defined(Q31_DT)
  154. all_dot_test<Q31>();
  155. #endif
  156. #if defined(Q15_DT)
  157. all_dot_test<Q15>();
  158. #endif
  159. #if defined(Q7_DT)
  160. all_dot_test<Q7>();
  161. #endif
  162. #endif
  163. }