PyTorch实现手写数字识别的示例代码

yizhihongxing

以下是“PyTorch实现手写数字识别的示例代码”的完整攻略,包含两个示例说明。

PyTorch实现手写数字识别的示例代码

手写数字识别是计算机视觉中的一个经典问题,它可以用于识别手写数字的图像。在PyTorch中,我们可以使用MNIST数据集来训练一个手写数字识别模型。下面是PyTorch实现手写数字识别的示例代码:

示例1:使用全连接层实现手写数字识别

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# 加载数据集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)

# 定义模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)

# 创建模型实例
net = Net()

# 进行模型训练
for epoch in range(10):
    for i, (inputs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

    # 打印训练日志
    print('Epoch: %d, Loss: %.4f' % (epoch+1, loss.item()))

# 进行模型测试
correct = 0
total = 0
with torch.no_grad():
    for inputs, targets in test_loader:
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

# 打印测试结果
print('Accuracy: %.2f%%' % (100 * correct / total))

在这个示例中,我们首先加载了MNIST数据集,并使用DataLoader创建了数据加载器。然后,我们定义了一个包含两个全连接层的模型,并使用交叉熵损失函数和随机梯度下降优化器进行模型训练。最后,我们使用测试集对模型进行测试,并计算了模型的准确率。

示例2:使用卷积神经网络实现手写数字识别

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# 加载数据集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)

# 定义模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=2)
        self.fc1 = nn.Linear(7*7*64, 1024)
        self.fc2 = nn.Linear(1024, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(-1, 7*7*64)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)

# 创建模型实例
net = Net()

# 进行模型训练
for epoch in range(10):
    for i, (inputs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

    # 打印训练日志
    print('Epoch: %d, Loss: %.4f' % (epoch+1, loss.item()))

# 进行模型测试
correct = 0
total = 0
with torch.no_grad():
    for inputs, targets in test_loader:
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

# 打印测试结果
print('Accuracy: %.2f%%' % (100 * correct / total))

在这个示例中,我们首先加载了MNIST数据集,并使用DataLoader创建了数据加载器。然后,我们定义了一个包含两个卷积层和两个全连接层的卷积神经网络,并使用交叉熵损失函数和随机梯度下降优化器进行模型训练。最后,我们使用测试集对模型进行测试,并计算了模型的准确率。

总结

本文介绍了PyTorch实现手写数字识别的示例代码,包括使用全连接层和卷积神经网络两种方法,并提供了两个示例说明。在实现过程中,我们使用了MNIST数据集,并使用DataLoader创建了数据加载器。然后,我们定义了模型、损失函数和优化器,并进行了模型训练和测试。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch实现手写数字识别的示例代码 - Python技术站

(0)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • pytorch seq2seq模型中加入teacher_forcing机制

    在循环内加的teacher forcing机制,这种为目标确定的时候,可以这样加。 目标不确定,需要在循环外加。 decoder.py 中的修改 “”” 实现解码器 “”” import torch.nn as nn import config import torch import torch.nn.functional as F import numpy…

    PyTorch 2023年4月8日
    00
  • pytorch 计算Parameter和FLOP的操作

    计算PyTorch模型参数和浮点操作(FLOP)是模型优化和性能调整的重要步骤。下面是关于如何计算PyTorch模型参数和FLOP的完整攻略: 计算模型参数 PyTorch中模型参数的数量是模型设计的基础部分。可以使用下面的代码计算PyTorch模型中的总参数数量: import torch.nn as nn def model_parameters(mod…

    PyTorch 2023年5月17日
    00
  • pytorch: cudnn.benchmark=True

    import torch.backends.cudnn as cudnn cudnn.benchmark = True 设置这个 flag 可以让内置的 cuDNN 的 auto-tuner 自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题。如果网络的输入数据维度或类型上变化不大,也就是每次训练的图像尺寸都是一样的时候,设置 torch.backe…

    PyTorch 2023年4月8日
    00
  • PyTorch中Torch.arange函数详解

    在本文中,我们将介绍PyTorch中的torch.arange()函数。torch.arange()函数是一个用于创建等差数列的函数,可以方便地生成一组数字序列。本文将详细介绍torch.arange()函数的用法和示例。 torch.arange()函数的用法 torch.arange()函数的语法如下: torch.arange(start=0, end…

    PyTorch 2023年5月15日
    00
  • PyTorch中topk函数的用法详解

    PyTorch中topk函数的用法详解 在PyTorch中,topk函数是一种用于获取张量中最大值或最小值的函数。在本文中,我们将介绍PyTorch中topk函数的用法,并提供两个示例说明。 示例1:获取张量中最大的k个值 以下是一个获取张量中最大的k个值的示例代码: import torch # Create input tensor x = torch.…

    PyTorch 2023年5月16日
    00
  • Pytorch 入门之Siamese网络

    首次体验Pytorch,本文参考于:github and  PyTorch 中文网人脸相似度对比         本文主要熟悉Pytorch大致流程,修改了读取数据部分。没有采用原作者的ImageFolder方法:   ImageFolder(root, transform=None, target_transform=None, loader=defaul…

    2023年4月8日
    00
  • Pytorch划分数据集的方法:torch.utils.data.Subset

        Pytorch提供的对数据集进行操作的函数详见:https://pytorch.org/docs/master/data.html#torch.utils.data.SubsetRandomSampler torch的这个文件包含了一些关于数据集处理的类: class torch.utils.data.Dataset: 一个抽象类, 所有其他类的数据…

    PyTorch 2023年4月6日
    00
  • pytorch 多gpu训练

    pytorch 多gpu训练 用nn.DataParallel重新包装一下 数据并行有三种情况 前向过程 device_ids=[0, 1, 2] model = model.cuda(device_ids[0]) model = nn.DataParallel(model, device_ids=device_ids) 只要将model重新包装一下就可以。…

    PyTorch 2023年4月6日
    00
合作推广
合作推广
分享本页
返回顶部