testrfft_all.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import cmsisdsp as dsp
  2. import cmsisdsp.fixedpoint as f
  3. import numpy as np
  4. from scipy import signal
  5. #import matplotlib.pyplot as plt
  6. import scipy.fft
  7. import colorama
  8. from colorama import init,Fore, Back, Style
  9. from numpy.testing import assert_allclose
  10. init()
  11. def printTitle(s):
  12. print("\n" + Fore.GREEN + Style.BRIGHT + s + Style.RESET_ALL)
  13. def printSubTitle(s):
  14. print("\n" + Style.BRIGHT + s + Style.RESET_ALL)
  15. def chop(A, eps = 1e-6):
  16. B = np.copy(A)
  17. B[np.abs(A) < eps] = 0
  18. return B
  19. # For fixed point version, compare that
  20. # the conjugate part is really the conjugate part
  21. def compareWithConjugatePart(r):
  22. res = r[0::2] + 1j * r[1::2]
  23. conjPart = res[nb:nb//2:-1].conj()
  24. refPart = res[1:nb//2]
  25. assert(np.equal(refPart , conjPart).all())
  26. nb = 32
  27. signal = np.cos(2 * np.pi * np.arange(nb) / nb)*np.cos(0.2*2 * np.pi * np.arange(nb) / nb)
  28. ref=scipy.fft.rfft(signal)
  29. invref = scipy.fft.irfft(ref)
  30. assert(len(ref) == (nb // 2) + 1)
  31. assert(len(invref) == nb)
  32. # Length of arrays for the float implementation
  33. # of the RFFT (in float so there is a factor 2
  34. # when the samples are complex)
  35. RFFT_F_IN_LENGTH = nb # real
  36. RFFT_F_OUT_LENGTH = nb # complex (so nb // 2 complex)
  37. RIFFT_F_IN_LENGTH = nb # complex
  38. RIFFT_F_OUT_LENGTH = nb # real
  39. # Length of arrays for the fixed point implementation
  40. # of the RFFT
  41. RFFT_Q_IN_LENGTH = nb
  42. RFFT_Q_OUT_LENGTH = 2*nb
  43. # Conjugate part ignored
  44. RIFFT_Q_IN_LENGTH = nb + 2
  45. RIFFT_Q_OUT_LENGTH = nb
  46. # Convert ref to CMSIS-DSP format
  47. referenceFloat=np.zeros(nb)
  48. # Replace complex datatype by real datatype
  49. referenceFloat[0::2] = np.real(ref)[:-1]
  50. referenceFloat[1::2] = np.imag(ref)[:-1]
  51. # Copy Nyquist frequency value into first
  52. # sample.This is just a storage trick so that the
  53. # output of the RFFT has same length as input
  54. # It is legacy behavior that we need to keep
  55. # for backward compatibility but it is not
  56. # very pretty
  57. referenceFloat[1] = np.real(ref[-1])
  58. referenceFixed=np.zeros(2*len(ref))
  59. referenceFixed[0::2] = np.real(ref)
  60. referenceFixed[1::2] = np.imag(ref)
  61. printTitle("RFFT FAST F64")
  62. printSubTitle("RFFT")
  63. rfftf64=dsp.arm_rfft_fast_instance_f64()
  64. status=dsp.arm_rfft_fast_init_f64(rfftf64,nb)
  65. result = dsp.arm_rfft_fast_f64(rfftf64,signal,0)
  66. assert(len(signal) == RFFT_F_IN_LENGTH)
  67. assert(len(result) == RFFT_F_OUT_LENGTH)
  68. assert_allclose(referenceFloat,result)
  69. printSubTitle("RIFFT")
  70. rifftf64=dsp.arm_rfft_fast_instance_f64()
  71. status=dsp.arm_rfft_fast_init_f64(rifftf64,nb)
  72. result = dsp.arm_rfft_fast_f64(rifftf64,referenceFloat,1)
  73. assert(len(referenceFloat) == RIFFT_F_IN_LENGTH)
  74. assert(len(result) == RIFFT_F_OUT_LENGTH)
  75. assert_allclose(invref,result,atol=1e-15)
  76. printTitle("RFFT FAST F32")
  77. printSubTitle("RFFT")
  78. rfftf32=dsp.arm_rfft_fast_instance_f32()
  79. status=dsp.arm_rfft_fast_init_f32(rfftf32,nb)
  80. result = dsp.arm_rfft_fast_f32(rfftf32,signal,0)
  81. assert(len(signal) == RFFT_F_IN_LENGTH)
  82. assert(len(result) == RFFT_F_OUT_LENGTH)
  83. assert_allclose(referenceFloat,result,rtol=3e-6)
  84. printSubTitle("RIFFT")
  85. rifftf32=dsp.arm_rfft_fast_instance_f32()
  86. status=dsp.arm_rfft_fast_init_f32(rifftf32,nb)
  87. result = dsp.arm_rfft_fast_f32(rifftf32,referenceFloat,1)
  88. assert(len(referenceFloat) == RIFFT_F_IN_LENGTH)
  89. assert(len(result) == RIFFT_F_OUT_LENGTH)
  90. assert_allclose(invref,result,atol=1e-7)
  91. # Fixed point
  92. printTitle("RFFT Q31")
  93. printSubTitle("RFFT")
  94. signalQ31 = f.toQ31(signal)
  95. rfftQ31=dsp.arm_rfft_instance_q31()
  96. status=dsp.arm_rfft_init_q31(rfftQ31,nb,0,1)
  97. resultQ31 = dsp.arm_rfft_q31(rfftQ31,signalQ31)
  98. assert(len(signalQ31) == RFFT_Q_IN_LENGTH)
  99. assert(len(resultQ31) == RFFT_Q_OUT_LENGTH)
  100. compareWithConjugatePart(resultQ31)
  101. # Drop the conjugate part which is not computed by scipy
  102. resultQ31 = resultQ31[:nb+2]
  103. assert(len(resultQ31) == RIFFT_Q_IN_LENGTH)
  104. resultF = f.Q31toF32(resultQ31) * nb
  105. assert_allclose(referenceFixed,resultF,rtol=1e-6,atol=1e-6)
  106. printSubTitle("RIFFT")
  107. rifftQ31=dsp.arm_rfft_instance_q31()
  108. status=dsp.arm_rfft_init_q31(rifftQ31,nb,1,1)
  109. # Apply CMSIS-DSP scaling
  110. referenceQ31 = f.toQ31(referenceFixed/ nb)
  111. resultQ31 = dsp.arm_rfft_q31(rifftQ31,referenceQ31)
  112. resultF = f.Q31toF32(resultQ31)
  113. assert(len(referenceQ31) == RIFFT_Q_IN_LENGTH)
  114. assert(len(resultQ31) == RIFFT_Q_OUT_LENGTH)
  115. assert_allclose(invref/nb,resultF,atol=1e-6)
  116. printTitle("RFFT Q15")
  117. printSubTitle("RFFT")
  118. signalQ15 = f.toQ15(signal)
  119. rfftQ15=dsp.arm_rfft_instance_q15()
  120. status=dsp.arm_rfft_init_q15(rfftQ15,nb,0,1)
  121. resultQ15 = dsp.arm_rfft_q15(rfftQ15,signalQ15)
  122. assert(len(signalQ15) == RFFT_Q_IN_LENGTH)
  123. assert(len(resultQ15) == RFFT_Q_OUT_LENGTH)
  124. compareWithConjugatePart(resultQ15)
  125. # Drop the conjugate part which is not computed by scipy
  126. resultQ15 = resultQ15[:nb+2]
  127. assert(len(resultQ15) == RIFFT_Q_IN_LENGTH)
  128. resultF = f.Q15toF32(resultQ15) * nb
  129. assert_allclose(referenceFixed,resultF,rtol=1e-6,atol=1e-2)
  130. printSubTitle("RIFFT")
  131. rifftQ15=dsp.arm_rfft_instance_q15()
  132. status=dsp.arm_rfft_init_q15(rifftQ15,nb,1,1)
  133. # Apply CMSIS-DSP scaling
  134. referenceQ15 = f.toQ15(referenceFixed / nb)
  135. resultQ15 = dsp.arm_rfft_q15(rifftQ15,referenceQ15)
  136. resultF = f.Q15toF32(resultQ15)
  137. assert(len(referenceQ15) == RIFFT_Q_IN_LENGTH)
  138. assert(len(resultQ15) == RIFFT_Q_OUT_LENGTH)
  139. assert_allclose(invref/nb,resultF,atol=1e-3)