pytorch + visdom CNN处理自建图片数据集的方法

yizhihongxing

对于使用PyTorch训练CNN的过程,一般情况下需要进行图片的预处理、数据集的加载,以及训练过程的可视化等步骤。其中,使用visdom进行训练过程的可视化非常方便,其支持的图形工具非常丰富。

下面,我们将围绕着“pytorch + visdom CNN处理自建图片数据集的方法”,从以下几个方面进行详细讲解。

1.数据集的准备

对于训练CNN所需的数据集,一般情况下需要进行以下几个步骤的准备:

(1)将数据集分为训练集、测试集,一般比例为8:2或7:3;

(2)将数据集中的图片进行缩放,使其大小统一,便于CNN模型的训练;

(3)将数据集中的图片进行标记,对每张图片给出对应的标签。

在PyTorch中,我们可以使用 torchvision.datasets.ImageFolder 对自定义数据集进行加载,它会自动为数据集中的每个样本生成标签,并将其存储在分类文件夹中。示例代码如下:

import torchvision.transforms as transforms
import torchvision.datasets as datasets

# 数据增强
data_transforms = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
    ])

# 数据集加载
image_datasets = datasets.ImageFolder(root='./data', transform=data_transforms)

# 数据集分割
train_size = int(0.8 * len(image_datasets))
test_size = len(image_datasets) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(image_datasets, [train_size, test_size])

2.CNN模型的定义

在PyTorch中定义CNN模型,可以直接继承 nn. Module 类,然后在子类中实现 init 和 forward 方法。其中,init 方法一般用来定义所需的层, forward 方法则用来实现前向传播。示例代码如下:

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 29 * 29, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 29 * 29)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

3.训练过程的实现

在PyTorch中,我们可以使用 torch.optim 模块中的优化器对CNN模型进行训练。其中,常用的优化器包括 Adam(), SGD() 和 RMSprop() 等。示例代码如下:

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

4.训练过程的可视化

使用visdom库可以直观的显示训练过程中的Loss值、准确率等变量。示例如下:

import visdom

def plot_loss_acc(loss, acc):
    x = torch.arange(len(loss))
    y1 = loss
    y2 = acc
    assert len(x) == len(y1) == len(y2)
    y1_opts = dict(title='loss', xlabel='epoch', ylabel='Loss')
    y2_opts = dict(title='acc', xlabel='epoch', ylabel='Acc.')
    win = 'result'
    if not vis.win_exists(win):
        vis.line(X=x,Y=y1, win=win,opts=y1_opts, name='train-loss')
        vis.line(X=x,Y=y2, win=win,opts=y2_opts, name='train-accuracy')
    else:
        vis.line(X=x,Y=y1, win=win,opts=y1_opts, update='append', name='train-loss')
        vis.line(X=x,Y=y2, win=win,opts=y2_opts, update='append', name='train-accuracy')

整合以上四个步骤后,即可方便地使用PyTorch和visdom对自定数据集进行CNN模型训练、训练过程的可视化。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch + visdom CNN处理自建图片数据集的方法 - Python技术站

(0)
上一篇 2023年5月25日
下一篇 2023年5月25日

相关文章

  • Spring Boot + Thymeleaf + Activiti 快速开发平台项目 附源码

    下面就是Spring Boot + Thymeleaf + Activiti快速开发平台项目的完整攻略。 项目简介 该项目是一个使用Spring Boot和Thymeleaf作为前端模板引擎,Activiti作为工作流引擎的快速开发平台项目,通过该项目可以快速搭建企业级应用程序。项目的主要功能包括:用户登陆、用户管理、角色管理、菜单权限管理、部门管理、工作流…

    人工智能概览 2023年5月25日
    00
  • Django中外键使用总结

    那么我会针对“Django中外键使用总结”进行一个完整的攻略。 什么是外键? 在数据库中,一个表可能会有外键(foreign key),外键通常被用作表之间的关联。外键就是用来关联两张表的字段,关联关系的建立可以在数据库层面来实现,也可以在业务逻辑层面实现。 Django中的外键 在Django中,外键是一个非常重要的概念,它用于建立模型类之间的关联。在Dj…

    人工智能概论 2023年5月25日
    00
  • Django项目中添加ldap登陆认证功能的实现

    让我来详细解释“Django项目中添加LDAP登录认证功能的实现”的完整攻略。 一、什么是LDAP LDAP全称是Lightweight Directory Access Protocol,简称LDAP,它是一个客户端-服务器协议,用于访问一个目录服务。目录是一个关键的网络组件,它提供了一种将名称(如用户、组织、网络服务等)与资源(如文件、印表机等)联系在一…

    人工智能概览 2023年5月25日
    00
  • Qt生成随机数的方法

    生成随机数是很多计算机程序都需要的功能之一。在 Qt 中,我们可以通过以下几种方式来生成随机数: 1. 使用 Qt 提供的 QRandomGenerator 类 QRandomGenerator 类可以生成质量较高的随机数序列。它在 Qt 5.10 中引入,在 Qt 6 中成为标准类。我们可以通过 QRandomGenerator::global() 来获取…

    人工智能概览 2023年5月25日
    00
  • 关于服务网关Spring Cloud Zuul(Finchley版本)

    让我为您详细讲解一下关于服务网关Spring Cloud Zuul(Finchley版本)的攻略。 什么是Spring Cloud Zuul? Spring Cloud Zuul是一个基于Netflix的开源项目Zuul的API Gateway服务,用于微服务架构中的服务网关,为服务提供代理、路由、过滤、安全等功能。 安装Spring Cloud Zuul …

    人工智能概览 2023年5月25日
    00
  • Python实战之手势识别控制电脑音量

    Python实战之手势识别控制电脑音量 在本文中,我们将讲解如何使用Python实现手势识别控制电脑音量的功能。我们将会用到Python的OpenCV和MediaPipe库,以及PyAutoGUI模块。整个流程分为以下几个步骤: 安装必要的库和模块 使用摄像头捕获图像 调用MediaPipe的HandTracking模块进行手势识别 根据识别出的手势对电脑音…

    人工智能概览 2023年5月25日
    00
  • k8s之ingress-nginx详解和部署方案

    k8s之ingress-nginx详解和部署方案 介绍 Ingress是一个Kubernetes对象,用于管理和公开Kubernetes集群中服务的路由规则。 Ingress不会提供自己的实际负载均衡,相反,它需要一个后端负载均衡器来实现实际路由。 Nginx是一个流行的Web服务器和反向代理服务器。nginx-ingress-controller是一个开源…

    人工智能概览 2023年5月25日
    00
  • Android使用phonegap从相册里面获取照片(代码分享)

    以下是关于 “Android使用phonegap从相册里面获取照片(代码分享)”的完整攻略: 1. 什么是PhoneGap PhoneGap是一种移动端开发框架,它基于HTML、CSS、JavaScript和一些原生API的实现,针对不同的移动平台,在原生应用和web应用之间构建一座桥梁。通过PhoneGap,开发者可以用Web技术来开发适用于多个移动平台的…

    人工智能概论 2023年5月24日
    00
合作推广
合作推广
分享本页
返回顶部