tensorflow使用CNN分析mnist手写体数字数据集

TensorFlow使用CNN分析MNIST手写数字数据集的完整攻略

本文将介绍如何使用TensorFlow和卷积神经网络(CNN)来分析MNIST手写数字数据集。本文重点介绍以下内容:

  • MNIST数据集的介绍
  • 构建CNN模型
  • 训练模型
  • 测试模型

MNIST数据集的介绍

MNIST数据集是一个手写数字数据集,包含60000张训练图像和10000张测试图像。每个图像都是28x28像素的灰度图像,表示0到9的数字。

我们可以使用TensorFlow的官方包keras来加载MNIST数据集。以下是加载代码示例:

from tensorflow import keras

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

构建CNN模型

我们可以使用keras构建CNN模型。以下是代码示例:

from tensorflow import keras

model = keras.Sequential([
    keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
    keras.layers.MaxPool2D(pool_size=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(10, activation='softmax')
])

上述模型包括:

  • 一个32个滤波器的卷积层
  • 一个2x2的最大池化层
  • 一个展开层
  • 一个10个神经元的输出层

训练模型

我们可以使用以下代码训练模型:

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))

上述代码中:

  • 我们使用Adam优化器和分类交叉熵损失函数编译模型
  • 我们使用训练集数据训练模型,每个epoch重复5次
  • 我们在测试集上验证我们的模型的表现

测试模型

使用以下代码测试模型,并计算测试准确率:

test_loss, test_acc = model.evaluate(x_test, y_test)
print('Test accuracy:', test_acc)

示例1

我们可以看一个基于上述代码的完整例子:

from tensorflow import keras

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# 将图像数据从0-255转换为0-1之间的浮点数
x_train = x_train / 255.0
x_test = x_test / 255.0

# 将数据重塑为适合CNN模型的形状
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)

model = keras.Sequential([
    keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
    keras.layers.MaxPool2D(pool_size=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))

test_loss, test_acc = model.evaluate(x_test, y_test)
print('Test accuracy:', test_acc)

上述代码会输出测试准确率,以及模型在验证集上的表现。

示例2

在上述例子中,我们使用了一个简单的CNN模型。我们也可以使用更深的CNN模型来提高准确率。例如,以下是一个更深的模型:

from tensorflow import keras

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# 将图像数据从0-255转换为0-1之间的浮点数
x_train = x_train / 255.0
x_test = x_test / 255.0

# 将数据重塑为适合CNN模型的形状
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)

model = keras.Sequential([
    keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
    keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
    keras.layers.MaxPool2D(pool_size=(2, 2)),
    keras.layers.Dropout(0.25),
    keras.layers.Flatten(),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))

test_loss, test_acc = model.evaluate(x_test, y_test)
print('Test accuracy:', test_acc)

上述代码中,我们使用了更深的CNN模型,包括:

  • 一个32个滤波器的卷积层
  • 一个64个滤波器的卷积层
  • 一个2x2的最大池化层
  • 一个dropout层
  • 一个展开层
  • 一个128个神经元的全连接层
  • 一个dropout层
  • 一个10个神经元的输出层

上述模型的测试准确率可以达到99%以上。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:tensorflow使用CNN分析mnist手写体数字数据集 - Python技术站

(0)
上一篇 2023年5月25日
下一篇 2023年5月25日

相关文章

  • 浅谈Python3.10 和 Python3.9 之间的差异

    浅谈Python3.10 和 Python3.9 之间的差异 Python是一门高级编程语言,它在不断地发展中,不同版本之间会存在差异。本文将重点介绍Python3.10和Python3.9之间的差异。 新特性 Python3.10引入了很多新特性,以下是几个值得关注的特性。 格式字符串的新特性 Python3.10中,格式字符串支持未命名参数。例如: na…

    人工智能概览 2023年5月25日
    00
  • 一篇文章带你了解Python中的装饰器

    一篇文章带你了解Python中的装饰器 装饰器是什么? 装饰器(Decorator)是Python中非常有用的一个函数特性,其主要作用是用于在代码运行时增强函数或类的功能,而不需要对其代码进行修改。 简单来说,装饰器就是一个函数,其参数是另一个函数或者类,其主要目的是用于修改其他函数或者类的行为。 装饰器函数的定义 一个装饰器函数的定义通常遵循以下步骤: 定…

    人工智能概论 2023年5月25日
    00
  • Python实现对桌面进行实时捕捉画面的方法详解

    下面就为您详细讲解“Python实现对桌面进行实时捕捉画面的方法详解”的完整攻略。 1. 确认环境 在使用Python进行桌面画面捕捉之前,需要确认开发环境是否准备齐全。 首先,需要安装好Python开发环境。可以从官网 https://www.python.org/downloads/ 下载安装Python,建议选择最新的稳定版本,并勾选“Add Pyth…

    人工智能概论 2023年5月25日
    00
  • .Net Core如何对MongoDB执行多条件查询

    针对.Net Core如何对MongoDB执行多条件查询,我提供如下攻略: 1. 安装MongoDB.Driver 首先需要引入 MongoDB.Driver,可以通过NuGet安装,也可以手动引入。 2. 实例化MongoClient 其次需要实例化 MongoClient,并且可以连接相应的MongoDB。 var client = new MongoC…

    人工智能概论 2023年5月25日
    00
  • Spring Cloud Eureka服务治理的实现

    Spring Cloud Eureka服务治理的实现 Spring Cloud Eureka是SpringCloud的子项目之一,用于实现服务治理。服务治理是SpringCloud微服务核心思想之一,其主要目的是协调各个微服务之间的通信,以便于负载均衡、故障恢复、服务升级等。在此文档中,我们将详细讲解“Spring Cloud Eureka服务治理的实现”的…

    人工智能概览 2023年5月25日
    00
  • AngularJS轻松实现双击排序的功能

    下面是“AngularJS轻松实现双击排序的功能”的完整攻略: 1. 概述 在AngularJS中实现双击排序的功能可以通过使用ng-repeat、ng-click和双击事件结合起来实现。其中ng-repeat用于循环生成视图,ng-click用于处理排序事件,双击事件用于响应用户的行为。 2. 示例说明 下面是两个示例,分别演示了如何使用AngularJS…

    人工智能概论 2023年5月24日
    00
  • Python中asyncio与aiohttp入门教程

    那么让我们开始吧! Python中asyncio与aiohttp入门教程 什么是异步编程? 在传统的同步编程中,程序在执行某个操作时需要等待其完成才能进行下一步操作。而在异步编程中,程序在执行某个操作时可以先转而去做其他事情,等到该操作完成后再回来继续执行原来的操作。这种非阻塞式的执行方式可以让程序更高效地利用时间。 Python提供了一个用于异步编程的标准…

    人工智能概论 2023年5月25日
    00
  • Django实现发送邮件找回密码功能

    下面我将为您详细讲解“Django实现发送邮件找回密码功能”的完整攻略。 1.安装依赖Django自带有邮件发送功能,但是需要SMTP的支持,因此需要在项目中安装django-smtp-ssl库来发送邮件。可以使用以下命令进行安装: pip install django-smtp-ssl 配置邮件发送 在settings.py文件中添加SMTP的配置信息,代…

    人工智能概览 2023年5月25日
    00
合作推广
合作推广
分享本页
返回顶部