MNIST数据处理
包含60000张数据作为训练数据,10000张数据数据作为测试数据。数据集中的每一张图片都代表了0-9中的一个数字,图片大小都为28*28,数字都在图片的正中间。tensorflow提供了一个类来处理MNIST数据。
one-hot编码
One-Hot编码,又称为一位有效编码,主要是采用N位状态寄存器来对N个状态进行编码,并且在任意时候只有一位有效。One-Hot编码是分类变量作为二进制向量的表示。这首先要求将分类值映射到整数值。然后,每个整数值被表示为二进制向量,除了整数的索引之外,它都是零值,它被标记为1。
1 from tensorflow.examples.tutorials.mnist import input_data 2 3 mnist = input_data.read_data_sets(\'path/to/MNIST_data/\',one_hot=True) #得到一个mnist类 4 5 print (\'Training data size:\',mnist.train.num_examples) 6 print(\'Validation data size:\',mnist.validation.num_examples) 7 print(\'Testing data size:\',mnist.test.num_examples) 8 9 a = mnist.train.images[0] 10 #28*28=784=size,像素矩阵的取值范围为[0,1],代表颜色深浅,0代表白色背景,1代表黑色背景 11 print(\'Example training data:\',a,\'\n\',\'size:\',a.shape) 12 13 b = mnist.train.labels[0] 14 print(\'Example training data label:\',\'\n\',b,\'size:\',b.shape)
书上将滑动平均,衰减的学习率,正则化都用到了模型训练中,并进行了预测准确率对比。
此处我都没有使用这些技巧,进行了最简单的训练,代码如下:
1 # -*- coding: utf-8 -*- 2 """ 3 Created on Wed Oct 23 21:25:30 2019 4 5 创建一个三层的全连接网络,进行数字识别。 6 采用了衰减学习率、L2正则化、滑动平均模型来训练网络 7 """ 8 import tensorflow as tf 9 from tensorflow.examples.tutorials.mnist import input_data 10 11 INPUT_NODE = 784 12 OUTPUT_NODE = 10 13 #隐藏层节点数 14 LAYER1_NODE = 500 15 BATCH_SIZE = 100 16 LEARNING_RATE_BASE = 0.8 17 LEARNING_RATE_DECAY = 0.99 18 REGULARIZATION_RATE = 0.0001 19 TRAINING_STEPS = 30000 20 MOVING_AVERAGE_DECAY = 0.99 21 learning_rate = 0.001 22 #定义网络的前向传播,avg_class为滑动平均类 23 def feed_forward(input_tensor,avg_class,weights1,bias1,weights2,bias2): 24 if avg_class==None: 25 layer1 = tf.nn.relu(tf.matmul(input_tensor,weights1) + bias1) 26 return tf.matmul(layer1,weights2) + bias2 27 else: 28 layer1 = tf.nn.relu(tf.matmul(input_tensor,avg_class.average(weights1)) + avg_class.average(bias1)) 29 return tf.matmul(layer1,avg_class.average(weights2)) + avg_class.average(bias2) 30 31 #训练模型的过程 32 #def train(mnist): 33 x = tf.placeholder(tf.float32,[None,INPUT_NODE],name=\'x-input\') 34 y_ = tf.placeholder(tf.float32,[None,OUTPUT_NODE],name=\'Y-output\') 35 #生成隐藏层的参数 使用tf.random_normal 36 # tf.truncated_normal的不同之处在于其平均值大于 2 个标准差的值将被丢弃并重新选择 37 weights1 = tf.Variable(tf.random_normal([INPUT_NODE,LAYER1_NODE],dtype=tf.float32)) 38 bias1 = tf.Variable(tf.constant(0.1,shape=[LAYER1_NODE])) 39 weights2 = tf.Variable(tf.random_normal([LAYER1_NODE,OUTPUT_NODE],dtype=tf.float32)) 40 bias2 = tf.Variable(tf.constant(0.1,shape=[OUTPUT_NODE])) 41 y = feed_forward(x,None,weights1,bias1,weights2,bias2) 42 global_step = tf.Variable(0,trainable=False) 43 44 #滑动平均部分,滑动平均类.apply(variable)计算滑动平均值,滑动平均类.average(variable)调用滑动平均值? 45 #tf.argmax(tensor,axis) 返回axis维上tensor最大值的下标 46 47 #计算L2正则化损失函数,一般只计算神经网络边上权重的正则化损失,不计算偏置项的 48 #labels的秩应该等于logits的秩-1 49 cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(y_,1),logits=y) 50 loss = cross_entropy 51 52 train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step = global_step) 53 54 #若采用滑动平均,每过一遍数据既需要通过反向传播来更新神经网络中的参数,又要更新每一个参数的滑动平均值。 55 #为了一次完成多个操作,tensorflow提供了tf.group和tf.contol_dependencies机制 56 57 #tf.cast(True,tf.float32)=1.0 tf.cast(False,tf.float32)=0.0 58 correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1)) 59 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) 60 61 with tf.Session() as sess: 62 tf.initialize_all_variables().run() 63 64 validate_feed = {x:mnist.validation.images,y_:mnist.validation.labels} 65 test_feed = {x:mnist.test.images,y_:mnist.test.labels} 66 67 for i in range(TRAINING_STEPS): 68 if i % 1000 == 0: 69 validation_acc = sess.run(accuracy,feed_dict=validate_feed) 70 print(\'after %d training steps,validation accuracy is %g\'%(i,validation_acc)) 71 #mnist类提供了next_batch函数 72 xs,ys = mnist.train.next_batch(BATCH_SIZE) 73 sess.run(train_step,feed_dict={x:xs,y_:ys}) 74 75 #训练结束之后,在测试集上进行测试 76 test_acc = sess.run(accuracy,feed_dict=test_feed) 77 print(\'after %d training steps,test accuracy is %g\'%(TRAINING_STEPS,test_acc)) 78 79 #主程序 80 #def main(argv=None): 81 # mnist = input_data.read_data_sets(\'/mnistdata\',one_hot = True) 82 #train(mnist) 83 84 #if __name__ == \'__main__\': 85 # tf.app.run() 86
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:《tensorflow实战Google深度学习框架》第五章mnist数字识别问题 - Python技术站