pytorch + visdom CNN处理自建图片数据集的方法

对于使用PyTorch训练CNN的过程,一般情况下需要进行图片的预处理、数据集的加载,以及训练过程的可视化等步骤。其中,使用visdom进行训练过程的可视化非常方便,其支持的图形工具非常丰富。

下面,我们将围绕着“pytorch + visdom CNN处理自建图片数据集的方法”,从以下几个方面进行详细讲解。

1.数据集的准备

对于训练CNN所需的数据集,一般情况下需要进行以下几个步骤的准备:

(1)将数据集分为训练集、测试集,一般比例为8:2或7:3;

(2)将数据集中的图片进行缩放,使其大小统一,便于CNN模型的训练;

(3)将数据集中的图片进行标记,对每张图片给出对应的标签。

在PyTorch中,我们可以使用 torchvision.datasets.ImageFolder 对自定义数据集进行加载,它会自动为数据集中的每个样本生成标签,并将其存储在分类文件夹中。示例代码如下:

import torchvision.transforms as transforms
import torchvision.datasets as datasets

# 数据增强
data_transforms = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
    ])

# 数据集加载
image_datasets = datasets.ImageFolder(root='./data', transform=data_transforms)

# 数据集分割
train_size = int(0.8 * len(image_datasets))
test_size = len(image_datasets) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(image_datasets, [train_size, test_size])

2.CNN模型的定义

在PyTorch中定义CNN模型,可以直接继承 nn. Module 类,然后在子类中实现 init 和 forward 方法。其中,init 方法一般用来定义所需的层, forward 方法则用来实现前向传播。示例代码如下:

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 29 * 29, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 29 * 29)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

3.训练过程的实现

在PyTorch中,我们可以使用 torch.optim 模块中的优化器对CNN模型进行训练。其中,常用的优化器包括 Adam(), SGD() 和 RMSprop() 等。示例代码如下:

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

4.训练过程的可视化

使用visdom库可以直观的显示训练过程中的Loss值、准确率等变量。示例如下:

import visdom

def plot_loss_acc(loss, acc):
    x = torch.arange(len(loss))
    y1 = loss
    y2 = acc
    assert len(x) == len(y1) == len(y2)
    y1_opts = dict(title='loss', xlabel='epoch', ylabel='Loss')
    y2_opts = dict(title='acc', xlabel='epoch', ylabel='Acc.')
    win = 'result'
    if not vis.win_exists(win):
        vis.line(X=x,Y=y1, win=win,opts=y1_opts, name='train-loss')
        vis.line(X=x,Y=y2, win=win,opts=y2_opts, name='train-accuracy')
    else:
        vis.line(X=x,Y=y1, win=win,opts=y1_opts, update='append', name='train-loss')
        vis.line(X=x,Y=y2, win=win,opts=y2_opts, update='append', name='train-accuracy')

整合以上四个步骤后,即可方便地使用PyTorch和visdom对自定数据集进行CNN模型训练、训练过程的可视化。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch + visdom CNN处理自建图片数据集的方法 - Python技术站

(0)
上一篇 2023年5月25日
下一篇 2023年5月25日

相关文章

  • Yii2框架中一些折磨人的坑

    下面我就来详细讲解Yii2框架中一些折磨人的坑和解决方案。 一、数据库操作中的坑 1.1 坑:使用Query对象时,忘记使用createCommand方法生成实际的SQL语句 在Yii2框架中,我们可以使用Query对象来构建和执行SQL语句。但是,在使用Query对象时,需要注意生成实际的SQL语句时需要使用createCommand方法。如果忘记了使用c…

    人工智能概论 2023年5月25日
    00
  • 易语言修改指定网页为浏览器主页的代码

    以下是详细讲解“易语言修改指定网页为浏览器主页的代码”的完整攻略。 1. 确认浏览器主页的配置文件路径 首先,我们需要确认浏览器主页的配置文件路径。以Chrome为例,Windows系统下Chrome的主页配置文件存放在C:\Users\{user}\AppData\Local\Google\Chrome\User Data\Default\Preferen…

    人工智能概论 2023年5月25日
    00
  • Nginx部署vue项目和配置代理的问题解析

    下面就是Nginx部署Vue项目的完整攻略,包括如何配置代理。 1. 准备工作 在开始部署Vue项目之前,首先需要安装和配置好Nginx,以及确保Vue项目的构建已经完成,生成了静态文件。 2. 部署Vue项目 2.1 将Vue项目的静态文件放入Nginx的服务目录中 假设Vue项目的静态文件都在dist目录下,将此目录拷贝到Nginx的服务目录下,比如在U…

    人工智能概览 2023年5月25日
    00
  • 捷速OCR文字识别如何把PDF转为txt?捷速OCR文字识别把PDF转为txt教程

    这里介绍使用捷速OCR文字识别工具将PDF文件转换为txt简单易学的教程。 步骤一:准备工作 首先,我们需要下载并安装捷速OCR文字识别工具,安装完成后,打开软件。 步骤二:导入PDF文件 在捷速OCR文字识别软件中,我们需要导入PDF文件。在“OCR文字识别”界面,选择“导入”按钮,然后选择需要转换的PDF文件。 步骤三:选择转换类型和语言 选择需要转换的…

    人工智能概览 2023年5月25日
    00
  • 如何优雅的在一台vps(云主机)上面部署vue+mongodb+express项目

    下面我将为你详细讲解如何优雅地在一台vps上面部署vue+mongodb+express项目的完整攻略。 确认vps环境 首先,需要确认你购买的vps已经安装好了Node.js和MongoDB。如果没有安装,需要先安装它们。具体安装方式可参考 Node.js 和 MongoDB 的官方文档。 部署Vue项目 在vps上创建一个专门存放代码的文件夹,例如/ho…

    人工智能概论 2023年5月25日
    00
  • Centos6.4 编译安装 nginx php的方法

    Centos6.4 编译安装 Nginx + PHP 的方法 本文主要讲解如何在 CentOS 6.4 系统上,使用源码编译的方式安装 Nginx 和 PHP,以便于自定义编译选项和版本。下面是具体的操作步骤。 1. 安装编译环境 在编译 Nginx 和 PHP 之前,需要先安装编译环境。 $ yum install -y gcc gcc-c++ make …

    人工智能概览 2023年5月25日
    00
  • CentOS6.3添加nginx系统服务的实例详解

    CentOS6.3添加nginx系统服务的实例详解 问题描述 在安装完CentOS6.3以及nginx服务器后,如何将nginx服务加入系统服务,实现系统启动时自启动nginx服务? 解决方法 第一步:创建nginx服务管理脚本 在CentOS系统中,使用init.d脚本管理系统服务。因此,我们需要创建一个nginx服务管理脚本,将其放入/etc/init.…

    人工智能概览 2023年5月25日
    00
  • IOS开发之由身份证号码提取性别的实现代码

    下面我将为大家介绍IOS开发中如何通过提取身份证号码中的信息来获取性别的实现代码攻略。 步骤一:获取身份证号码 在IOS中我们需要通过UI控件来获取用户输入的身份证号码,这里以UITextfield为例: @IBOutlet weak var idNumberInputField: UITextField! let idNumber = idNumberIn…

    人工智能概论 2023年5月25日
    00
合作推广
合作推广
分享本页
返回顶部