pytorch cnn 识别手写的字实现自建图片数据

下面是详细的攻略:

简介

卷积神经网络(CNN)是一种在图像识别、语音识别和自然语言处理等领域广泛应用的深度学习算法。本文将介绍如何使用PyTorch实现一个CNN模型来识别手写字,并且展示如何通过自建图片数据进行训练和测试。

步骤

1. 准备自建图片数据

首先需要准备自建图片数据作为我们的训练集和测试集。这些图片应该是手写的数字,且需要分类为数字0到9的10个类别。每个类别应该包含足够数量的图片,以便模型可以充分学习区分不同数字的特征。

可以使用任何图片编辑软件来创建这些手写数字图片,例如Paint或GIMP。注意要将图片大小和分辨率保持一致,以便后续处理。

将这些图片按照不同数字分类,分别存储到对应文件夹内。例如,所有数字0的图片可以存储在名为“0”的文件夹中,所有数字1的图片可以存储在名为“1”的文件夹中,以此类推。

2. 加载和转换自建图片数据

使用torchvision库中的ImageFolder来加载自建的图片数据,该函数会自动将每个文件夹内的图片视为同一类别。可以使用transform来对图片进行预处理,例如缩放和剪裁。

下面是一个示例:

import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder

data_transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor()
])

train_dataset = ImageFolder('./train', transform=data_transform)
test_dataset = ImageFolder('./test', transform=data_transform)

这里使用了Resize将图片尺寸调整为28×28像素,ToTensor将图片转换为PyTorch中的张量。

3. 创建CNN模型

接下来需要创建一个CNN模型。可以使用PyTorch中的nn.Module类来构建模型。这里可以简单地使用两个卷积层和一个全连接层,具体如下:

import torch.nn as nn

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, 3, 1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.max_pool2d(x, 2)
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.max_pool2d(x, 2)
        x = x.view(-1, 64 * 7 * 7)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

这个模型包含两个卷积层和一个全连接层。第一个卷积层有32个输出通道,第二个卷积层有64个输出通道。每个卷积层都使用大小为3×3的卷积核,每个像素的步幅为1,并使用1个像素的填充。全连接层有128个隐藏单元,最终输出10个类别的概率。

4. 训练模型

现在可以开始训练模型了。首先需要定义损失函数和优化器。这里使用交叉熵作为损失函数,使用随机梯度下降作为优化器。

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

然后可以使用PyTorch中的DataLoader来加载自建图片数据,并使用上述定义的损失函数和优化器来训练模型。训练模型的代码如下:

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)

model = CNNModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:
            print('[%d, %5d] loss: %.3f' % (epoch+1, i+1, running_loss/100))
            running_loss = 0.0

训练模型需要遍历所有训练数据多次,每次遍历称为一个epoch。这里共训练了10个epoch。

5. 测试模型

训练完模型后,可以使用测试集来评估模型的性能。这里使用PyTorch中的accuracy_score函数来计算模型在测试集上的准确率。

from sklearn.metrics import accuracy_score

correct = 0
total = 0

with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))

最后,可以使用模型对单个手写数字图片进行预测。预测过程包括两个步骤:首先将图片转换为PyTorch张量,然后使用训练好的模型对该张量进行预测:

from PIL import Image

img = Image.open("5.png")
img = data_transform(img).unsqueeze(0)

outputs = model(img)
_, predicted = torch.max(outputs.data, 1)

print('The predicted digit is:', predicted[0])

这里使用了PIL库来打开单个图片,并使用上述定义的data_transform函数将其转换为PyTorch张量。需要注意,使用unsqueeze将单个张量转换为批量大小为1的张量,在预测时需要用到。

示例说明

下面给出两个示例,分别演示如何创建CNN模型和如何对单个手写数字进行预测。

示例1:创建CNN模型

import torch.nn as nn

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, 3, 1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.max_pool2d(x, 2)
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.max_pool2d(x, 2)
        x = x.view(-1, 64 * 7 * 7)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

该模型包含两个卷积层和一个全连接层。第一个卷积层有32个输出通道,第二个卷积层有64个输出通道。每个卷积层都使用大小为3×3的卷积核,每个像素的步幅为1,并使用1个像素的填充。全连接层有128个隐藏单元,最终输出10个类别的概率。

示例2:对单个手写数字进行预测

from PIL import Image

img = Image.open("5.png")
img = data_transform(img).unsqueeze(0)

outputs = model(img)
_, predicted = torch.max(outputs.data, 1)

print('The predicted digit is:', predicted[0])

这个示例使用了PIL库打开了一个名为“5.png”的图片,并使用上述定义的data_transform函数将其转换为PyTorch张量。将该张量转换为批量大小为1的张量,并使用已训练好的CNN模型对其进行预测,最终输出预测的数字。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch cnn 识别手写的字实现自建图片数据 - Python技术站

(0)
上一篇 2023年6月6日
下一篇 2023年6月6日

相关文章

  • python_array[0][0]与array[0,0]的区别详解

    让我们先来看看两者的区别。 在Python中,可以使用多种方式来表示数组。其中,有一种方式是使用列表(List)创建多维数组,这种数组被称为Python List Array或Python内置数组(Python Built-in Array)。这种数组是Python标准库中“array”模块中提供的,其使用方式与列表类似。对于这种数组,我们可以使用下面两种方…

    python 2023年6月5日
    00
  • Python爬取数据保存为Json格式的代码示例

    下面我将为你详细讲解“Python爬取数据保存为Json格式的代码示例”的完整攻略。 一、前置知识 在介绍代码实现之前,我们需要了解一些前置知识: requests库:用于向网站发起HTTP请求并获取响应; json模块:用于将Python数据(如列表、字典)转换为Json格式的字符串,并将Json格式的字符串解析为Python对象; 爬虫基础知识:了解如何…

    python 2023年6月3日
    00
  • python实现文本界面网络聊天室

    Python实现文本界面网络聊天室攻略 介绍 网络聊天室是一种将多用户连接到同一聊天室中进行实时通信的应用程序。在这种聊天室中,用户可以向其他用户发送消息,并从其他用户处接收消息。在本文中,我们将使用Python编写一个文本界面网络聊天室。 步骤 第一步:创建服务端 服务端是聊天室中的核心组件,负责监听客户端的连接请求,并将消息转发给其他客户端。在Pytho…

    python 2023年5月30日
    00
  • 利用Python实现Windows定时关机功能

    利用Python实现Windows定时关机功能攻略 一、安装Python 首先我们需要在Windows系统中安装Python,可以从官方网站 https://www.python.org/downloads/ 下载,选择适合自己系统的版本,然后按照默认设置安装即可。 二、编写Python脚本 在安装完Python之后,我们可以使用任意文本编辑器,比如Note…

    python 2023年5月23日
    00
  • 浅谈Python几种常见的归一化方法

    浅谈Python几种常见的归一化方法 在机器学习中,归一化是一种常用的数据预处理技术,其目的是将不同量纲的特征值缩放到相同的范内,以便更好地进行模型训练和预测。本文将介绍Python中几种常见的归一化方法,并提供两个示例说明。 1. Min-Max归一化 Min-Max归一化是一种常用的线性归一化方法,其公式如下: $${norm} = \frac{x – …

    python 2023年5月14日
    00
  • python中的__dict__属性介绍

    当我们在Python中创建对象时,每个对象都有一个名为 dict 的属性,它是一个字典,其中存储了该对象的所有类属性和实例属性。我们可以使用该属性来访问、添加或修改对象中的属性。 __dict__属性的访问 我们可以使用以下方式访问任意对象的__dict__属性: obj.__dict__ 其中,obj是待访问的对象名。 例如,我们定义一个类 Person,…

    python 2023年5月13日
    00
  • Flask 上下文是什么 ?

    哈喽大家好,我是咸鱼。今天我们来聊聊什么是 Flask 上下文   咸鱼在刚接触到这个概念的时候脑子里蹦出的第一个词是 CPU 上下文   今天咸鱼希望通过这篇文章,让大家能够对 Flask 上下文设计的初衷以及应用有一个基本的了解   Flask 上下文 我们在使用 Flask 开发 web 程序的时候,通常会面临下面的情况     假设同一时间内有三台客…

    python 2023年4月22日
    00
  • Python自动重试HTTP连接装饰器

    一、什么是Python自动重试HTTP连接装饰器? Python自动重试HTTP连接装饰器即为一个能够在HTTP连接失败时自动重试的Python函数装饰器。该装饰器会在装饰的函数执行过程中,对HTTP请求返回的状态进行判断,并在需要时自动发起新的请求。这样,可以保证当HTTP连接出现故障时,程序不会因此而直接崩溃,而是能够进行自我修复,从而提高程序的稳定性和…

    python 2023年5月13日
    00
合作推广
合作推广
分享本页
返回顶部