pytorch cnn 识别手写的字实现自建图片数据

下面是详细的攻略:

简介

卷积神经网络(CNN)是一种在图像识别、语音识别和自然语言处理等领域广泛应用的深度学习算法。本文将介绍如何使用PyTorch实现一个CNN模型来识别手写字,并且展示如何通过自建图片数据进行训练和测试。

步骤

1. 准备自建图片数据

首先需要准备自建图片数据作为我们的训练集和测试集。这些图片应该是手写的数字,且需要分类为数字0到9的10个类别。每个类别应该包含足够数量的图片,以便模型可以充分学习区分不同数字的特征。

可以使用任何图片编辑软件来创建这些手写数字图片,例如Paint或GIMP。注意要将图片大小和分辨率保持一致,以便后续处理。

将这些图片按照不同数字分类,分别存储到对应文件夹内。例如,所有数字0的图片可以存储在名为“0”的文件夹中,所有数字1的图片可以存储在名为“1”的文件夹中,以此类推。

2. 加载和转换自建图片数据

使用torchvision库中的ImageFolder来加载自建的图片数据,该函数会自动将每个文件夹内的图片视为同一类别。可以使用transform来对图片进行预处理,例如缩放和剪裁。

下面是一个示例:

import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder

data_transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor()
])

train_dataset = ImageFolder('./train', transform=data_transform)
test_dataset = ImageFolder('./test', transform=data_transform)

这里使用了Resize将图片尺寸调整为28×28像素,ToTensor将图片转换为PyTorch中的张量。

3. 创建CNN模型

接下来需要创建一个CNN模型。可以使用PyTorch中的nn.Module类来构建模型。这里可以简单地使用两个卷积层和一个全连接层,具体如下:

import torch.nn as nn

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, 3, 1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.max_pool2d(x, 2)
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.max_pool2d(x, 2)
        x = x.view(-1, 64 * 7 * 7)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

这个模型包含两个卷积层和一个全连接层。第一个卷积层有32个输出通道,第二个卷积层有64个输出通道。每个卷积层都使用大小为3×3的卷积核,每个像素的步幅为1,并使用1个像素的填充。全连接层有128个隐藏单元,最终输出10个类别的概率。

4. 训练模型

现在可以开始训练模型了。首先需要定义损失函数和优化器。这里使用交叉熵作为损失函数,使用随机梯度下降作为优化器。

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

然后可以使用PyTorch中的DataLoader来加载自建图片数据,并使用上述定义的损失函数和优化器来训练模型。训练模型的代码如下:

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)

model = CNNModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:
            print('[%d, %5d] loss: %.3f' % (epoch+1, i+1, running_loss/100))
            running_loss = 0.0

训练模型需要遍历所有训练数据多次,每次遍历称为一个epoch。这里共训练了10个epoch。

5. 测试模型

训练完模型后,可以使用测试集来评估模型的性能。这里使用PyTorch中的accuracy_score函数来计算模型在测试集上的准确率。

from sklearn.metrics import accuracy_score

correct = 0
total = 0

with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))

最后,可以使用模型对单个手写数字图片进行预测。预测过程包括两个步骤:首先将图片转换为PyTorch张量,然后使用训练好的模型对该张量进行预测:

from PIL import Image

img = Image.open("5.png")
img = data_transform(img).unsqueeze(0)

outputs = model(img)
_, predicted = torch.max(outputs.data, 1)

print('The predicted digit is:', predicted[0])

这里使用了PIL库来打开单个图片,并使用上述定义的data_transform函数将其转换为PyTorch张量。需要注意,使用unsqueeze将单个张量转换为批量大小为1的张量,在预测时需要用到。

示例说明

下面给出两个示例,分别演示如何创建CNN模型和如何对单个手写数字进行预测。

示例1:创建CNN模型

import torch.nn as nn

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, 3, 1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.max_pool2d(x, 2)
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.max_pool2d(x, 2)
        x = x.view(-1, 64 * 7 * 7)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

该模型包含两个卷积层和一个全连接层。第一个卷积层有32个输出通道,第二个卷积层有64个输出通道。每个卷积层都使用大小为3×3的卷积核,每个像素的步幅为1,并使用1个像素的填充。全连接层有128个隐藏单元,最终输出10个类别的概率。

示例2:对单个手写数字进行预测

from PIL import Image

img = Image.open("5.png")
img = data_transform(img).unsqueeze(0)

outputs = model(img)
_, predicted = torch.max(outputs.data, 1)

print('The predicted digit is:', predicted[0])

这个示例使用了PIL库打开了一个名为“5.png”的图片,并使用上述定义的data_transform函数将其转换为PyTorch张量。将该张量转换为批量大小为1的张量,并使用已训练好的CNN模型对其进行预测,最终输出预测的数字。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch cnn 识别手写的字实现自建图片数据 - Python技术站

(0)
上一篇 2023年6月6日
下一篇 2023年6月6日

相关文章

  • python如何将一个四位数反向输出

    确切说法是“如何反向输出一个四位数的数字”,下面是操作步骤。 将要翻转的数字转换成字符串。 num = 1234 str_num = str(num) 使用字符串的切片操作与步长来实现反转。 reverse_str_num = str_num[::-1] 这里的[::-1]表示从字符串结尾到开头,步长为-1,即倒序输出。 将反转后的字符串转回数字类型。 re…

    python 2023年6月5日
    00
  • Python字符串的转义字符

    Python字符串是由多个字符组成的数据类型,字符串中的字符可以使用单引号、双引号或者三重引号括起来。在Python字符串中,可以使用转义字符来表示一些特殊的字符或字符序列,例如换行符、制表符等。 下面是一些常用的Python字符串转义字符及其含义: \n:表示换行符; \t:表示制表符; \’: 表示单引号; \”: 表示双引号; \:表示反斜杠。 在Py…

    python 2023年6月5日
    00
  • 详解Python自动化之文件自动化处理

    详解Python自动化之文件自动化处理 本文将讲解利用Python进行文件自动化处理的完整攻略,包含以下几个步骤: 控制文件路径 文件读写操作 批量操作文件 文件重命名 文件复制与移动 文件压缩 以下将详细讲解每个步骤。 1. 控制文件路径 在Python中,我们可以使用os模块来控制文件路径。该模块提供了一些用于处理文件路径的函数,如获取当前工作目录os.…

    python 2023年5月19日
    00
  • 如何使用Python将一个CSV文件中的数据导入到数据库中?

    以下是如何使用Python将一个CSV文件中的数据导入到数据库中的完整使用攻略。 使用Python将一个CSV文件中的数据导入到数据库中的前提条件 在Python将一个CSV文件中的数据导入到数据库中前,需要确保已经安装并启动了支持导入数据的数据库,例如MySQL或PostgreSQL,并且需要安装Python的相应数据库驱动程序,例如mysql-conne…

    python 2023年5月12日
    00
  • python实现串口自动触发工作的示例

    下面是“python实现串口自动触发工作的示例”的完整攻略。 1. 前置条件 在进行串口自动触发工作之前,你需要先了解操作系统中串口的基本知识,并且需要安装相应的串口模拟器软件。在这里以windows操作系统为例,推荐使用PuTTY和Realterm两款软件。 2. 实现步骤 2.1 安装相关模块 在python中实现串口通讯,我们需要使用到pyserial…

    python 2023年5月19日
    00
  • python实现定制交互式命令行的方法

    实现定制交互式命令行,可以使用Python标准库中的cmd模块。下面是该过程的完整攻略: 步骤一:创建一个命令行解析器类 导入cmd模块 创建一个继承自cmd.Cmd的类,该类将作为命令行解析器 在该类中重写欢迎信息、提示符和默认的帮助信息的方法 示例代码: import cmd class MyCmd(cmd.Cmd): # 定义欢迎信息 def do_h…

    python 2023年6月2日
    00
  • Python – 选择出现在第二个数据框中的数据框中的行

    【问题标题】:Python – Select lines in dataframe that appear in a second data framePython – 选择出现在第二个数据框中的数据框中的行 【发布时间】:2023-04-02 11:24:01 【问题描述】: 我有两个 Pandas 数据框,列数相同,行数不同。 dfA = pd.Data…

    Python开发 2023年4月8日
    00
  • Python中的chr()函数与ord()函数解析

    Python中的chr()函数与ord()函数解析 chr()函数 在 Python 中,chr() 函数用于将 Unicode 码点转换为字符。Unicode 码点是一个整数,用于表示字符的独特标识符。此函数的语法为: chr(i) 其中 i 代表 Unicode 码点。 下面是一个示例,演示了如何使用 chr() 函数将 Unicode 码点转换为字符:…

    python 2023年5月31日
    00
合作推广
合作推广
分享本页
返回顶部