tensorflow使用CNN分析mnist手写体数字数据集

TensorFlow使用CNN分析MNIST手写数字数据集的完整攻略

本文将介绍如何使用TensorFlow和卷积神经网络(CNN)来分析MNIST手写数字数据集。本文重点介绍以下内容:

  • MNIST数据集的介绍
  • 构建CNN模型
  • 训练模型
  • 测试模型

MNIST数据集的介绍

MNIST数据集是一个手写数字数据集,包含60000张训练图像和10000张测试图像。每个图像都是28x28像素的灰度图像,表示0到9的数字。

我们可以使用TensorFlow的官方包keras来加载MNIST数据集。以下是加载代码示例:

from tensorflow import keras

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

构建CNN模型

我们可以使用keras构建CNN模型。以下是代码示例:

from tensorflow import keras

model = keras.Sequential([
    keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
    keras.layers.MaxPool2D(pool_size=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(10, activation='softmax')
])

上述模型包括:

  • 一个32个滤波器的卷积层
  • 一个2x2的最大池化层
  • 一个展开层
  • 一个10个神经元的输出层

训练模型

我们可以使用以下代码训练模型:

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))

上述代码中:

  • 我们使用Adam优化器和分类交叉熵损失函数编译模型
  • 我们使用训练集数据训练模型,每个epoch重复5次
  • 我们在测试集上验证我们的模型的表现

测试模型

使用以下代码测试模型,并计算测试准确率:

test_loss, test_acc = model.evaluate(x_test, y_test)
print('Test accuracy:', test_acc)

示例1

我们可以看一个基于上述代码的完整例子:

from tensorflow import keras

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# 将图像数据从0-255转换为0-1之间的浮点数
x_train = x_train / 255.0
x_test = x_test / 255.0

# 将数据重塑为适合CNN模型的形状
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)

model = keras.Sequential([
    keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
    keras.layers.MaxPool2D(pool_size=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))

test_loss, test_acc = model.evaluate(x_test, y_test)
print('Test accuracy:', test_acc)

上述代码会输出测试准确率,以及模型在验证集上的表现。

示例2

在上述例子中,我们使用了一个简单的CNN模型。我们也可以使用更深的CNN模型来提高准确率。例如,以下是一个更深的模型:

from tensorflow import keras

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# 将图像数据从0-255转换为0-1之间的浮点数
x_train = x_train / 255.0
x_test = x_test / 255.0

# 将数据重塑为适合CNN模型的形状
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)

model = keras.Sequential([
    keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
    keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
    keras.layers.MaxPool2D(pool_size=(2, 2)),
    keras.layers.Dropout(0.25),
    keras.layers.Flatten(),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))

test_loss, test_acc = model.evaluate(x_test, y_test)
print('Test accuracy:', test_acc)

上述代码中,我们使用了更深的CNN模型,包括:

  • 一个32个滤波器的卷积层
  • 一个64个滤波器的卷积层
  • 一个2x2的最大池化层
  • 一个dropout层
  • 一个展开层
  • 一个128个神经元的全连接层
  • 一个dropout层
  • 一个10个神经元的输出层

上述模型的测试准确率可以达到99%以上。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:tensorflow使用CNN分析mnist手写体数字数据集 - Python技术站

(0)
上一篇 2023年5月25日
下一篇 2023年5月25日

相关文章

  • 如何将tensorflow训练好的模型移植到Android (MNIST手写数字识别)

    关于如何将 TensorFlow 训练好的模型移植到 Android 上,我将分以下几个步骤进行介绍: 导出模型 在使用 TensorFlow 进行模型训练并完成后,需要将模型导出,以便在 Android 上进行使用。导出模型时,需要定义保存路径和需要导出的节点信息,示例代码如下: from tensorflow.python.framework impor…

    人工智能概论 2023年5月24日
    00
  • 详解Pytorch+PyG实现GCN过程示例

    详解Pytorch+PyG实现GCN过程示例 这篇攻略将会详细讲解如何使用PyTorch和PyG实现图卷积网络(GCN)。我们将通过两条示例说明如何使用PyG和PyTorch来实现GCN,并对代码进行详细分析。 简介 图形数据(或称为网络数据或图形数据)由许多顶点和边组成,这些组成通常是不规则的,图形中顶点之间的拓扑关系也是不规则的。GCN是一种用于处理图形…

    人工智能概论 2023年5月25日
    00
  • SpringCloud-Hystrix组件使用方法

    SpringCloud Hystrix 组件使用方法攻略 概述 SpringCloud Hystrix 组件是一个用于服务容错和限流的工具,用于帮助我们处理分布式系统的各种问题,提升系统的可用性、稳定性和弹性。本文将详细讲解 Hystrix 组件的使用方法,包括如何在项目中配置 Hystrix、如何编写 Hystrix Command、如何在 Feign 中…

    人工智能概览 2023年5月25日
    00
  • 易语言修改指定网页为浏览器主页的代码

    以下是详细讲解“易语言修改指定网页为浏览器主页的代码”的完整攻略。 1. 确认浏览器主页的配置文件路径 首先,我们需要确认浏览器主页的配置文件路径。以Chrome为例,Windows系统下Chrome的主页配置文件存放在C:\Users\{user}\AppData\Local\Google\Chrome\User Data\Default\Preferen…

    人工智能概论 2023年5月25日
    00
  • Bootstrap框架建立树形菜单(Tree)的实例代码

    Bootstrap框架提供了用于创建平台可用的用户界面组件的组合工具。其中之一就是树形菜单(Tree)。通过使用Bootstrap,我们可以从头开始创建一个完整的树形菜单,并将其集成到我们的网站或应用程序中。 以下是建立树形菜单的步骤: 1. 引入Bootstrap库和jQuery库 在标签中引入Bootstrap库和jQuery库。 <head&gt…

    人工智能概论 2023年5月25日
    00
  • Python变量作用域LEGB用法解析

    Python变量作用域LEGB用法解析 在Python中,变量作用域非常重要,它决定了变量的可见性和生命周期。在Python中,变量作用域可以分为四种:局部变量、嵌套作用域变量、全局变量和内置变量。在理解Python变量作用域之前,我们首先需要了解LEGB模型。 LEGB模型是Python中关于查找变量的顺序,其中LEGB分别代表(Local, Enclos…

    人工智能概览 2023年5月25日
    00
  • SpringBoot创建RSocket服务器的全过程记录

    下面是关于Spring Boot创建RSocket服务器的全过程记录。 RSocket简介 RSocket是一种基于Reactive Streams规范并且支持多种传输协议的全双工网络通信协议,可以实现高效、可扩展、快速启动的微服务通信。它由Netty、Reactor和Spring团队合作开发,提供Java、Kotlin和其他语言的客户端和服务器端实现,是S…

    人工智能概览 2023年5月25日
    00
  • Python Opencv中获取卷积核的实现代码

    获取卷积核可以通过在Python Opencv中使用getStructuringElement函数来实现。该函数用于获取不同形状和大小的结构元素或卷积核。 具体实现如下: 1. 获取矩形卷积核 如下为实现获取3*3矩形卷积核的代码示例: import cv2 kernel_rect = cv2.getStructuringElement(cv2.MORPH_…

    人工智能概论 2023年5月24日
    00
合作推广
合作推广
分享本页
返回顶部