tensorflow使用CNN分析mnist手写体数字数据集

yizhihongxing

TensorFlow使用CNN分析MNIST手写数字数据集的完整攻略

本文将介绍如何使用TensorFlow和卷积神经网络(CNN)来分析MNIST手写数字数据集。本文重点介绍以下内容:

  • MNIST数据集的介绍
  • 构建CNN模型
  • 训练模型
  • 测试模型

MNIST数据集的介绍

MNIST数据集是一个手写数字数据集,包含60000张训练图像和10000张测试图像。每个图像都是28x28像素的灰度图像,表示0到9的数字。

我们可以使用TensorFlow的官方包keras来加载MNIST数据集。以下是加载代码示例:

from tensorflow import keras

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

构建CNN模型

我们可以使用keras构建CNN模型。以下是代码示例:

from tensorflow import keras

model = keras.Sequential([
    keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
    keras.layers.MaxPool2D(pool_size=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(10, activation='softmax')
])

上述模型包括:

  • 一个32个滤波器的卷积层
  • 一个2x2的最大池化层
  • 一个展开层
  • 一个10个神经元的输出层

训练模型

我们可以使用以下代码训练模型:

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))

上述代码中:

  • 我们使用Adam优化器和分类交叉熵损失函数编译模型
  • 我们使用训练集数据训练模型,每个epoch重复5次
  • 我们在测试集上验证我们的模型的表现

测试模型

使用以下代码测试模型,并计算测试准确率:

test_loss, test_acc = model.evaluate(x_test, y_test)
print('Test accuracy:', test_acc)

示例1

我们可以看一个基于上述代码的完整例子:

from tensorflow import keras

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# 将图像数据从0-255转换为0-1之间的浮点数
x_train = x_train / 255.0
x_test = x_test / 255.0

# 将数据重塑为适合CNN模型的形状
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)

model = keras.Sequential([
    keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
    keras.layers.MaxPool2D(pool_size=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))

test_loss, test_acc = model.evaluate(x_test, y_test)
print('Test accuracy:', test_acc)

上述代码会输出测试准确率,以及模型在验证集上的表现。

示例2

在上述例子中,我们使用了一个简单的CNN模型。我们也可以使用更深的CNN模型来提高准确率。例如,以下是一个更深的模型:

from tensorflow import keras

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# 将图像数据从0-255转换为0-1之间的浮点数
x_train = x_train / 255.0
x_test = x_test / 255.0

# 将数据重塑为适合CNN模型的形状
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)

model = keras.Sequential([
    keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
    keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
    keras.layers.MaxPool2D(pool_size=(2, 2)),
    keras.layers.Dropout(0.25),
    keras.layers.Flatten(),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))

test_loss, test_acc = model.evaluate(x_test, y_test)
print('Test accuracy:', test_acc)

上述代码中,我们使用了更深的CNN模型,包括:

  • 一个32个滤波器的卷积层
  • 一个64个滤波器的卷积层
  • 一个2x2的最大池化层
  • 一个dropout层
  • 一个展开层
  • 一个128个神经元的全连接层
  • 一个dropout层
  • 一个10个神经元的输出层

上述模型的测试准确率可以达到99%以上。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:tensorflow使用CNN分析mnist手写体数字数据集 - Python技术站

(0)
上一篇 2023年5月25日
下一篇 2023年5月25日

相关文章

  • Django框架的中的setting.py文件说明详解

    Django框架的settings.py文件是Django应用程序配置的核心文件之一。在该文件中,您可以设置各种设置,例如数据库连接、静态文件路径、中间件等等。 以下是对settings.py文件的详细说明: 应用程序配置 DEBUG: 此选项是用于在开发过程中启用或禁用调试模式。如果将其设置为True,则会显示有关代码错误的详细信息。在生产环境中,它应该始…

    人工智能概览 2023年5月25日
    00
  • nginx的zabbix 5.0安装部署的方法步骤

    下面我会详细讲解nginx的zabbix 5.0安装部署的方法步骤,包括安装nginx、安装zabbix server和zabbix agent,同时给出两条示例说明。 一、安装nginx 1. 安装依赖项 Nginx需要一些依赖项进行安装。 yum install -y gcc pcre-devel zlib-devel make openssl-deve…

    人工智能概览 2023年5月25日
    00
  • 什么是MEAN?JavaScript编程中的MEAN是什么意思?

    MEAN是JavaScript编程中的一个技术栈,它包含了四个技术领域的理念:MongoDB、Express.js、AngularJS、Node.js。下面我来详细讲解一下这四个技术领域对于MEAN的意义和重要作用。 MongoDB MongoDB是一个面向文档的数据库,可以帮助我们存储和管理数据。它非常灵活,可以处理非结构化数据和大规模数据。在MEAN技术…

    人工智能概论 2023年5月24日
    00
  • tensorflow学习笔记之mnist的卷积神经网络实例

    TensorFlow学习笔记之MNIST的卷积神经网络实例 随着深度学习的普及,卷积神经网络已成为图像和视觉任务中最常用的模型之一。在这篇文章中,我们将介绍如何使用Tensorflow创建一个基本的卷积神经网络(CNN)模型来处理MNIST数据集。 1. MNIST数据集 手写数字识别数据集MNIST是一个广泛使用的数据集,它包含60,000个训练样本和10…

    人工智能概论 2023年5月25日
    00
  • 在CentOS下使用Munin来监控服务器运行的方法

    下面是在CentOS下使用Munin来监控服务器运行的完整攻略: 1. 安装Munin Munin是一个开源的监控系统,可以监控服务器的资源使用情况。我们可以通过yum命令来安装Munin: sudo yum install munin munin-node -y 2. 配置Munin Munin的配置文件位于/etc/munin目录下,我们可以在此目录下找…

    人工智能概览 2023年5月25日
    00
  • Python定时任务工具之APScheduler使用方式

    下面给你讲解 “Python定时任务工具之APScheduler使用方式” 的完整攻略。 一、概述 在Python中,可以使用APScheduler来进行定时任务的管理和调度。APScheduler支持多种任务触发器,例如:间隔时间触发器、定时时间触发器、日期时间触发器等。同时,APScheduler还支持多种任务执行器,例如:进程池执行器、线程池执行器、协…

    人工智能概览 2023年5月25日
    00
  • python Web开发你要理解的WSGI & uwsgi详解

    让我详细讲解一下“Python Web开发你要理解的WSGI & uWSGI详解”攻略。 WSGI简介 WSGI是Web服务器网关接口(Web Server Gateway Interface)的缩写。WSGI是一种Web服务器和Web应用程序(如Python程序)之间通信的标准接口。 WSGI规范定义了Web服务器和Web应用程序之间的接口,使得开…

    人工智能概览 2023年5月25日
    00
  • 对python中的six.moves模块的下载函数urlretrieve详解

    对python中的six.moves模块的下载函数urlretrieve详解 介绍 six.moves是由six模块提供的一个适用于Python 2和3的兼容性工具,致力于让开发者在Python 2/3之间轻松移植。常用的六个子模块:- builtins- configparser- http_client- urllib- queue- xrange si…

    人工智能概览 2023年5月25日
    00
合作推广
合作推广
分享本页
返回顶部