fusion_test.cpp 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. extern "C" {
  2. extern void fusion_test();
  3. }
  4. #include "allocator.h"
  5. #include <tuple>
  6. #include <dsppp/arch.hpp>
  7. #include <dsppp/fixed_point.hpp>
  8. #include <dsppp/matrix.hpp>
  9. #include <dsppp/unroll.hpp>
  10. #include <iostream>
  11. #include <cmsis_tests.h>
  12. template<typename T,int NB>
  13. static void test()
  14. {
  15. std::cout << "----\r\n" << "N = " << NB << "\r\n";
  16. #if defined(STATIC_TEST)
  17. PVector<T,NB> a;
  18. PVector<T,NB> b;
  19. PVector<T,NB> c;
  20. #else
  21. PVector<T> a(NB);
  22. PVector<T> b(NB);
  23. PVector<T> c(NB);
  24. #endif
  25. init_array(a,NB);
  26. init_array(b,NB);
  27. init_array(c,NB);
  28. #if defined(STATIC_TEST)
  29. PVector<T,NB> resa;
  30. PVector<T,NB> resb;
  31. #else
  32. PVector<T> resa(NB);
  33. PVector<T> resb(NB);
  34. #endif
  35. INIT_SYSTICK;
  36. START_CYCLE_MEASUREMENT;
  37. startSectionNB(1);
  38. results(resa,resb) = Merged{a + b,a + c};
  39. stopSectionNB(1);
  40. STOP_CYCLE_MEASUREMENT;
  41. PVector<T,NB> refa;
  42. PVector<T,NB> refb;
  43. INIT_SYSTICK;
  44. START_CYCLE_MEASUREMENT;
  45. cmsisdsp_add(a.const_ptr(),b.const_ptr(),refa.ptr(),NB);
  46. cmsisdsp_add(a.const_ptr(),c.const_ptr(),refb.ptr(),NB);
  47. STOP_CYCLE_MEASUREMENT;
  48. if (!validate(resa.const_ptr(),refa.const_ptr(),NB))
  49. {
  50. printf("add a failed \r\n");
  51. }
  52. if (!validate(resb.const_ptr(),refb.const_ptr(),NB))
  53. {
  54. printf("add b failed \r\n");
  55. }
  56. std::cout << "=====\r\n";
  57. }
  58. template<typename T,int NB>
  59. static void test2()
  60. {
  61. std::cout << "----\r\n" << "N = " << NB << "\r\n";
  62. #if defined(STATIC_TEST)
  63. PVector<T,NB> a;
  64. PVector<T,NB> b;
  65. PVector<T,NB> c;
  66. #else
  67. PVector<T> a(NB);
  68. PVector<T> b(NB);
  69. PVector<T> c(NB);
  70. #endif
  71. using Acc = typename number_traits<T>::accumulator;
  72. init_array(a,NB);
  73. init_array(b,NB);
  74. init_array(c,NB);
  75. Acc resa,resb,refa,refb;
  76. INIT_SYSTICK;
  77. START_CYCLE_MEASUREMENT;
  78. startSectionNB(2);
  79. std::tie(resa,resb) = dot(Merged{expr(a),expr(a)},
  80. Merged{expr(b),expr(c)});
  81. stopSectionNB(2);
  82. STOP_CYCLE_MEASUREMENT;
  83. INIT_SYSTICK;
  84. START_CYCLE_MEASUREMENT;
  85. cmsisdsp_dot(a.const_ptr(),b.const_ptr(),refa,NB);
  86. cmsisdsp_dot(a.const_ptr(),c.const_ptr(),refb,NB);
  87. STOP_CYCLE_MEASUREMENT;
  88. if (!validate(resa,refa))
  89. {
  90. printf("dot a failed \r\n");
  91. }
  92. if (!validate(resb,refb))
  93. {
  94. printf("dot b failed \r\n");
  95. }
  96. std::cout << "=====\r\n";
  97. }
  98. template<typename T,int NB>
  99. static void test3()
  100. {
  101. std::cout << "----\r\n" << "N = " << NB << "\r\n";
  102. constexpr int U = 2;
  103. #if defined(STATIC_TEST)
  104. PVector<T,NB> a[U];
  105. PVector<T,NB> b[U];
  106. #else
  107. PVector<T> a[U]={PVector<T>(NB),PVector<T>(NB)};
  108. PVector<T> b[U]={PVector<T>(NB),PVector<T>(NB)};
  109. #endif
  110. using Acc = typename number_traits<T>::accumulator;
  111. for(int i=0;i<U;i++)
  112. {
  113. init_array(a[i],NB);
  114. init_array(b[i],NB);
  115. }
  116. std::array<Acc,U> res;
  117. Acc ref[U];
  118. INIT_SYSTICK;
  119. START_CYCLE_MEASUREMENT;
  120. startSectionNB(3);
  121. results(res) = dot(unroll<U>(
  122. [&a](index_t k){return expr(a[k]);}),
  123. unroll<U>(
  124. [&b](index_t k){return expr(b[k]);})
  125. );
  126. stopSectionNB(3);
  127. STOP_CYCLE_MEASUREMENT;
  128. INIT_SYSTICK;
  129. START_CYCLE_MEASUREMENT;
  130. for(int i=0;i<U;i++)
  131. {
  132. cmsisdsp_dot(a[i].const_ptr(),b[i].const_ptr(),ref[i],NB);
  133. }
  134. STOP_CYCLE_MEASUREMENT;
  135. for(int i=0;i<U;i++)
  136. {
  137. if (!validate(res[i],ref[i]))
  138. {
  139. printf("dot failed %d \r\n",i);
  140. }
  141. }
  142. std::cout << "=====\r\n";
  143. }
  144. template<typename T>
  145. void all_fusion_test()
  146. {
  147. const int nb_tails = TailForTests<T>::tail;
  148. const int nb_loops = TailForTests<T>::loop;
  149. title<T>("Vector Fusion");
  150. test<T,NBVEC_256>();
  151. test<T,1>();
  152. test<T,nb_tails>();
  153. test<T,nb_loops>();
  154. test<T,nb_loops+1>();
  155. test<T,nb_loops+nb_tails>();
  156. title<T>("Dot Product Fusion");
  157. test2<T,NBVEC_256>();
  158. test2<T,1>();
  159. test2<T,nb_tails>();
  160. test2<T,nb_loops>();
  161. test2<T,nb_loops+1>();
  162. test2<T,nb_loops+nb_tails>();
  163. title<T>("Unroll Fusion");
  164. test3<T,NBVEC_256>();
  165. test3<T,1>();
  166. test3<T,nb_tails>();
  167. test3<T,nb_loops>();
  168. test3<T,nb_loops+1>();
  169. test3<T,nb_loops+nb_tails>();
  170. }
  171. void fusion_test()
  172. {
  173. #if defined(FUSION_TEST)
  174. #if defined(F64_DT)
  175. all_fusion_test<double>();
  176. #endif
  177. #if defined(F32_DT)
  178. all_fusion_test<float>();
  179. #endif
  180. #if defined(F16_DT) && !defined(DISABLEFLOAT16)
  181. all_fusion_test<float16_t>();
  182. #endif
  183. #if defined(Q31_DT)
  184. all_fusion_test<Q31>();
  185. #endif
  186. #if defined(Q15_DT)
  187. all_fusion_test<Q15>();
  188. #endif
  189. #if defined(Q7_DT)
  190. all_fusion_test<Q7>();
  191. #endif
  192. #endif
  193. }