| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374 |
- from sklearn.naive_bayes import GaussianNB
- import random
- import numpy as np
- import math
- from pylab import scatter,figure, clf, plot, xlabel, ylabel, xlim, ylim, title, grid, axes, show,semilogx, semilogy
- import matplotlib.pyplot as plt
- from matplotlib.font_manager import FontProperties
- # Generation of data to train the classifier
- # 100 vectors are generated. Vector have dimension 2 so can be represented as points
- NBVECS = 100
- VECDIM = 2
- # 3 cluster of points are generated
- ballRadius = 1.0
- x1 = [1.5, 1] + ballRadius * np.random.randn(NBVECS,VECDIM)
- x2 = [-1.5, 1] + ballRadius * np.random.randn(NBVECS,VECDIM)
- x3 = [0, -3] + ballRadius * np.random.randn(NBVECS,VECDIM)
- # All points are concatenated
- X_train=np.concatenate((x1,x2,x3))
- # The classes are 0,1 and 2.
- Y_train=np.concatenate((np.zeros(NBVECS),np.ones(NBVECS),2*np.ones(NBVECS)))
- gnb = GaussianNB()
- gnb.fit(X_train, Y_train)
- print("Testing")
- y_pred = gnb.predict([[1.5,1.0]])
- print(y_pred)
- y_pred = gnb.predict([[-1.5,1.0]])
- print(y_pred)
- y_pred = gnb.predict([[0,-3.0]])
- print(y_pred)
- # Dump of data for CMSIS-DSP
- print("Parameters")
- # Gaussian averages
- print("Theta = ",list(np.reshape(gnb.theta_,np.size(gnb.theta_))))
- # Gaussian variances
- print("Sigma = ",list(np.reshape(gnb.sigma_,np.size(gnb.sigma_))))
- # Class priors
- print("Prior = ",list(np.reshape(gnb.class_prior_,np.size(gnb.class_prior_))))
- print("Epsilon = ",gnb.epsilon_)
- # Some bounds are computed for the graphical representation
- x_min = X_train[:, 0].min()
- x_max = X_train[:, 0].max()
- y_min = X_train[:, 1].min()
- y_max = X_train[:, 1].max()
- font = FontProperties()
- font.set_size(20)
- r=plt.figure()
- plt.axis('off')
- plt.text(1.5,1.0,"A", verticalalignment='center', horizontalalignment='center',fontproperties=font)
- plt.text(-1.5,1.0,"B",verticalalignment='center', horizontalalignment='center', fontproperties=font)
- plt.text(0,-3,"C", verticalalignment='center', horizontalalignment='center',fontproperties=font)
- scatter(x1[:,0],x1[:,1],s=1.0,color='#FF6B00')
- scatter(x2[:,0],x2[:,1],s=1.0,color='#95D600')
- scatter(x3[:,0],x3[:,1],s=1.0,color='#00C1DE')
- #r.savefig('fig.jpeg')
- #plt.close(r)
- show()
|