train.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from sklearn.naive_bayes import GaussianNB
  2. import random
  3. import numpy as np
  4. import math
  5. from pylab import scatter,figure, clf, plot, xlabel, ylabel, xlim, ylim, title, grid, axes, show,semilogx, semilogy
  6. import matplotlib.pyplot as plt
  7. from matplotlib.font_manager import FontProperties
  8. # Generation of data to train the classifier
  9. # 100 vectors are generated. Vector have dimension 2 so can be represented as points
  10. NBVECS = 100
  11. VECDIM = 2
  12. # 3 cluster of points are generated
  13. ballRadius = 1.0
  14. x1 = [1.5, 1] + ballRadius * np.random.randn(NBVECS,VECDIM)
  15. x2 = [-1.5, 1] + ballRadius * np.random.randn(NBVECS,VECDIM)
  16. x3 = [0, -3] + ballRadius * np.random.randn(NBVECS,VECDIM)
  17. # All points are concatenated
  18. X_train=np.concatenate((x1,x2,x3))
  19. # The classes are 0,1 and 2.
  20. Y_train=np.concatenate((np.zeros(NBVECS),np.ones(NBVECS),2*np.ones(NBVECS)))
  21. gnb = GaussianNB()
  22. gnb.fit(X_train, Y_train)
  23. print("Testing")
  24. y_pred = gnb.predict([[1.5,1.0]])
  25. print(y_pred)
  26. y_pred = gnb.predict([[-1.5,1.0]])
  27. print(y_pred)
  28. y_pred = gnb.predict([[0,-3.0]])
  29. print(y_pred)
  30. # Dump of data for CMSIS-DSP
  31. print("Parameters")
  32. # Gaussian averages
  33. print("Theta = ",list(np.reshape(gnb.theta_,np.size(gnb.theta_))))
  34. # Gaussian variances
  35. print("Sigma = ",list(np.reshape(gnb.sigma_,np.size(gnb.sigma_))))
  36. # Class priors
  37. print("Prior = ",list(np.reshape(gnb.class_prior_,np.size(gnb.class_prior_))))
  38. print("Epsilon = ",gnb.epsilon_)
  39. # Some bounds are computed for the graphical representation
  40. x_min = X_train[:, 0].min()
  41. x_max = X_train[:, 0].max()
  42. y_min = X_train[:, 1].min()
  43. y_max = X_train[:, 1].max()
  44. font = FontProperties()
  45. font.set_size(20)
  46. r=plt.figure()
  47. plt.axis('off')
  48. plt.text(1.5,1.0,"A", verticalalignment='center', horizontalalignment='center',fontproperties=font)
  49. plt.text(-1.5,1.0,"B",verticalalignment='center', horizontalalignment='center', fontproperties=font)
  50. plt.text(0,-3,"C", verticalalignment='center', horizontalalignment='center',fontproperties=font)
  51. scatter(x1[:,0],x1[:,1],s=1.0,color='#FF6B00')
  52. scatter(x2[:,0],x2[:,1],s=1.0,color='#95D600')
  53. scatter(x3[:,0],x3[:,1],s=1.0,color='#00C1DE')
  54. #r.savefig('fig.jpeg')
  55. #plt.close(r)
  56. show()