人工神经网络是有一系列简单的单元相互紧密联系构成的,每个单元有一定数量的实数输入和唯一的实数输出。神经网络的一个重要的用途就是接受和处理传感器产生的复杂的输入并进行自适应性的学习,是一种模式匹配算法,通常用于解决分类和回归问题。

  常用的人工神经网络算法包括:感知机神经网络(Perceptron Neural Nerwork)、反向传播网络(Back Propagation,BP)、HopField网络、自组织映射网络(Self-Organizing Map,SOM)、学习矢量量化网络(Learning Vector Quantization,LVQ)

1、感知机模型

  感知机是一种线性分类器,它用于二类分类问题。它将一个实例分类为正类(取值+1)和负类(-1)。其物理意义:它是将输入空间(特征空间)划分为正负两类的分离超平面。

  输入:线性可分训练数据集T,学习率η

  输出:感知机参数w,b

  算法步骤:

    1)选取初始值w0和b0

    2)在训练数据集中选取数据(xi,yi)

    3)若y1(w.xi+b)<=0(即该实例为误分类点)则更新参数:w=w+η.yi.xi   b=b+η.yi

    4)在训练数据集中重复选取数据来更新w,b直到训练数据集中没有误分类点为止

  实验代码:

  1 from matplotlib import pyplot as plt
  2 from mpl_toolkits.mplot3d import Axes3D
  3 import numpy as np
  4 from sklearn.datasets import load_iris
  5 from sklearn.neural_network import MLPClassifier
  6 
  7 
  8 def creat_data(n):
  9     np.random.seed(1)
 10     x_11=np.random.randint(0,100,(n,1))
 11     x_12=np.random.randint(0,100,(n,1,))
 12     x_13 = 20+np.random.randint(0, 10, (n, 1,))
 13     x_21 = np.random.randint(0, 100, (n, 1))
 14     x_22 = np.random.randint(0, 100, (n, 1))
 15     x_23 = 10-np.random.randint(0, 10, (n, 1,))
 16 
 17     # print(x_11)
 18     # print(x_12)
 19     # print(x_13)
 20     # print(x_21)
 21     # print(x_22)
 22     # print(x_23)
 23 
 24     # rotate 45 degrees along the X axis
 25     new_x_12=x_12*np.sqrt(2)/2-x_13*np.sqrt(2)/2
 26     new_x_13 = x_12 * np.sqrt(2) / 2 + x_13 * np.sqrt(2) / 2
 27     new_x_22=x_22*np.sqrt(2)/2-x_23*np.sqrt(2)/2
 28     new_x_23 = x_22 * np.sqrt(2) / 2 + x_23 * np.sqrt(2) / 2
 29 
 30     # print(new_x_12)
 31     # print(new_x_13)
 32     # print(new_x_22)
 33     # print(new_x_23)
 34 
 35     plus_samples=np.hstack([x_11,new_x_12,new_x_13,np.ones((n,1))])
 36     minus_samples=np.hstack([x_11,new_x_22,new_x_23,-np.ones((n,1))])
 37     samples=np.vstack([plus_samples,minus_samples])
 38     # print(samples)
 39     np.random.shuffle(samples)
 40 
 41     # print(plus_samples)
 42     # print(minus_samples)
 43     # print(samples)
 44 
 45     return  samples
 46 
 47 def plot_samples(ax,samples):
 48     Y=samples[:,-1]
 49     Y=samples[:,-1]
 50     # print(Y)
 51     position_p=Y==1 ##the position of positve class
 52     position_m=Y==-1 ##the position of minus class
 53     # print(position_p)
 54     # print(position_m)
 55     ax.scatter(samples[position_p,0],samples[position_p,1],samples[position_p,2],marker='+',label="+",color='b')
 56     ax.scatter(samples[position_m,0],samples[position_m,1],samples[position_m,2],marker='^',label='-',color='y')
 57 
 58 def perceptron(train_data,eta,w_0,b_0):
 59     x=train_data[:,:-1] #x data
 60     y=train_data[:,-1] #corresponding classification
 61     length=train_data.shape[0] #the size of sample==the row number of the train_data
 62     w=w_0
 63     b=b_0
 64     step_num=0
 65     while True:
 66         i=0
 67         while(i<length): #traverse all sample points in a sample set
 68             step_num+=1
 69             x_i=x[i].reshape((x.shape[1],1))
 70             y_i=y[i]
 71             if y_i*(np.dot(np.transpose(w),x_i)+b)<=0: #the point is misclassified
 72                 w=w+eta*y_i*x_i #gradient descent
 73                 b=b+eta*y_i
 74                 break;#perform the next round of screening
 75             else: #the point is not a misclassification point select the next sample point
 76                 i=i+1
 77         if(i==length):
 78             break
 79     return (w,b,step_num)
 80 
 81 def creat_hyperplane(x,y,w,b):
 82     return (-w[0][0]*x-w[1][0]*y-b)/w[2][0]  #w0*x+w1*y+w2*z+b=0
 83 
 84 
 85 
 86 
 87 data=creat_data(100)
 88 eta,w_0,b_0=0.1,np.ones((3,1),dtype=float),1
 89 w,b,num=perceptron(data,eta,w_0,b_0)
 90 
 91 fig=plt.figure()
 92 plt.suptitle("perceptron")
 93 ax=Axes3D(fig)
 94 #draw samplt point
 95 plot_samples(ax,data)
 96 #draw hyperplane
 97 x=np.linspace(-30,100,100)
 98 y=np.linspace(-30,100,100)
 99 x,y=np.meshgrid(x,y)
100 z=creat_hyperplane(x,y,w,b)
101 ax.plot_surface(x,y,z,rstride=1,cstride=1,color='g',alpha=0.2)
102 
103 ax.legend(loc='best')
104 plt.show()

View Code