row_test.cpp 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. extern "C" {
  2. extern void row_test();
  3. }
  4. #include "allocator.h"
  5. #include <dsppp/arch.hpp>
  6. #include <dsppp/fixed_point.hpp>
  7. #include <dsppp/matrix.hpp>
  8. #include <iostream>
  9. #include <cmsis_tests.h>
  10. #include "dsp/matrix_functions.h"
  11. #include "matrix_utils.h"
  12. template<typename T,int R,int C>
  13. static void test()
  14. {
  15. constexpr int NBOUT = C-1;
  16. std::cout << "----\r\n";
  17. std::cout << R << " x " << C << "\r\n";
  18. std::cout << "NBOUT = " << NBOUT << "\r\n";
  19. #if defined(STATIC_TEST)
  20. PMat<T,R,C> a;
  21. #else
  22. PMat<T> a(R,C);
  23. #endif
  24. init_array(a,R*C);
  25. INIT_SYSTICK
  26. START_CYCLE_MEASUREMENT;
  27. startSectionNB(1);
  28. #if defined(STATIC_TEST)
  29. PVector<T,NBOUT> res = a.row(0,1) + a.row(1,1);
  30. #else
  31. PVector<T> res = a.row(0,1) + a.row(1,1);
  32. #endif
  33. stopSectionNB(1);
  34. STOP_CYCLE_MEASUREMENT;
  35. INIT_SYSTICK;
  36. START_CYCLE_MEASUREMENT;
  37. PVector<T,NBOUT> da (a.row(0,1));
  38. PVector<T,NBOUT> db (a.row(1,1));
  39. PVector<T,NBOUT> ref;
  40. cmsisdsp_add(da.const_ptr(),db.const_ptr(),ref.ptr(),NBOUT);
  41. STOP_CYCLE_MEASUREMENT;
  42. if (!validate(res.const_ptr(),ref.const_ptr(),NBOUT))
  43. {
  44. printf("row add failed \r\n");
  45. }
  46. std::cout << "=====\r\n";
  47. }
  48. template<typename T,int R,int C>
  49. static void swaptest()
  50. {
  51. constexpr int NBOUT = C-2;
  52. std::cout << "----\r\n";
  53. std::cout << R << " x " << C << "\r\n";
  54. std::cout << "NBOUT = " << NBOUT << "\r\n";
  55. #if defined(STATIC_TEST)
  56. PMat<T,R,C> a;
  57. PMat<T,R,C> b;
  58. #else
  59. PMat<T> a(R,C);
  60. PMat<T> b(R,C);
  61. #endif
  62. init_array(a,R*C);
  63. init_array(b,R*C);
  64. INIT_SYSTICK;
  65. START_CYCLE_MEASUREMENT;
  66. startSectionNB(1);
  67. swap(a.row(0,2) , a.row(1,2));
  68. stopSectionNB(1);
  69. STOP_CYCLE_MEASUREMENT;
  70. typename CMSISMatrixType<T>::type mat;
  71. mat.numCols = C;
  72. mat.numRows = R;
  73. mat.pData = b.ptr();
  74. INIT_SYSTICK;
  75. START_CYCLE_MEASUREMENT;
  76. SWAP_ROWS_F32(&mat,2,0,1);
  77. STOP_CYCLE_MEASUREMENT;
  78. if (!validate(a.const_ptr(),(const float32_t*)mat.pData,R*C))
  79. {
  80. printf("row add failed \r\n");
  81. }
  82. std::cout << "=====\r\n";
  83. }
  84. template<typename T>
  85. void all_row_test()
  86. {
  87. const int nb_tails = TailForTests<T>::tail;
  88. const int nb_loops = TailForTests<T>::loop;
  89. title<T>("Row test");
  90. test<T,2,NBVEC_4>();
  91. test<T,4,NBVEC_4>();
  92. test<T,5,NBVEC_4>();
  93. test<T,9,NBVEC_4>();
  94. test<T,2,NBVEC_8>();
  95. test<T,4,NBVEC_8>();
  96. test<T,5,NBVEC_8>();
  97. test<T,9,NBVEC_8>();
  98. test<T,2,NBVEC_16>();
  99. test<T,4,NBVEC_16>();
  100. test<T,5,NBVEC_16>();
  101. test<T,9,NBVEC_16>();
  102. test<T,2,nb_loops>();
  103. test<T,4,nb_loops>();
  104. test<T,5,nb_loops>();
  105. test<T,9,nb_loops>();
  106. test<T,2,nb_loops+1>();
  107. test<T,4,nb_loops+1>();
  108. test<T,5,nb_loops+1>();
  109. test<T,9,nb_loops+1>();
  110. test<T,2,nb_loops+nb_tails>();
  111. test<T,4,nb_loops+nb_tails>();
  112. test<T,5,nb_loops+nb_tails>();
  113. test<T,9,nb_loops+nb_tails>();
  114. if constexpr (std::is_same<T,float>::value)
  115. {
  116. title<T>("Swap test");
  117. swaptest<T,2,NBVEC_32>();
  118. swaptest<T,4,NBVEC_32>();
  119. swaptest<T,5,NBVEC_32>();
  120. swaptest<T,9,NBVEC_32>();
  121. swaptest<T,2,nb_loops>();
  122. swaptest<T,4,nb_loops>();
  123. swaptest<T,5,nb_loops>();
  124. swaptest<T,9,nb_loops>();
  125. swaptest<T,2,nb_loops+1>();
  126. swaptest<T,4,nb_loops+1>();
  127. swaptest<T,5,nb_loops+1>();
  128. swaptest<T,9,nb_loops+1>();
  129. swaptest<T,2,nb_loops+nb_tails>();
  130. swaptest<T,4,nb_loops+nb_tails>();
  131. swaptest<T,5,nb_loops+nb_tails>();
  132. swaptest<T,9,nb_loops+nb_tails>();
  133. }
  134. //print_map("Stats",max_stats);
  135. }
  136. void row_test()
  137. {
  138. #if defined(ROW_TEST)
  139. #if defined(F64_DT)
  140. all_row_test<double>();
  141. #endif
  142. #if defined(F32_DT)
  143. all_row_test<float>();
  144. #endif
  145. #if defined(F16_DT) && !defined(DISABLEFLOAT16)
  146. all_row_test<float16_t>();
  147. #endif
  148. #if defined(Q31_DT)
  149. all_row_test<Q31>();
  150. #endif
  151. #if defined(Q15_DT)
  152. all_row_test<Q15>();
  153. #endif
  154. #if defined(Q7_DT)
  155. all_row_test<Q7>();
  156. #endif
  157. #endif
  158. }