example_1_11.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # Bug corrections for version 1.9
  2. import cmsisdsp as dsp
  3. import cmsisdsp.fixedpoint as f
  4. import numpy as np
  5. import colorama
  6. from colorama import init,Fore, Back, Style
  7. from numpy.testing import assert_allclose
  8. from numpy.linalg import norm
  9. import matplotlib
  10. import matplotlib as mpl
  11. import matplotlib.pyplot as plt
  12. init()
  13. def printTitle(s):
  14. print("\n" + Fore.GREEN + Style.BRIGHT + s + Style.RESET_ALL)
  15. def printSubTitle(s):
  16. print("\n" + Style.BRIGHT + s + Style.RESET_ALL)
  17. printTitle("DTW Window")
  18. printSubTitle("SAKOE_CHIBA_WINDOW")
  19. refWin1=np.array([[1, 1, 1, 0, 0],
  20. [1, 1, 1, 1, 0],
  21. [1, 1, 1, 1, 1],
  22. [0, 1, 1, 1, 1],
  23. [0, 0, 1, 1, 1],
  24. [0, 0, 0, 1, 1],
  25. [0, 0, 0, 0, 1],
  26. [0, 0, 0, 0, 0],
  27. [0, 0, 0, 0, 0],
  28. [0, 0, 0, 0, 0]], dtype=np.int8)
  29. dtwWindow=np.zeros((10,5),dtype=np.int8)
  30. wsize=2
  31. status,w=dsp.arm_dtw_init_window_q7(dsp.ARM_DTW_SAKOE_CHIBA_WINDOW,wsize,dtwWindow)
  32. assert (w==refWin1).all()
  33. printSubTitle("SLANTED_BAND_WINDOW")
  34. refWin2=np.array([[1, 1, 0, 0, 0],
  35. [1, 1, 0, 0, 0],
  36. [1, 1, 1, 0, 0],
  37. [0, 1, 1, 0, 0],
  38. [0, 1, 1, 1, 0],
  39. [0, 0, 1, 1, 0],
  40. [0, 0, 1, 1, 1],
  41. [0, 0, 0, 1, 1],
  42. [0, 0, 0, 1, 1],
  43. [0, 0, 0, 0, 1]], dtype=np.int8)
  44. dtwWindow=np.zeros((10,5),dtype=np.int8)
  45. wsize=1
  46. status,w=dsp.arm_dtw_init_window_q7(dsp.ARM_DTW_SLANTED_BAND_WINDOW,wsize,dtwWindow)
  47. assert (w==refWin2).all()
  48. printTitle("DTW Cost Matrix and DTW Distance")
  49. QUERY_LENGTH = 10
  50. TEMPLATE_LENGTH = 5
  51. query=np.array([ 0.08387197, 0.68082274, 1.06756417, 0.88914541, 0.42513398, -0.3259053,
  52. -0.80934885, -0.90979435, -0.64026483, 0.06923695])
  53. template=np.array([ 1.00000000e+00, 7.96326711e-04, -9.99998732e-01, -2.38897811e-03,
  54. 9.99994927e-01])
  55. cols=np.array([1,2,3])
  56. rows=np.array([10,11,12])
  57. printSubTitle("Without a window")
  58. referenceCost=np.array([[0.91612804, 0.9992037 , 2.0830743 , 2.1693354 , 3.0854583 ],
  59. [1.2353053 , 1.6792301 , 3.3600516 , 2.8525472 , 2.8076797 ],
  60. [1.3028694 , 2.3696373 , 4.4372 , 3.9225004 , 2.875249 ],
  61. [1.4137241 , 2.302073 , 4.1912174 , 4.814035 , 2.9860985 ],
  62. [1.98859 , 2.2623994 , 3.6875322 , 4.115055 , 3.5609593 ],
  63. [3.3144953 , 2.589101 , 3.2631946 , 3.586711 , 4.8868594 ],
  64. [5.123844 , 3.3992462 , 2.9704008 , 3.7773607 , 5.5867043 ],
  65. [7.0336385 , 4.309837 , 3.0606053 , 3.9680107 , 5.8778 ],
  66. [8.673903 , 4.950898 , 3.420339 , 4.058215 , 5.698475 ],
  67. [9.604667 , 5.0193386 , 4.489575 , 3.563591 , 4.494349 ]],
  68. dtype=np.float32)
  69. referenceDistance = 0.2996232807636261
  70. # Each row is a new query
  71. a,b = np.meshgrid(template,query)
  72. distance=abs(a-b).astype(np.float32)
  73. status,dtwDistance,dtwMatrix = dsp.arm_dtw_distance_f32(distance,None)
  74. assert_allclose(referenceDistance,dtwDistance)
  75. assert_allclose(referenceCost,dtwMatrix)
  76. printSubTitle("Path")
  77. path=dsp.arm_dtw_path_f32(np.copy(dtwMatrix))
  78. #print(path)
  79. pathMatrix=np.zeros(dtwMatrix.shape)
  80. for x in list(zip(path[0::2],path[1::2])):
  81. pathMatrix[x] = 1
  82. fig, ax = plt.subplots()
  83. im = ax.imshow(pathMatrix,vmax=2.0)
  84. for i in range(QUERY_LENGTH):
  85. for j in range(TEMPLATE_LENGTH):
  86. text = ax.text(j, i, "%.1f" % dtwMatrix[i, j],
  87. ha="center", va="center", color="w")
  88. fig.tight_layout()
  89. plt.show()
  90. printSubTitle("With a window")
  91. referenceDistance = 0.617099940776825
  92. referenceCost=np.array([[9.1612804e-01, 9.9920368e-01, np.NAN, np.NAN,
  93. np.NAN],
  94. [1.2353053e+00, 1.6792301e+00, np.NAN, np.NAN,
  95. np.NAN],
  96. [1.3028694e+00, 2.3696373e+00, 4.4372001e+00, np.NAN,
  97. np.NAN],
  98. [np.NAN, 3.0795674e+00, 4.9687119e+00, np.NAN,
  99. np.NAN],
  100. [np.NAN, 3.5039051e+00, 4.9290380e+00, 5.3565612e+00,
  101. np.NAN],
  102. [np.NAN, np.NAN, 4.8520918e+00, 5.1756082e+00,
  103. np.NAN],
  104. [np.NAN, np.NAN, 5.0427418e+00, 5.8497019e+00,
  105. 7.6590457e+00],
  106. [np.NAN, np.NAN, np.NAN, 6.7571073e+00,
  107. 8.6668968e+00],
  108. [np.NAN, np.NAN, np.NAN, 7.3949833e+00,
  109. 9.0352430e+00],
  110. [np.NAN, np.NAN, np.NAN, np.NAN,
  111. 9.2564993e+00]], dtype=np.float32)
  112. status,dtwDistance,dtwMatrix = dsp.arm_dtw_distance_f32(distance,w)
  113. assert_allclose(referenceDistance,dtwDistance)
  114. assert_allclose(referenceCost[w==1],dtwMatrix[w==1])