PyTorch实现手写数字识别的示例代码

以下是“PyTorch实现手写数字识别的示例代码”的完整攻略,包含两个示例说明。

PyTorch实现手写数字识别的示例代码

手写数字识别是计算机视觉中的一个经典问题,它可以用于识别手写数字的图像。在PyTorch中,我们可以使用MNIST数据集来训练一个手写数字识别模型。下面是PyTorch实现手写数字识别的示例代码:

示例1:使用全连接层实现手写数字识别

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# 加载数据集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)

# 定义模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)

# 创建模型实例
net = Net()

# 进行模型训练
for epoch in range(10):
    for i, (inputs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

    # 打印训练日志
    print('Epoch: %d, Loss: %.4f' % (epoch+1, loss.item()))

# 进行模型测试
correct = 0
total = 0
with torch.no_grad():
    for inputs, targets in test_loader:
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

# 打印测试结果
print('Accuracy: %.2f%%' % (100 * correct / total))

在这个示例中,我们首先加载了MNIST数据集,并使用DataLoader创建了数据加载器。然后,我们定义了一个包含两个全连接层的模型,并使用交叉熵损失函数和随机梯度下降优化器进行模型训练。最后,我们使用测试集对模型进行测试,并计算了模型的准确率。

示例2:使用卷积神经网络实现手写数字识别

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# 加载数据集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)

# 定义模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=2)
        self.fc1 = nn.Linear(7*7*64, 1024)
        self.fc2 = nn.Linear(1024, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(-1, 7*7*64)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)

# 创建模型实例
net = Net()

# 进行模型训练
for epoch in range(10):
    for i, (inputs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

    # 打印训练日志
    print('Epoch: %d, Loss: %.4f' % (epoch+1, loss.item()))

# 进行模型测试
correct = 0
total = 0
with torch.no_grad():
    for inputs, targets in test_loader:
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

# 打印测试结果
print('Accuracy: %.2f%%' % (100 * correct / total))

在这个示例中,我们首先加载了MNIST数据集,并使用DataLoader创建了数据加载器。然后,我们定义了一个包含两个卷积层和两个全连接层的卷积神经网络,并使用交叉熵损失函数和随机梯度下降优化器进行模型训练。最后,我们使用测试集对模型进行测试,并计算了模型的准确率。

总结

本文介绍了PyTorch实现手写数字识别的示例代码,包括使用全连接层和卷积神经网络两种方法,并提供了两个示例说明。在实现过程中,我们使用了MNIST数据集,并使用DataLoader创建了数据加载器。然后,我们定义了模型、损失函数和优化器,并进行了模型训练和测试。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch实现手写数字识别的示例代码 - Python技术站

(0)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • pytorch提取神经网络模型层结构和参数初始化

    torch.nn.Module()类有一些重要属性,我们可用其下面几个属性来实现对神经网络层结构的提取: torch.nn.Module.children() torch.nn.Module.modules() torch.nn.Module.named_children() torch.nn.Module.named_moduless() 为方面说明,我们…

    2023年4月8日
    00
  • Pytorch:单卡多进程并行训练

    在深度学习的项目中,我们进行单机多进程编程时一般不直接使用multiprocessing模块,而是使用其替代品torch.multiprocessing模块。它支持完全相同的操作,但对其进行了扩展。Python的multiprocessing模块可使用fork、spawn、forkserver三种方法来创建进程。但有一点需要注意的是,CUDA运行时不支持使用…

    2023年4月6日
    00
  • Pytorch关于Dataset 的数据处理

    PyTorch关于Dataset的数据处理 在PyTorch中,Dataset是一个抽象类,用于表示数据集。它提供了一种统一的方式来处理数据,使得我们可以轻松地加载和处理数据。在本文中,我们将详细介绍如何使用PyTorch中的Dataset类来处理数据,并提供两个示例来说明其用法。 1. 创建自定义Dataset 要创建自定义Dataset,需要继承PyTo…

    PyTorch 2023年5月15日
    00
  • PyTorch数据处理,datasets、DataLoader及其工具的使用

    torchvision是PyTorch的一个视觉工具包,提供了很多图像处理的工具。 datasets使用ImageFolder工具(默认PIL Image图像),获取定制化的图片并自动生成类别标签。如裁剪、旋转、标准化、归一化等(使用transforms工具)。 DataLoader可以把datasets数据集打乱,分成batch,并行加速等。 一、data…

    2023年4月8日
    00
  • pytorch中tensor张量数据基础入门

    pytorch张量数据类型入门1、对于pytorch的深度学习框架,其基本的数据类型属于张量数据类型,即Tensor数据类型,对于python里面的int,float,int array,flaot array对应于pytorch里面即在前面加一个Tensor即可——intTensor ,Float tensor,IntTensor of size [d1,…

    2023年4月8日
    00
  • pytorch中的model.eval()和BN层的使用

    PyTorch中的model.eval()和BN层的使用 在深度学习中,模型的训练和测试是两个不同的过程。在测试过程中,我们需要使用model.eval()函数来将模型设置为评估模式。此外,批量归一化(Batch Normalization,BN)层是一种常用的技术,可以加速模型的训练过程。本文将提供一个完整的攻略,介绍如何使用PyTorch中的model.…

    PyTorch 2023年5月15日
    00
  • 【笔记】PyTorch快速入门:基础部分合集

    一天时间快速上手PyTorch PyTorch快速入门 Tensors Tensors贯穿PyTorch始终 和多维数组很相似,一个特点是可以硬件加速 Tensors的初始化 有很多方式 直接给值 data = [[1,2],[3,4]] x_data = torch.tensor(data) 从NumPy数组转来 np_arr = np.array(dat…

    2023年4月8日
    00
  • pytorch中:使用bert预训练模型进行中文语料任务,bert-base-chinese下载。

    1.网址:https://huggingface.co/bert-base-chinese?text=%E5%AE%89%E5%80%8D%E6%98%AF%E5%8F%AA%5BMASK%5D%E7%8B%97 2.下载: 下载 在这里插入图片描述

    PyTorch 2023年4月6日
    00
合作推广
合作推广
分享本页
返回顶部