BasicMaths.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. import os.path
  2. import numpy as np
  3. import itertools
  4. import Tools
  5. # Those patterns are used for tests and benchmarks.
  6. # For tests, there is the need to add tests for saturation
  7. def writeTests(config,format):
  8. NBSAMPLES=256
  9. data1=np.random.randn(NBSAMPLES)
  10. data2=np.random.randn(NBSAMPLES)
  11. data3=np.random.randn(1)
  12. data1 = Tools.normalize(data1)
  13. data2 = Tools.normalize(data2)
  14. config.writeInput(1, data1)
  15. config.writeInput(2, data2)
  16. ref = data1 + data2
  17. config.writeReference(1, ref)
  18. ref = data1 - data2
  19. config.writeReference(2, ref)
  20. ref = data1 * data2
  21. config.writeReference(3, ref)
  22. ref = -data1
  23. config.writeReference(4, ref)
  24. ref = data1 + 0.5
  25. config.writeReference(5, ref)
  26. ref = data1 * 0.5
  27. config.writeReference(6, ref)
  28. nb = Tools.loopnb(format,Tools.TAILONLY)
  29. ref = np.array([np.dot(data1[0:nb] ,data2[0:nb])])
  30. if format == 31 or format == 15:
  31. if format==31:
  32. ref = ref / 2**15 # Because CMSIS format is 16.48
  33. if format==15:
  34. ref = ref / 2**33 # Because CMSIS format is 34.30
  35. config.writeReferenceQ63(7, ref)
  36. elif format == 7:
  37. ref = ref / 2**17 # Because CMSIS format is 18.14
  38. config.writeReferenceQ31(7, ref)
  39. else:
  40. config.writeReference(7, ref)
  41. nb = Tools.loopnb(format,Tools.BODYONLY)
  42. ref = np.array([np.dot(data1[0:nb] ,data2[0:nb])])
  43. if format == 31 or format == 15:
  44. if format==31:
  45. ref = ref / 2**15 # Because CMSIS format is 16.48
  46. if format==15:
  47. ref = ref / 2**33 # Because CMSIS format is 34.30
  48. config.writeReferenceQ63(8, ref)
  49. elif format == 7:
  50. ref = ref / 2**17 # Because CMSIS format is 18.14
  51. config.writeReferenceQ31(8, ref)
  52. else:
  53. config.writeReference(8, ref)
  54. nb = Tools.loopnb(format,Tools.BODYANDTAIL)
  55. ref = np.array([np.dot(data1[0:nb] ,data2[0:nb])])
  56. if format == 31 or format == 15:
  57. if format==31:
  58. ref = ref / 2**15 # Because CMSIS format is 16.48
  59. if format==15:
  60. ref = ref / 2**33 # Because CMSIS format is 34.30
  61. config.writeReferenceQ63(9, ref)
  62. elif format == 7:
  63. ref = ref / 2**17 # Because CMSIS format is 18.14
  64. config.writeReferenceQ31(9, ref)
  65. else:
  66. config.writeReference(9, ref)
  67. ref = abs(data1)
  68. config.writeReference(10, ref)
  69. #ref = np.array([np.dot(data1 ,data2)])
  70. #config.writeReference(11, ref)
  71. return(11)
  72. def writeTestsWithSat(config,format):
  73. if format == 31:
  74. NBSAMPLES=9
  75. if format == 15:
  76. NBSAMPLES=17
  77. if format == 7:
  78. NBSAMPLES=33
  79. nb = writeTests(config,format)
  80. data1 = np.full(NBSAMPLES, 2**format - 1)
  81. data1[1::2] = 2
  82. data2 = np.full(NBSAMPLES, -2**format)
  83. data2[1::2] = -2
  84. datar=np.random.randn(NBSAMPLES)
  85. datar = Tools.normalize(datar)
  86. datar = datar / 3.0 # Because used to test shift of 2 without saturation
  87. config.writeInput(nb+1, datar)
  88. if format == 31:
  89. config.writeInputS32(nb+1,data1-1,"MaxPosInput")
  90. config.writeInputS32(nb+1,data2+1,"MaxNegInput")
  91. config.writeInputS32(nb+1,data2,"MaxNeg2Input")
  92. if format == 15:
  93. config.writeInputS16(nb+1,data1-1,"MaxPosInput")
  94. config.writeInputS16(nb+1,data2+1,"MaxNegInput")
  95. config.writeInputS16(nb+1,data2,"MaxNeg2Input")
  96. if format == 7:
  97. config.writeInputS8(nb+1,data1-1,"MaxPosInput")
  98. config.writeInputS8(nb+1,data2+1,"MaxNegInput")
  99. config.writeInputS8(nb+1,data2,"MaxNeg2Input")
  100. d1 = 1.0*(data1-1) / 2**format
  101. d2 = 1.0*(data2+1) / 2**format
  102. d3 = 1.0*(data2) / 2**format
  103. ref = d1 + d1
  104. config.writeReference(nb+1, ref,"PosSat")
  105. ref = d2 + d2
  106. config.writeReference(nb+2, ref,"NegSat")
  107. d1 = 1.0*(data1-1) / 2**format
  108. d2 = 1.0*(data2+1) / 2**format
  109. ref = d1 - d2
  110. config.writeReference(nb+3, ref,"PosSat")
  111. ref = d2 - d1
  112. config.writeReference(nb+4, ref,"NegSat")
  113. ref = d3*d3
  114. config.writeReference(nb+5, ref,"PosSat")
  115. ref = -d3
  116. config.writeReference(nb+6, ref,"PosSat")
  117. ref = d1 + 0.9
  118. config.writeReference(nb+7, ref,"PosSat")
  119. ref = d2 - 0.9
  120. config.writeReference(nb+8, ref,"NegSat")
  121. ref = d3 * d3[0]
  122. config.writeReference(nb+9, ref,"PosSat")
  123. ref = datar * 2.0
  124. config.writeReference(nb+10, ref,"Shift")
  125. ref = d1 * 2.0
  126. config.writeReference(nb+11, ref,"Shift")
  127. ref = d2 * 2.0
  128. config.writeReference(nb+12, ref,"Shift")
  129. def generatePatterns():
  130. PATTERNDIR = os.path.join("Patterns","DSP","BasicMaths","BasicMaths")
  131. PARAMDIR = os.path.join("Parameters","DSP","BasicMaths","BasicMaths")
  132. configf32=Tools.Config(PATTERNDIR,PARAMDIR,"f32")
  133. configq31=Tools.Config(PATTERNDIR,PARAMDIR,"q31")
  134. configq15=Tools.Config(PATTERNDIR,PARAMDIR,"q15")
  135. configq7=Tools.Config(PATTERNDIR,PARAMDIR,"q7")
  136. writeTests(configf32,0)
  137. writeTestsWithSat(configq31,31)
  138. writeTestsWithSat(configq15,15)
  139. writeTestsWithSat(configq7,7)
  140. # Params just as example
  141. someLists=[[1,3,5],[1,3,5],[1,3,5]]
  142. r=np.array([element for element in itertools.product(*someLists)])
  143. configf32.writeParam(1, r.reshape(81))
  144. if __name__ == '__main__':
  145. generatePatterns()