pytorch 6 batch_train 批训练操作

下面是关于pytorch 6 batch_train 批训练的完整攻略。

什么是批训练操作

在深度学习中,一般将训练数据分成一个个的batch,每个batch都可以看做是一个小的数据集。在批训练操作中,模型将对每个batch进行一次前向传播和反向传播,在更新梯度的过程中,使用所有batch的梯度的平均值。这样可以有效地加速训练进程,减小了内存占用和梯度更新的波动。

在pytorch中,批训练操作可以通过使用torch.utils.data.DataLoadertorch.optim.SGD两个类实现。

如何进行批训练操作

1. 导入相关库

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

2. 准备数据

使用torch.utils.data.Dataset类创建数据集,并使用torch.utils.data.DataLoader类创建数据加载器。

class MyDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.x = torch.randn(100, 10)
        self.y = torch.randint(2, (100,))

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

    def __len__(self):
        return len(self.x)

dataset = MyDataset()
loader = DataLoader(dataset=dataset, batch_size=10, shuffle=True)

3. 准备模型和损失函数

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = MyModel()
criterion = nn.CrossEntropyLoss()

4. 准备优化器

optimizer = optim.SGD(model.parameters(), lr=0.1)

5. 训练模型

for epoch in range(num_epochs):
    for i, (inputs, labels) in enumerate(loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if (i+1) % log_step == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(loader)}], Loss: {loss.item():.4f}')

批训练操作的示例

示例1:使用MNIST数据集进行分类

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

test_dataset = MNIST(root='./data', train=False, transform=transform, download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5)
        self.bn2 = nn.BatchNorm2d(32)
        self.fc1 = nn.Linear(32 * 4 * 4, 10)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = nn.functional.max_pool2d(x, 2)
        x = torch.relu(self.bn2(self.conv2(x)))
        x = nn.functional.max_pool2d(x, 2)
        x = x.view(-1, 32 * 4 * 4)
        x = self.fc1(x)
        return x

model = MyModel()
criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

num_epochs = 10
log_step = 100

for epoch in range(num_epochs):
    for i, (inputs, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if (i+1) % log_step == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

    # 每个epoch结束后,对模型进行一个测试
    with torch.no_grad():
        total = 0
        correct = 0
        for inputs, labels in test_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total
        print(f'Test Accuracy: {accuracy:.2f}%')

示例2:使用CIFAR10数据集进行分类

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms
from torch.utils.data import DataLoader

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = CIFAR10(root='./data', train=True, transform=transform_train, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)

test_dataset = CIFAR10(root='./data', train=False, transform=transform_test, download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(512)
        self.fc1 = nn.Linear(512 * 4 * 4, 512)
        self.bn5 = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = nn.functional.max_pool2d(x, 2)
        x = torch.relu(self.bn2(self.conv2(x)))
        x = nn.functional.max_pool2d(x, 2)
        x = torch.relu(self.bn3(self.conv3(x)))
        x = nn.functional.max_pool2d(x, 2)
        x = torch.relu(self.bn4(self.conv4(x)))
        x = nn.functional.max_pool2d(x, 2)
        x = x.view(-1, 512 * 4 * 4)
        x = torch.relu(self.bn5(self.fc1(x)))
        x = self.fc2(x)
        return x

model = MyModel()
criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)

num_epochs = 10
log_step = 100

for epoch in range(num_epochs):
    for i, (inputs, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if (i+1) % log_step == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

    # 每个epoch结束后,对模型进行一个测试
    with torch.no_grad():
        total = 0
        correct = 0
        for inputs, labels in test_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total
        print(f'Test Accuracy: {accuracy:.2f}%')

以上就是关于pytorch 6 batch_train 批训练操作的攻略和两个示例的说明。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch 6 batch_train 批训练操作 - Python技术站

(0)
上一篇 2023年5月25日
下一篇 2023年5月25日

相关文章

  • KOA+egg.js集成kafka消息队列的示例

    下面是关于KOA+egg.js集成kafka消息队列的完整攻略。 一、什么是Kafka Kafka是一个高吞吐量的分布式队列系统,被广泛应用于大规模数据处理和处理高并发请求的场景。 二、集成kafka消息队列方案 KOA+egg.js集成kafka消息队列,需要用到kafka-node和egg-kafkanode插件。 其中,kafka-node是kafka…

    人工智能概览 2023年5月25日
    00
  • Spring Cloud Ribbon实现客户端负载均衡的示例

    下面是“Spring Cloud Ribbon实现客户端负载均衡的示例”的完整攻略。 一、什么是Spring Cloud Ribbon Spring Cloud Ribbon是Netflix Ribbon的一个集成,通过使用Spring Cloud的注解和Spring Cloud的默认配置,可以方便地实现客户端负载均衡。 二、Spring Cloud Rib…

    人工智能概览 2023年5月25日
    00
  • 关于python中remove的一些坑小结

    关于Python中remove的一些坑小结 问题简介 在Python中使用remove()方法移除列表中的元素时,经常会遇到一些问题。例如,移除列表中特定的元素却没有成功移除,在移除元素时却出现了IndexError等错误。本文将详细解释这些问题的产生原因,并提供解决方案。 问题解决 使用remove()方法移除列表中元素时,需要注意以下两点: 问题1:re…

    人工智能概览 2023年5月25日
    00
  • Perl5 OOP学习笔记第1/2页

    如果想学习 Perl5 面向对象编程(OOP),可以参考下面的攻略: 第1页 什么是面向对象编程? 对象是什么? 对象是程序中的一个实体,它包括一些属性和可以对这些属性执行的操作。 面向对象编程(OOP)是什么? OOP 是一种编程范式,使用面向对象的方式描述和解决问题。在 OOP 中,程序被组织成对象,对象之间可以互相交互来完成任务。 这里还需要注意 OO…

    人工智能概论 2023年5月25日
    00
  • node.js基于mongodb的搜索分页示例

    node.js是一个基于Chrome V8引擎的JavaScript运行环境,可以轻松地构建高效的Web应用程序。而mongodb是一个功能强大的文档数据库,是node.js的好搭档。搜索分页是Web应用程序中常见的需求之一,本文将为您详细讲解如何使用node.js和mongodb构建搜索分页示例。 1. 安装和配置mongodb 首先,在本地安装mongo…

    人工智能概论 2023年5月25日
    00
  • OpenCV4.1.0+VisualStudio2019开发环境搭建(超级简单)

    下面我将为您详细讲解“OpenCV4.1.0+VisualStudio2019开发环境搭建(超级简单)”的完整攻略。 第一步 安装Visual Studio 2019 首先,我们需要安装Visual Studio 2019,可以在微软官网下载安装包进行安装。具体步骤可以参考下面的链接:Visual Studio 2019安装教程 第二步 安装CMake Op…

    人工智能概览 2023年5月25日
    00
  • Windows设置nginx开机自启动的方法

    当我们使用 Windows 操作系统来配置 Nginx 服务器时,每次重启系统时都需要手动启动 Nginx,非常麻烦。因此,设置 Nginx 开机自启动是非常必要的。下面是 Windows 设置 Nginx 开机自启动的完整攻略: 第一步:创建一个 Nginx 开机启动的 .bat 文件 在任何一个地方创建一个新的文本文件,比如说在桌面上。 将下面这行命令复…

    人工智能概览 2023年5月25日
    00
  • python和js交互调用的方法

    Python和JavaScript是两种不同的编程语言,它们在特性和运行环境上有一些显著的差异。但是,在一些现代Web开发场景中,我们常常会需要使用这两种语言协同工作,以实现需要在浏览器和服务器上公用的某些功能。 下面,我们将详细讲解Python和JavaScript之间的交互与调用方法,包括在前端和后端如何使用JavaScript调用Python,以及如何…

    人工智能概览 2023年5月25日
    00
合作推广
合作推广
分享本页
返回顶部