PyTorch一小时掌握之图像识别实战篇

PyTorch一小时掌握之图像识别实战篇

本文将介绍如何使用PyTorch进行图像识别任务。我们将提供两个示例,分别是手写数字识别和猫狗分类。

手写数字识别

手写数字识别是一个经典的图像识别任务。以下是一个简单的手写数字识别示例:

import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms

# 定义超参数
input_size = 784
hidden_size = 500
num_classes = 10
num_epochs = 5
batch_size = 100
learning_rate = 0.001

# 加载MNIST数据集
train_dataset = dsets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = dsets.MNIST(root='./data', train=False, transform=transforms.ToTensor())

# 加载数据集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# 定义模型
model = nn.Sequential(
    nn.Linear(input_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, num_classes)
)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, 28*28)

        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))

# 测试模型
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.reshape(-1, 28*28)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

在这个示例中,我们首先定义了超参数,然后加载了MNIST数据集。接下来,我们定义了一个名为model的模型,并定义了一个名为criterion的损失函数和一个名为optimizer的优化器。然后,我们使用训练数据对模型进行了训练,并在每个epoch结束时输出损失值。最后,我们使用测试数据对模型进行了测试,并输出了模型的准确率。

猫狗分类

猫狗分类是一个常见的图像识别任务。以下是一个简单的猫狗分类示例:

import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torchvision.models as models

# 定义超参数
num_epochs = 5
batch_size = 100
learning_rate = 0.001

# 加载数据集
train_dataset = dsets.ImageFolder(root='./data/train', transform=transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]))
test_dataset = dsets.ImageFolder(root='./data/test', transform=transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]))
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# 加载预训练模型
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))

# 测试模型
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))

在这个示例中,我们首先定义了超参数,然后加载了猫狗分类数据集。接下来,我们加载了一个预训练的ResNet18模型,并将其输出层替换为一个名为fc的线性层。然后,我们定义了一个名为criterion的损失函数和一个名为optimizer的优化器。然后,我们使用训练数据对模型进行了训练,并在每个epoch结束时输出损失值。最后,我们使用测试数据对模型进行了测试,并输出了模型的准确率。

结论

在本文中,我们介绍了如何使用PyTorch进行图像识别任务,并提供了两个示例说明。如果您遵循这些步骤和示例,您应该能够在PyTorch中实现手写数字识别和猫狗分类任务。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch一小时掌握之图像识别实战篇 - Python技术站

(0)
上一篇 2023年5月16日
下一篇 2023年5月16日

相关文章

  • Python利用Pytorch实现绘制ROC与PR曲线图

    当我们需要评估二分类模型的性能时,ROC曲线和PR曲线是两个常用的工具。在Python中,我们可以使用PyTorch库来绘制这些曲线。下面是绘制ROC曲线和PR曲线的完整攻略,包括两个示例说明。 1. 绘制ROC曲线 ROC曲线是一种用于评估二分类模型性能的工具,它显示了真阳性率(TPR)与假阳性率(FPR)之间的关系。以下是使用PyTorch绘制ROC曲线…

    PyTorch 2023年5月15日
    00
  • Pytorch:权重初始化方法

    pytorch在torch.nn.init中提供了常用的初始化方法函数,这里简单介绍,方便查询使用。 介绍分两部分: 1. Xavier,kaiming系列; 2. 其他方法分布   Xavier初始化方法,论文在《Understanding the difficulty of training deep feedforward neural network…

    PyTorch 2023年4月6日
    00
  • pytorch实现加载保存查看checkpoint文件

    在PyTorch中,我们可以使用checkpoint文件来保存和加载模型的状态。checkpoint文件包含了模型的权重、优化器的状态以及其他相关信息。在本文中,我们将详细介绍如何使用PyTorch来加载、保存和查看checkpoint文件。 加载checkpoint文件 在PyTorch中,我们可以使用torch.load函数来加载checkpoint文件…

    PyTorch 2023年5月15日
    00
  • 基于pytorch神经网络模型参数的加载及自定义

    最近在训练MobileNet时经常会对其模型参数进行各种操作,或者替换其中的几层之类的,故总结一下用到的对神经网络参数的各种操作方法。 1.将matlab的.mat格式参数整理转换为tensor类型的模型参数 import torch import torch.nn as nn import torch.nn.functional as F import s…

    PyTorch 2023年4月8日
    00
  • Pytorch释放显存占用方式

    下面是关于Pytorch如何释放显存占用的完整攻略,包含两条示例说明。 1. 使用with torch.no_grad()释放显存 在Pytorch中,通过with语句使用torch.no_grad()上下文管理器可以释放显存,这个操作对于训练中不需要梯度计算的代码非常有用。 代码示例: import torch # 创建一个3000 * 3000的矩阵 t…

    PyTorch 2023年5月17日
    00
  • pytorch 膨胀算法实现大眼效果

    PyTorch 膨胀算法实现大眼效果 膨胀算法是一种常用的图像处理算法,可以用于实现大眼效果。在本文中,我们将详细介绍如何使用 PyTorch 实现膨胀算法,并提供两个示例来说明其用法。 1. 膨胀算法的原理 膨胀算法是一种基于形态学的图像处理算法,它可以将图像中的物体膨胀或扩张。在实现大眼效果时,我们可以使用膨胀算法将眼睛的轮廓进行扩张,从而实现大眼效果。…

    PyTorch 2023年5月15日
    00
  • 关于Pytorch报警告:Warning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead

    在使用Pytorch的时候,遇到警告的日志打印: [W IndexingUtils.h:20] Warning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (function expandTensors)[W ..aten…

    2023年4月6日
    00
  • pytorch 多GPU 训练

    import osos.environ[‘CUDA_VISIBLE_DEVICES’] = ‘0, 1, 2’import torch   #注意以上两行先后顺序不可弄错   device = torch.device(‘cuda’) model = DataParallel(model)model.to(device)   这样模型就会在gpu 0, 1,…

    PyTorch 2023年4月7日
    00
合作推广
合作推广
分享本页
返回顶部