MNIST数据处理

包含60000张数据作为训练数据,10000张数据数据作为测试数据。数据集中的每一张图片都代表了0-9中的一个数字,图片大小都为28*28,数字都在图片的正中间。tensorflow提供了一个类来处理MNIST数据。

one-hot编码

One-Hot编码,又称为一位有效编码,主要是采用N位状态寄存器来对N个状态进行编码,并且在任意时候只有一位有效。One-Hot编码是分类变量作为二进制向量的表示。这首先要求将分类值映射到整数值。然后,每个整数值被表示为二进制向量,除了整数的索引之外,它都是零值,它被标记为1。

 1 from tensorflow.examples.tutorials.mnist import input_data
 2 
 3 mnist = input_data.read_data_sets(\'path/to/MNIST_data/\',one_hot=True) #得到一个mnist类
 4 
 5 print (\'Training data size:\',mnist.train.num_examples)
 6 print(\'Validation data size:\',mnist.validation.num_examples)
 7 print(\'Testing data size:\',mnist.test.num_examples)
 8 
 9 a = mnist.train.images[0]
10 #28*28=784=size,像素矩阵的取值范围为[0,1],代表颜色深浅,0代表白色背景,1代表黑色背景
11 print(\'Example training data:\',a,\'\n\',\'size:\',a.shape)   
12 
13 b = mnist.train.labels[0]
14 print(\'Example training data label:\',\'\n\',b,\'size:\',b.shape)

书上将滑动平均,衰减的学习率,正则化都用到了模型训练中,并进行了预测准确率对比。
此处我都没有使用这些技巧,进行了最简单的训练,代码如下:

 1 # -*- coding: utf-8 -*-
 2 """
 3 Created on Wed Oct 23 21:25:30 2019
 4 
 5 创建一个三层的全连接网络,进行数字识别。
 6 采用了衰减学习率、L2正则化、滑动平均模型来训练网络
 7 """
 8 import tensorflow as tf
 9 from tensorflow.examples.tutorials.mnist import input_data
10 
11 INPUT_NODE = 784
12 OUTPUT_NODE = 10
13 #隐藏层节点数
14 LAYER1_NODE = 500
15 BATCH_SIZE = 100
16 LEARNING_RATE_BASE = 0.8
17 LEARNING_RATE_DECAY = 0.99   
18 REGULARIZATION_RATE = 0.0001
19 TRAINING_STEPS = 30000
20 MOVING_AVERAGE_DECAY = 0.99
21 learning_rate = 0.001
22 #定义网络的前向传播,avg_class为滑动平均类
23 def feed_forward(input_tensor,avg_class,weights1,bias1,weights2,bias2):
24     if avg_class==None:
25         layer1 = tf.nn.relu(tf.matmul(input_tensor,weights1) + bias1)
26         return tf.matmul(layer1,weights2) + bias2
27     else:
28         layer1 = tf.nn.relu(tf.matmul(input_tensor,avg_class.average(weights1)) + avg_class.average(bias1))
29         return tf.matmul(layer1,avg_class.average(weights2)) + avg_class.average(bias2)
30     
31 #训练模型的过程
32 #def train(mnist):
33 x = tf.placeholder(tf.float32,[None,INPUT_NODE],name=\'x-input\')
34 y_ = tf.placeholder(tf.float32,[None,OUTPUT_NODE],name=\'Y-output\')
35 #生成隐藏层的参数  使用tf.random_normal 
36 #    tf.truncated_normal的不同之处在于其平均值大于 2 个标准差的值将被丢弃并重新选择
37 weights1 = tf.Variable(tf.random_normal([INPUT_NODE,LAYER1_NODE],dtype=tf.float32))
38 bias1 = tf.Variable(tf.constant(0.1,shape=[LAYER1_NODE]))
39 weights2 = tf.Variable(tf.random_normal([LAYER1_NODE,OUTPUT_NODE],dtype=tf.float32))
40 bias2 = tf.Variable(tf.constant(0.1,shape=[OUTPUT_NODE]))
41 y = feed_forward(x,None,weights1,bias1,weights2,bias2)
42 global_step = tf.Variable(0,trainable=False)
43 
44 #滑动平均部分,滑动平均类.apply(variable)计算滑动平均值,滑动平均类.average(variable)调用滑动平均值?
45 #tf.argmax(tensor,axis)  返回axis维上tensor最大值的下标
46 
47 #计算L2正则化损失函数,一般只计算神经网络边上权重的正则化损失,不计算偏置项的
48 #labels的秩应该等于logits的秩-1
49 cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(y_,1),logits=y)
50 loss = cross_entropy
51 
52 train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step = global_step)
53 
54 #若采用滑动平均,每过一遍数据既需要通过反向传播来更新神经网络中的参数,又要更新每一个参数的滑动平均值。
55 #为了一次完成多个操作,tensorflow提供了tf.group和tf.contol_dependencies机制
56 
57 #tf.cast(True,tf.float32)=1.0  tf.cast(False,tf.float32)=0.0  
58 correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
59 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
60 
61 with tf.Session() as sess:
62     tf.initialize_all_variables().run()
63     
64     validate_feed = {x:mnist.validation.images,y_:mnist.validation.labels}
65     test_feed = {x:mnist.test.images,y_:mnist.test.labels}
66     
67     for i in range(TRAINING_STEPS):
68         if i % 1000 == 0:
69             validation_acc = sess.run(accuracy,feed_dict=validate_feed)
70             print(\'after %d training steps,validation accuracy is %g\'%(i,validation_acc))
71         #mnist类提供了next_batch函数
72         xs,ys = mnist.train.next_batch(BATCH_SIZE)
73         sess.run(train_step,feed_dict={x:xs,y_:ys})
74     
75     #训练结束之后,在测试集上进行测试
76     test_acc = sess.run(accuracy,feed_dict=test_feed)
77     print(\'after %d training steps,test accuracy is %g\'%(TRAINING_STEPS,test_acc))
78         
79 #主程序
80 #def main(argv=None):
81 #    mnist = input_data.read_data_sets(\'/mnistdata\',one_hot = True)
82 #train(mnist)    
83 
84 #if __name__ == \'__main__\':
85 #    tf.app.run()
86