Quaternion.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import os.path
  2. import numpy as np
  3. import itertools
  4. import Tools
  5. from pyquaternion import Quaternion
  6. # mult, multvec, inverse, conjugate, normalize rot2quat, quat2rot , norm
  7. def flattenQuat(l):
  8. return(np.array([list(x) for x in l]).reshape(4*len(l)))
  9. def flattenRot(l):
  10. return(np.array([list(x) for x in l]).reshape(9*len(l)))
  11. # q and -q are representing the same rotation.
  12. # So there is an ambiguity for the tests.
  13. # We force the real part of be positive.
  14. def mkQuaternion(mat):
  15. q=Quaternion(matrix=mat)
  16. if q.scalar < 0:
  17. return(-q)
  18. else:
  19. return(q)
  20. def writeTests(config,format):
  21. NBSAMPLES=128
  22. a=[Quaternion.random() for x in range(NBSAMPLES)]
  23. b=[Quaternion.random() for x in range(NBSAMPLES)]
  24. config.writeInput(1, flattenQuat(a))
  25. config.writeInput(2, flattenQuat(b))
  26. normTest = [x.norm for x in a]
  27. config.writeReference(1, normTest)
  28. inverseTest = [x.inverse for x in a]
  29. config.writeReference(2, flattenQuat(inverseTest))
  30. conjugateTest = [x.conjugate for x in a]
  31. config.writeReference(3, flattenQuat(conjugateTest))
  32. normalizeTest = [x.normalised for x in a]
  33. config.writeReference(4, flattenQuat(normalizeTest))
  34. multTest = [a[i] * b[i] for i in range(NBSAMPLES)]
  35. config.writeReference(5, flattenQuat(multTest))
  36. quat2RotTest = [x.rotation_matrix for x in a]
  37. config.writeReference(6, flattenRot(quat2RotTest))
  38. config.writeInput(7, flattenRot(quat2RotTest))
  39. rot2QuatTest = [mkQuaternion(x) for x in quat2RotTest]
  40. config.writeReference(7, flattenQuat(rot2QuatTest))
  41. def generatePatterns():
  42. PATTERNDIR = os.path.join("Patterns","DSP","QuaternionMaths","QuaternionMaths")
  43. PARAMDIR = os.path.join("Parameters","DSP","QuaternionMaths","QuaternionMaths")
  44. configf32=Tools.Config(PATTERNDIR,PARAMDIR,"f32")
  45. configf16=Tools.Config(PATTERNDIR,PARAMDIR,"f16")
  46. writeTests(configf32,0)
  47. writeTests(configf16,16)
  48. # Params just as example
  49. someLists=[[1,3,5],[1,3,5],[1,3,5]]
  50. r=np.array([element for element in itertools.product(*someLists)])
  51. configf32.writeParam(1, r.reshape(81))
  52. if __name__ == '__main__':
  53. generatePatterns()