PyTorch实现手写数字识别的示例代码

以下是“PyTorch实现手写数字识别的示例代码”的完整攻略,包含两个示例说明。

PyTorch实现手写数字识别的示例代码

手写数字识别是计算机视觉中的一个经典问题,它可以用于识别手写数字的图像。在PyTorch中,我们可以使用MNIST数据集来训练一个手写数字识别模型。下面是PyTorch实现手写数字识别的示例代码:

示例1:使用全连接层实现手写数字识别

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# 加载数据集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)

# 定义模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)

# 创建模型实例
net = Net()

# 进行模型训练
for epoch in range(10):
    for i, (inputs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

    # 打印训练日志
    print('Epoch: %d, Loss: %.4f' % (epoch+1, loss.item()))

# 进行模型测试
correct = 0
total = 0
with torch.no_grad():
    for inputs, targets in test_loader:
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

# 打印测试结果
print('Accuracy: %.2f%%' % (100 * correct / total))

在这个示例中,我们首先加载了MNIST数据集,并使用DataLoader创建了数据加载器。然后,我们定义了一个包含两个全连接层的模型,并使用交叉熵损失函数和随机梯度下降优化器进行模型训练。最后,我们使用测试集对模型进行测试,并计算了模型的准确率。

示例2:使用卷积神经网络实现手写数字识别

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# 加载数据集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)

# 定义模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=2)
        self.fc1 = nn.Linear(7*7*64, 1024)
        self.fc2 = nn.Linear(1024, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(-1, 7*7*64)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)

# 创建模型实例
net = Net()

# 进行模型训练
for epoch in range(10):
    for i, (inputs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

    # 打印训练日志
    print('Epoch: %d, Loss: %.4f' % (epoch+1, loss.item()))

# 进行模型测试
correct = 0
total = 0
with torch.no_grad():
    for inputs, targets in test_loader:
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

# 打印测试结果
print('Accuracy: %.2f%%' % (100 * correct / total))

在这个示例中,我们首先加载了MNIST数据集,并使用DataLoader创建了数据加载器。然后,我们定义了一个包含两个卷积层和两个全连接层的卷积神经网络,并使用交叉熵损失函数和随机梯度下降优化器进行模型训练。最后,我们使用测试集对模型进行测试,并计算了模型的准确率。

总结

本文介绍了PyTorch实现手写数字识别的示例代码,包括使用全连接层和卷积神经网络两种方法,并提供了两个示例说明。在实现过程中,我们使用了MNIST数据集,并使用DataLoader创建了数据加载器。然后,我们定义了模型、损失函数和优化器,并进行了模型训练和测试。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch实现手写数字识别的示例代码 - Python技术站

(0)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • python调用pytorch实现deeplabv3+图像语义分割——以分割动漫人物为例

    图像语义分割就是把图像分成若干个特定的、具有独特性质的区域并提出感兴趣目标的技术和过程。本文提供了一个可进行自定义数据集训练基于pytorch的deeplabv3+图像分割模型的方法,训练了一个动漫人物分割模型,不过数据集较小,仅供学习使用 程序输入:动漫图片 程序输出:分割好的动漫人物图片 目录 程序简介 程序/数据集下载 数据集准备 训练步骤 预测演示步…

    2023年4月8日
    00
  • pytorch动态网络以及权重共享实例

    以下是关于“PyTorch 动态网络以及权重共享实例”的完整攻略,其中包含两个示例说明。 示例1:动态网络 步骤1:导入必要库 在定义动态网络之前,我们需要导入一些必要的库,包括torch。 import torch 步骤2:定义动态网络 在这个示例中,我们使用动态网络来演示如何定义动态网络。 # 定义动态网络 class DynamicNet(torch.…

    PyTorch 2023年5月16日
    00
  • Python实现softmax反向传播的示例代码

    Python实现softmax反向传播的示例代码 softmax函数是一种常用的激活函数,它可以将输入转换为概率分布。在神经网络中,softmax函数通常用于多分类问题。本文将提供一个完整的攻略,介绍如何使用Python实现softmax反向传播。我们将提供两个示例,分别是使用softmax反向传播进行多分类和使用softmax反向传播进行图像分类。 sof…

    PyTorch 2023年5月15日
    00
  • pytorch in vscode (Module ‘xx’ has no ‘xx’ member pylint(no-member))

    在VSCode setting中搜索python.linting.pylintPath改为pylint的路径,如/home/xxx/.local/lib/python3.5/site-packages/pylint

    PyTorch 2023年4月6日
    00
  • Win10操作系统中PyTorch虚拟环境配置+PyCharm配置

    Win10操作系统中PyTorch虚拟环境配置+PyCharm配置 在使用PyTorch进行深度学习开发时,我们通常需要搭建一个适合自己的开发环境。本文将介绍如何在Win10操作系统中配置PyTorch虚拟环境,并使用PyCharm进行开发,并演示两个示例。 示例一:使用Anaconda创建PyTorch虚拟环境 下载并安装Anaconda:从Anacond…

    PyTorch 2023年5月15日
    00
  • PyTorch LSTM,batch_first=True对初始化h0和c0的影响

    batch_first=True会对LSTM的输入输出的维度顺序有影响,但是对初始化h0和c0的维度顺序没有影响,也就是说,不管batch_first=True还是False,h0和c0的维度顺序都是:     关于LSTM的输入输出,可参考这篇博客。  

    2023年4月7日
    00
  • pytorch实践:dog VS cat

    猫狗分类,练手级代码,与手写数字识别相比,主要修改的地方是输出全连接层,将输出通道由10(十个数字)改成2(猫狗二分类)。还有一个是对数据集处理,因pytorch没有内置数据集函数,因此图片要自己处理。 数据要用opencv处理,归一化。 数据集:data __train__Cat       |     |__Dog       |__test__Cat …

    PyTorch 2023年4月8日
    00
  • pytorch之维度变化view/reshape;squeeze/unsqueeze;Transpose/permute;Expand/repeat

    ————恢复内容开始———— 概括:      一. view/reshape      作用几乎一模一样,保证size不变:意思就是各维度相乘之积相等(numel()),且具有物理意义,别瞎变,要不然破坏数据污染数据;     数据的存储、维度顺序非常重要,需要时刻记住            size没有保持固定住,报错  …

    PyTorch 2023年4月7日
    00
合作推广
合作推广
分享本页
返回顶部