FastMath.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import os.path
  2. import numpy as np
  3. import itertools
  4. import Tools
  5. import math
  6. # Those patterns are used for tests and benchmarks.
  7. # For tests, there is the need to add tests for saturation
  8. # For benchmarks
  9. NBSAMPLES=256
  10. def cartesian(*somelists):
  11. r=[]
  12. for element in itertools.product(*somelists):
  13. r.append(element)
  14. return(r)
  15. # Fixed point division should not be called with a denominator of zero.
  16. # But if it is, it should return a saturated result.
  17. def divide(f,r):
  18. e = 0
  19. if f == Tools.Q31:
  20. e = 1.0 / (1<<31)
  21. if f == Tools.Q15:
  22. e = 1.0 / (1<<15)
  23. if f == Tools.Q7:
  24. e = 1.0 / (1<<7)
  25. a,b=r
  26. if b == 0.0:
  27. if a >= 0.0:
  28. return(1.0,0)
  29. else:
  30. return(-1.0,0)
  31. k = 0
  32. while abs(a) > abs(b):
  33. a = a / 2.0
  34. k = k + 1
  35. # In C code we don't saturate but instead generate the right value
  36. # with a shift of 1.
  37. # So this test is to ease the comparison between the Python reference
  38. # and the output of the division algorithm in C
  39. if abs(a/b) > 1 - e:
  40. a = a / 2.0
  41. k = k + 1
  42. return(a/b,k)
  43. def writeTests(config,format):
  44. config.setOverwrite(False)
  45. a1=np.array([0,math.pi/4,math.pi/2,3*math.pi/4,math.pi,5*math.pi/4,3*math.pi/2,2*math.pi-1e-6])
  46. a2=np.array([-math.pi/4,-math.pi/2,-3*math.pi/4,-math.pi,-5*math.pi/4,-3*math.pi/2,-2*math.pi-1e-6])
  47. a3 = a1 + 2*math.pi
  48. angles=np.concatenate((a1,a2,a3))
  49. refcos = np.cos(angles)
  50. refsin = np.sin(angles)
  51. vals=np.array([0.0, 0.0, 0.1,1.0,2.0,3.0,3.5,3.6])
  52. sqrtvals=np.sqrt(vals)
  53. # Negative values in CMSIS are giving 0
  54. vals[0] = -0.4
  55. sqrtvals[0] = 0.0
  56. if format != 0 and format != 16:
  57. angles=np.concatenate((a1,a2,a1))
  58. angles = angles / (2*math.pi)
  59. config.writeInput(1, angles,"Angles")
  60. config.writeInput(1, vals,"SqrtInput")
  61. config.writeReference(1, refcos,"Cos")
  62. config.writeReference(1, refsin,"Sin")
  63. config.writeReference(1, sqrtvals,"Sqrt")
  64. # For benchmarks
  65. samples=np.random.randn(NBSAMPLES)
  66. samples = np.abs(Tools.normalize(samples))
  67. config.writeInput(1, samples,"Samples")
  68. config.setOverwrite(True)
  69. numerator=np.linspace(-0.9,0.9)
  70. denominator=np.linspace(-0.9,0.9)
  71. samples=cartesian(numerator,denominator)
  72. numerator=[x[0] for x in samples]
  73. denominator=[x[1] for x in samples]
  74. result=[divide(format,x) for x in samples]
  75. resultValue=[x[0] for x in result]
  76. resultShift=[x[1] for x in result]
  77. config.writeInput(1, numerator,"Numerator")
  78. config.writeInput(1, denominator,"Denominator")
  79. config.writeReference(1, resultValue,"DivisionValue")
  80. config.writeReferenceS16(1, resultShift,"DivisionShift")
  81. config.setOverwrite(False)
  82. def writeTestsFloat(config,format):
  83. config.setOverwrite(False)
  84. writeTests(config,format)
  85. data1 = np.random.randn(20)
  86. data1 = np.abs(data1)
  87. data1 = data1 + 1e-3 # To avoid zero values
  88. data1 = Tools.normalize(data1)
  89. samples=np.concatenate((np.array([0.1,0.3,0.5,1.0,2.0]) , data1))
  90. config.writeInput(1, samples,"LogInput")
  91. v = np.log(samples)
  92. config.writeReference(1, v,"Log")
  93. samples=np.concatenate((np.array([0.0,1.0]),np.linspace(-0.4,0.4)))
  94. config.writeInput(1, samples,"ExpInput")
  95. v = np.exp(samples)
  96. config.writeReference(1, v,"Exp")
  97. # For benchmarks and other tests
  98. samples=np.random.randn(NBSAMPLES)
  99. samples = np.abs(Tools.normalize(samples))
  100. config.writeInput(1, samples,"Samples")
  101. v = 1.0 / samples
  102. config.writeReference(1, v,"Inverse")
  103. config.setOverwrite(True)
  104. def generatePatterns():
  105. PATTERNDIR = os.path.join("Patterns","DSP","FastMath","FastMath")
  106. PARAMDIR = os.path.join("Parameters","DSP","FastMath","FastMath")
  107. configf32=Tools.Config(PATTERNDIR,PARAMDIR,"f32")
  108. configf16=Tools.Config(PATTERNDIR,PARAMDIR,"f16")
  109. configq31=Tools.Config(PATTERNDIR,PARAMDIR,"q31")
  110. configq15=Tools.Config(PATTERNDIR,PARAMDIR,"q15")
  111. #writeTestsFloat(configf32,0)
  112. #writeTestsFloat(configf16,16)
  113. #writeTests(configq31,31)
  114. writeTests(configq15,15)
  115. if __name__ == '__main__':
  116. generatePatterns()