pytorch cnn 识别手写的字实现自建图片数据

yizhihongxing

下面是详细的攻略:

简介

卷积神经网络(CNN)是一种在图像识别、语音识别和自然语言处理等领域广泛应用的深度学习算法。本文将介绍如何使用PyTorch实现一个CNN模型来识别手写字,并且展示如何通过自建图片数据进行训练和测试。

步骤

1. 准备自建图片数据

首先需要准备自建图片数据作为我们的训练集和测试集。这些图片应该是手写的数字,且需要分类为数字0到9的10个类别。每个类别应该包含足够数量的图片,以便模型可以充分学习区分不同数字的特征。

可以使用任何图片编辑软件来创建这些手写数字图片,例如Paint或GIMP。注意要将图片大小和分辨率保持一致,以便后续处理。

将这些图片按照不同数字分类,分别存储到对应文件夹内。例如,所有数字0的图片可以存储在名为“0”的文件夹中,所有数字1的图片可以存储在名为“1”的文件夹中,以此类推。

2. 加载和转换自建图片数据

使用torchvision库中的ImageFolder来加载自建的图片数据,该函数会自动将每个文件夹内的图片视为同一类别。可以使用transform来对图片进行预处理,例如缩放和剪裁。

下面是一个示例:

import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder

data_transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor()
])

train_dataset = ImageFolder('./train', transform=data_transform)
test_dataset = ImageFolder('./test', transform=data_transform)

这里使用了Resize将图片尺寸调整为28×28像素,ToTensor将图片转换为PyTorch中的张量。

3. 创建CNN模型

接下来需要创建一个CNN模型。可以使用PyTorch中的nn.Module类来构建模型。这里可以简单地使用两个卷积层和一个全连接层,具体如下:

import torch.nn as nn

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, 3, 1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.max_pool2d(x, 2)
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.max_pool2d(x, 2)
        x = x.view(-1, 64 * 7 * 7)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

这个模型包含两个卷积层和一个全连接层。第一个卷积层有32个输出通道,第二个卷积层有64个输出通道。每个卷积层都使用大小为3×3的卷积核,每个像素的步幅为1,并使用1个像素的填充。全连接层有128个隐藏单元,最终输出10个类别的概率。

4. 训练模型

现在可以开始训练模型了。首先需要定义损失函数和优化器。这里使用交叉熵作为损失函数,使用随机梯度下降作为优化器。

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

然后可以使用PyTorch中的DataLoader来加载自建图片数据,并使用上述定义的损失函数和优化器来训练模型。训练模型的代码如下:

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)

model = CNNModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:
            print('[%d, %5d] loss: %.3f' % (epoch+1, i+1, running_loss/100))
            running_loss = 0.0

训练模型需要遍历所有训练数据多次,每次遍历称为一个epoch。这里共训练了10个epoch。

5. 测试模型

训练完模型后,可以使用测试集来评估模型的性能。这里使用PyTorch中的accuracy_score函数来计算模型在测试集上的准确率。

from sklearn.metrics import accuracy_score

correct = 0
total = 0

with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))

最后,可以使用模型对单个手写数字图片进行预测。预测过程包括两个步骤:首先将图片转换为PyTorch张量,然后使用训练好的模型对该张量进行预测:

from PIL import Image

img = Image.open("5.png")
img = data_transform(img).unsqueeze(0)

outputs = model(img)
_, predicted = torch.max(outputs.data, 1)

print('The predicted digit is:', predicted[0])

这里使用了PIL库来打开单个图片,并使用上述定义的data_transform函数将其转换为PyTorch张量。需要注意,使用unsqueeze将单个张量转换为批量大小为1的张量,在预测时需要用到。

示例说明

下面给出两个示例,分别演示如何创建CNN模型和如何对单个手写数字进行预测。

示例1:创建CNN模型

import torch.nn as nn

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, 3, 1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.max_pool2d(x, 2)
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.max_pool2d(x, 2)
        x = x.view(-1, 64 * 7 * 7)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

该模型包含两个卷积层和一个全连接层。第一个卷积层有32个输出通道,第二个卷积层有64个输出通道。每个卷积层都使用大小为3×3的卷积核,每个像素的步幅为1,并使用1个像素的填充。全连接层有128个隐藏单元,最终输出10个类别的概率。

示例2:对单个手写数字进行预测

from PIL import Image

img = Image.open("5.png")
img = data_transform(img).unsqueeze(0)

outputs = model(img)
_, predicted = torch.max(outputs.data, 1)

print('The predicted digit is:', predicted[0])

这个示例使用了PIL库打开了一个名为“5.png”的图片,并使用上述定义的data_transform函数将其转换为PyTorch张量。将该张量转换为批量大小为1的张量,并使用已训练好的CNN模型对其进行预测,最终输出预测的数字。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch cnn 识别手写的字实现自建图片数据 - Python技术站

(0)
上一篇 2023年6月6日
下一篇 2023年6月6日

相关文章

  • 如何使用Python进行爬虫开发?

    使用Python进行爬虫开发需要以下步骤: 安装Python和相应的第三方库(比如requests和beautifulsoup4) 选择目标网站,并使用requests库发送GET请求获取HTML页面 使用beautifulsoup4库解析HTML页面,提取需要的信息 将提取的信息存储到本地文件或数据库中 以下是两个示例说明: 示例1:爬取新闻网站的标题和链…

    python 2023年4月19日
    00
  • 浅谈Python中的bs4基础

    浅谈Python中的bs4基础 Python中的bs4是一个强大的HTML和XML解析库,可以帮助我们更好地解析网页和XML文档。本文将介绍bs4的基础知识和使用方法。 安装bs4 在使用bs4之前,需要先安装bs4库。可以使用pip命令进行安装: pip install beautifulsoup4 解析HTML文档 以下是一个示例代码,演示如何使用bs4…

    python 2023年5月15日
    00
  • python实现抽奖小程序

    下面是Python实现抽奖小程序的完整攻略: 需求分析 在开始编写程序前,我们需要先明确需求。这个抽奖程序需要实现以下功能:1. 输入参与抽奖人员名单2. 从名单中随机选取若干个人作为获奖者3. 输出获奖者名单 方案设计 知道了需求,我们就可以开始设计实现方案了。为了实现这个抽奖小程序,我们可以采用以下方案:1. 通过Python内置的random模块实现随…

    python 2023年5月23日
    00
  • OpenCV 绘制同心圆的示例代码

    绘制同心圆是计算机视觉中常见的任务,可以使用OpenCV通过简单的代码实现。以下是绘制两个同心圆的示例代码: import cv2 # 创建一个黑色的图像 img = np.zeros((512,512,3), np.uint8) # 确定两个圆的中心坐标与半径 center1 = (256, 256) radius1 = 100 center2 = (25…

    python 2023年5月18日
    00
  • Python文件简单操作及openpyxl操作excel文件详解

    Python文件简单操作及openpyxl操作excel文件详解 Python文件简单操作 文件的打开和关闭 使用open()函数可以打开指定的文件,该函数包含两个参数:第一个参数是文件的路径,第二个参数是文件的打开模式。常见的文件打开模式如下: r: 以只读方式打开文件 w: 以写入方式打开文件,如果文件不存在则创建文件,如果文件已存在则覆盖文件内容 a:…

    python 2023年6月3日
    00
  • Python 添加文件注释和函数注释操作

    添加文件注释和函数注释是Python编程中非常重要的一项操作,能够为开发者提供更好的代码可读性和维护性。下面将提供完整的攻略,帮助你了解如何在Python中添加文件注释和函数注释。 Python添加文件注释操作 在Python文件的开头,使用三个双引号或单引号来添加多行注释。以下是添加文件注释的示例代码: """ 这是一个Pyt…

    python 2023年6月5日
    00
  • python之pyinstaller组件打包命令和异常解析实战

    Python是一门非常流行的高级编程语言,而PyInstaller则是Python中一款常用的打包工具,可以将Python程序转换为可执行文件,以便在其他计算机上运行,而无需安装Python解释器环境。在实际使用中,PyInstaller打包命令和异常解析对我们来说是非常重要的。下面我们来详细讲解如何使用PyInstaller进行打包和解析异常。 PyIns…

    python 2023年5月13日
    00
  • Python实现修改文件内容的方法分析

    Python实现修改文件内容的方法分析 在Python中,可以利用内置的open函数和文件读写模式来实现对文件内容的修改,常见的做法有以下几种。 方法一:利用with语句和文件对象的write方法 with open(‘file.txt’,’r+’) as f: text = f.read() text = text.replace(‘old’, ‘new’…

    python 2023年6月3日
    00
合作推广
合作推广
分享本页
返回顶部