train.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. from sklearn import svm
  2. import random
  3. import numpy as np
  4. import math
  5. from pylab import scatter,figure, clf, plot, xlabel, ylabel, xlim, ylim, title, grid, axes, show,semilogx, semilogy
  6. import matplotlib.pyplot as plt
  7. from matplotlib.ticker import MaxNLocator
  8. from matplotlib.colors import BoundaryNorm,ListedColormap
  9. # Generation of data to train the SVM classifier
  10. # 100 vectors are generated. Vector have dimension 2 so can be represented as points
  11. NBVECS = 100
  12. VECDIM = 2
  13. # A cluster of point is generated around the origin.
  14. ballRadius = 0.5
  15. x = ballRadius * np.random.randn(NBVECS,2)
  16. # An annulus of point is generated around the central cluster.
  17. angle = 2.0*math.pi * np.random.randn(1,NBVECS)
  18. radius = 3.0+0.1*np.random.randn(1,NBVECS)
  19. xa = np.zeros((NBVECS,2))
  20. xa[:,0]=radius*np.cos(angle)
  21. xa[:,1]=radius*np.sin(angle)
  22. # All points are concatenated
  23. X_train=np.concatenate((x,xa))
  24. # First points (central cluster) are corresponding to class 0
  25. # OTher points (annulus) are corresponding to class 1
  26. Y_train=np.concatenate((np.zeros(NBVECS),np.ones(NBVECS)))
  27. # Some bounds are computed for the graphical representation
  28. x_min = X_train[:, 0].min()
  29. x_max = X_train[:, 0].max()
  30. y_min = X_train[:, 1].min()
  31. y_max = X_train[:, 1].max()
  32. # Training is done with a polynomial SVM
  33. clf = svm.SVC(kernel='poly',gamma='auto', coef0=1.1)
  34. clf.fit(X_train, Y_train)
  35. # The classifier is tested with a first point inside first class
  36. test1=np.array([0.4,0.1])
  37. test1=test1.reshape(1,-1)
  38. predicted1 = clf.predict(test1)
  39. # Predicted class should be 0
  40. print(predicted1)
  41. # Second test is made with a point inside the second class (in the annulus)
  42. test2=np.array([x_max,0]).reshape(1,-1)
  43. predicted2 = clf.predict(test2)
  44. # Predicted class should be 1
  45. print(predicted2)
  46. # The parameters of the trained classifier are printed to be used
  47. # in CMSIS-DSP
  48. supportShape = clf.support_vectors_.shape
  49. nbSupportVectors=supportShape[0]
  50. vectorDimensions=supportShape[1]
  51. print("nbSupportVectors = %d" % nbSupportVectors)
  52. print("vectorDimensions = %d" % vectorDimensions)
  53. print("degree = %d" % clf.degree)
  54. print("coef0 = %f" % clf.coef0)
  55. print("gamma = %f" % clf._gamma)
  56. print("intercept = %f" % clf.intercept_)
  57. dualCoefs=clf.dual_coef_
  58. dualCoefs=dualCoefs.reshape(nbSupportVectors)
  59. supportVectors=clf.support_vectors_
  60. supportVectors = supportVectors.reshape(nbSupportVectors*VECDIM)
  61. print("Dual Coefs")
  62. print(dualCoefs)
  63. print("Support Vectors")
  64. print(supportVectors)
  65. # Graphical representation to display the cluster of points
  66. # and the SVM boundary
  67. r=plt.figure()
  68. plt.axis('off')
  69. XX, YY = np.mgrid[x_min:x_max:200j, y_min:y_max:200j]
  70. Z = clf.decision_function(np.c_[XX.ravel(), YY.ravel()])
  71. # Put the result into a color plot
  72. Z = Z.reshape(XX.shape)
  73. levels = MaxNLocator(nbins=15).tick_values(Z.min(), Z.max())
  74. #cmap = plt.get_cmap('gray')
  75. newcolors = ['#FFFFFF','#FFFFFF']
  76. cmap = ListedColormap(newcolors)
  77. norm = BoundaryNorm(levels, ncolors=cmap.N, clip=True)
  78. plt.pcolormesh(XX, YY, Z > 0, cmap=cmap,norm=norm)
  79. plt.contour(XX, YY, Z, colors=['k', 'k', 'k'],
  80. linestyles=['--', '-', '--'], levels=[-.5, 0, .5])
  81. scatter(x[:,0],x[:,1],s=1.0,color='#FF6B00')
  82. scatter(xa[:,0],xa[:,1],s=1.0,color='#95D600')
  83. # The test points are displayed in red.
  84. scatter(test1[:,0],test1[:,1],s=6.0,color='Red')
  85. scatter(test2[:,0],test2[:,1],s=6.0,color='Red')
  86. #r.savefig('fig1.jpeg')
  87. #plt.close(r)
  88. show()
  89. #r=plt.figure()
  90. #plt.axis('off')
  91. #scatter(x[:,0],x[:,1],s=1.0,color='#FF6B00')
  92. #scatter(xa[:,0],xa[:,1],s=1.0,color='#95D600')
  93. #r.savefig('fig2.jpeg')
  94. #plt.close(r)