ComplexMaths.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import os.path
  2. import numpy as np
  3. import itertools
  4. import Tools
  5. # Those patterns are used for tests and benchmarks.
  6. # For tests, there is the need to add tests for saturation
  7. def randComplex(nb):
  8. data = np.random.randn(2*nb)
  9. data = Tools.normalize(data)
  10. data_comp = data.view(dtype=np.complex128)
  11. return(data_comp)
  12. def asReal(a):
  13. #return(a.view(dtype=np.float64))
  14. return(a.reshape(np.size(a)).view(dtype=np.float64))
  15. def writeTests(config,format):
  16. NBSAMPLES=256
  17. data1=randComplex(NBSAMPLES)
  18. data2=randComplex(NBSAMPLES)
  19. data3=np.random.randn(NBSAMPLES)
  20. data3 = Tools.normalize(data3)
  21. config.writeInput(1, asReal(data1))
  22. config.writeInput(2, asReal(data2))
  23. config.writeInput(3, data3)
  24. ref = np.conj(data1)
  25. config.writeReference(1, asReal(ref))
  26. nb = Tools.loopnb(format,Tools.TAILONLY)
  27. ref = np.array(np.dot(data1[0:nb],data2[0:nb]))
  28. if format==31:
  29. ref = ref / 2**15 # Because CMSIS format is 16.48
  30. config.writeReferenceQ63(2, asReal(ref))
  31. elif format==15:
  32. ref = ref / 2**7 # Because CMSIS format is 8.24
  33. config.writeReferenceQ31(2, asReal(ref))
  34. else:
  35. config.writeReference(2, asReal(ref))
  36. nb = Tools.loopnb(format,Tools.BODYONLY)
  37. ref = np.array(np.dot(data1[0:nb] ,data2[0:nb]))
  38. if format==31:
  39. ref = ref / 2**15 # Because CMSIS format is 16.48
  40. config.writeReferenceQ63(3, asReal(ref))
  41. elif format==15:
  42. ref = ref / 2**7 # Because CMSIS format is 8.24
  43. config.writeReferenceQ31(3, asReal(ref))
  44. else:
  45. config.writeReference(3, asReal(ref))
  46. #
  47. nb = Tools.loopnb(format,Tools.BODYANDTAIL)
  48. ref = np.array(np.dot(data1[0:nb] ,data2[0:nb]))
  49. if format==31:
  50. ref = ref / 2**15 # Because CMSIS format is 16.48
  51. config.writeReferenceQ63(4, asReal(ref))
  52. elif format==15:
  53. ref = ref / 2**7 # Because CMSIS format is 8.24
  54. config.writeReferenceQ31(4, asReal(ref))
  55. else:
  56. config.writeReference(4, asReal(ref))
  57. #
  58. ref = np.absolute(data1)
  59. if format==31:
  60. ref = ref / 2 # Because CMSIS format is 2.30
  61. elif format==15:
  62. ref = ref / 2 # Because CMSIS format is 2.14
  63. config.writeReference(5, ref)
  64. #
  65. ref = np.absolute(data1)**2
  66. if format==31:
  67. ref = ref / 4 # Because CMSIS format is 3.29
  68. elif format==15:
  69. ref = ref / 4 # Because CMSIS format is 3.13
  70. config.writeReference(6, ref)
  71. #
  72. ref = data1 * data2
  73. if format==31:
  74. ref = ref / 4 # Because CMSIS format is 3.29
  75. elif format==15:
  76. ref = ref / 4 # Because CMSIS format is 3.13
  77. config.writeReference(7, asReal(ref))
  78. #
  79. ref = data1 * data3
  80. config.writeReference(8, asReal(ref))
  81. ref = np.array(np.dot(data1 ,data2))
  82. if format==31:
  83. ref = ref / 2**15 # Because CMSIS format is 16.48
  84. config.writeReferenceQ63(9, asReal(ref))
  85. elif format==15:
  86. ref = ref / 2**7 # Because CMSIS format is 8.24
  87. config.writeReferenceQ31(9, asReal(ref))
  88. else:
  89. config.writeReference(9, asReal(ref))
  90. def generatePatterns():
  91. PATTERNDIR = os.path.join("Patterns","DSP","ComplexMaths","ComplexMaths")
  92. PARAMDIR = os.path.join("Parameters","DSP","ComplexMaths","ComplexMaths")
  93. configf64=Tools.Config(PATTERNDIR,PARAMDIR,"f64")
  94. configf32=Tools.Config(PATTERNDIR,PARAMDIR,"f32")
  95. configf16=Tools.Config(PATTERNDIR,PARAMDIR,"f16")
  96. configq31=Tools.Config(PATTERNDIR,PARAMDIR,"q31")
  97. configq15=Tools.Config(PATTERNDIR,PARAMDIR,"q15")
  98. configf32.setOverwrite(False)
  99. configf16.setOverwrite(False)
  100. configq31.setOverwrite(False)
  101. configq15.setOverwrite(False)
  102. writeTests(configf64,Tools.F64)
  103. writeTests(configf32,0)
  104. writeTests(configf16,16)
  105. writeTests(configq31,31)
  106. writeTests(configq15,15)
  107. if __name__ == '__main__':
  108. generatePatterns()