torch.optim优化算法理解之optim.Adam()解读

yizhihongxing

下面是对于“torch.optim优化算法理解之optim.Adam()解读”的完整攻略。

1. 优化算法概述

在神经网络训练的过程中,我们需要选择一个好的优化算法来更新模型中的参数,这个过程就是优化算法。优化算法可以通过最小化损失函数来更新参数,以便更好地拟合数据。

目前常用的优化算法有SGD、Adam、RMSprop等,每个算法都有自己的优缺点,选用不同的算法会对训练效果产生不同的影响。

2. Adam算法介绍

Adam算法是一种自适应优化算法,可以根据每个参数的历史梯度和动量对学习率进行自适应调整,同时避免了梯度下降中的局部最小值问题。

Adam算法的更新公式如下:

$v_t=\beta_1v_{t-1}+(1-\beta_1)g_t$

$s_t=\beta_2s_{t-1}+(1-\beta_2)g_t^2$

$\hat{v}_t=\frac{v_t}{1-\beta_1^t}$

$\hat{s}_t=\frac{s_t}{1-\beta_2^t}$

$\theta_{t+1}=\theta_t-\frac{\alpha\hat{v}_t}{\sqrt{\hat{s}_t}+\epsilon}$

其中,$g_t$为梯度,$\theta_t$为当前参数,$\alpha$为学习率,$\beta_1$和$\beta_2$为两个可调参数,$v_t$和$s_t$是梯度的一阶矩估计和二阶矩估计,通过对这两个估计的校正可以减小偏差,防止梯度估计的不准确。$\epsilon$是为了防止分母为0而添加的一个极小常数。

3. optim.Adam()函数用法

在Pytorch中,可以使用torch.optim中的Adam函数来实现Adam算法的优化器。Adam函数的定义如下:

torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

其中,params是需要优化的参数,lr为学习率,betas是计算一阶和二阶矩估计的指数衰减率,eps是分母中的极小常数,用于避免除数为0的情况,weight_decay是L2正则化项的权重,amsgrad是一种变体Adam算法,可以减小训练过程中的波动。

4. 案例说明

在下面的两个案例中,我们将使用optim.Adam()来优化两个不同的模型。

案例1:手写数字识别

在这个案例中,我们将使用Pytorch内置的手写数字数据集来训练一个简单的卷积神经网络。我们的目标是优化模型的准确率。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms

# 定义模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.dropout = nn.Dropout()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = nn.functional.relu(nn.functional.max_pool2d(self.conv1(x), 2))
        x = nn.functional.relu(nn.functional.max_pool2d(self.dropout(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return nn.functional.log_softmax(x, dim=1)

# 数据预处理
train_transforms = transforms.Compose([transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,))])

# 加载数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=train_transforms, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# 设置优化器和损失函数
model = Net()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.NLLLoss()

# 训练模型
for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(train_loader):
        inputs, labels = data
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:    
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

在上面的代码中,我们定义了一个简单的卷积神经网络,使用Adam优化器进行模型训练。最终的结果可以达到95%左右的准确率。

案例2:LSTM情感分析

在这个案例中,我们使用LSTM对IMDb情感分析数据集进行分类,从而比较SGD和Adam两种优化算法的差异。

import torch
import torch.nn as nn
import torch.optim as optim
import torchtext.datasets as datasets
import torchtext.data as data

# 定义LSTM模型
class LSTM(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
        super().__init__()

        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=2, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, output_dim)

    def forward(self, text):
        # text: [sent len, batch size]
        embedded = self.embedding(text)
        # embedded: [sent len, batch size, emb dim]
        output, (hidden, cell) = self.rnn(embedded)
        # output: [sent len, batch size, hid dim * num directions]
        # hidden: [num layers * num directions, batch size, hid dim]
        # cell: [num layers * num directions, batch size, hid dim]
        hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
        # hidden: [batch size, hid dim * num directions]
        return self.fc(hidden)

# 加载数据集
TEXT = data.Field(tokenize='spacy')
LABEL = data.LabelField(dtype=torch.float)

train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)
TEXT.build_vocab(train_data, max_size=25000, vectors="glove.6B.100d")
LABEL.build_vocab(train_data)

train_loader, test_loader = data.BucketIterator.splits((train_data, test_data), batch_size=32)

# 定义模型和优化器
input_dim = len(TEXT.vocab)
embedding_dim = 100
hidden_dim = 256
output_dim = 1
model = LSTM(input_dim, embedding_dim, hidden_dim, output_dim)
optimizer1 = optim.Adam(model.parameters(), lr=1e-3)
optimizer2 = optim.SGD(model.parameters(), lr=1e-3)

# 定义损失函数
criterion = nn.BCEWithLogitsLoss()

# 训练模型
for epoch in range(5):
    train_loss1, train_loss2, train_acc1, train_acc2 = 0.0, 0.0, 0.0, 0.0
    model.train()
    for batch in train_loader:
        optimizer1.zero_grad()
        optimizer2.zero_grad()
        text, label = batch.text, batch.label
        preds = model(text).squeeze(1)
        loss1 = criterion(preds, label)
        loss2 = criterion(preds, label)
        loss1.backward()
        loss2.backward()
        optimizer1.step()
        optimizer2.step()
        train_loss1 += loss1.item()
        train_loss2 += loss2.item()
        train_acc1 += ((preds>0).float() == label).sum().item()
        train_acc2 += ((preds>0).float() == label).sum().item()
    train_loss1 /= len(train_loader)
    train_loss2 /= len(train_loader)
    train_acc1 /= len(train_loader)
    train_acc2 /= len(train_loader)
    print(f'Epoch {epoch+1}: Adam loss {train_loss1:.3f} / SGD loss {train_loss2:.3f}, Adam acc {train_acc1*100:.2f}% / SGD acc {train_acc2*100:.2f}%')

在上面的代码中,我们定义了一个LSTM模型,并使用Adam和SGD两种优化算法进行比较。最终的结果表明,Adam算法的收敛速度比较快,而SGD算法的准确率稍高。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:torch.optim优化算法理解之optim.Adam()解读 - Python技术站

(2)
上一篇 2023年6月6日
下一篇 2023年6月6日

相关文章

  • Python实现数据集划分(训练集和测试集)

    Python实现数据集划分(训练集和测试集)是机器学习中非常重要的一部分。数据集划分可以帮助我们评估模型的准确性、提高模型的效率和避免过拟合等问题。下面是实现数据集划分的完整攻略: 步骤一:准备数据集 首先,我们需要准备数据集。数据集是机器学习中重要的组成部分,一般将数据集划分为训练集和测试集,其中训练集用于训练模型,测试集用于测试模型的准确性和泛化能力。 …

    python 2023年6月3日
    00
  • Python命令行参数定义及需要注意的地方

    Python命令行参数是指在运行Python程序时,通过命令行传入的参数信息,它们可以从sys模块的argv列表中获取到。可以使用argparse模块来处理和定义命令行参数。在这个攻略中,我们将详细介绍如何定义和处理Python命令行参数以及需要注意的地方。 使用argparse模块定义Python命令行参数 argparse是Python标准库中定义命令行…

    python 2023年6月3日
    00
  • Python标准库defaultdict模块使用示例

    下面是关于Python标准库defaultdict模块使用的详细攻略: 什么是defaultdict模块 Python标准库中的defaultdict是一个内置模块,它是一个类,它继承自普通的字典(dict),同时添加了一个名为default_factory的方法。default_factory可以将默认值设置为任意类型,其可以是int、list、set、s…

    python 2023年5月13日
    00
  • Python爬虫框架之Scrapy中Spider的用法

    Python爬虫框架之Scrapy中Spider的用法 简介 Scrapy是一个用于爬取网站数据的Python框架,是Python爬虫工具中的一种,其提供了高效、快捷和可扩展的数据获取方式。其中Spider是Scrapy框架中最基本的爬虫,用于定制和控制Scrapy框架的爬取过程。 Spider的基本用法 创建Spider 在Scrapy框架中,我们通过创建…

    python 2023年5月14日
    00
  • 深入讲解Python中的迭代器和生成器

    标题:深入讲解Python中的迭代器和生成器 什么是迭代器? Python中的迭代器是一种访问集合元素的对象,可以使用for循环遍历集合中的元素,同时也可以使用next()函数逐个访问集合中的元素。 迭代器的定义 迭代器对象从一个集合中取出一个元素后,依次再取出下一个元素,直到取出集合中的所有元素为止。迭代器的定义需要满足以下条件: 实现 next() 方法…

    python 2023年6月3日
    00
  • Python和其他编程语言有什么区别?

    Python是一种高级、面向对象的编程语言,与其他编程语言相比,它具有以下几点差别: 1. 语法简单 Python的语法非常简单,易于学习和记忆,像英语一样的语法,加上优雅和简洁的语法风格,使得Python查错和调试变得容易。 示例代码:以下是Python代码和Java代码实现Hello World的对比。 Python代码: print("Hel…

    python 2023年4月19日
    00
  • Python实现获取汉字偏旁部首的方法示例【测试可用】

    获取汉字偏旁部首是中文文本处理中的一个重要问题。本攻略将介绍Python实现获取汉字偏旁部首的方法,包括基于Unicode编码和基于康熙字典的方法。 基于Unicode编码的方法 Unicode编码为每个汉字分配了一个唯一的代码点,可以使用Python内置的ord函数获取汉字的Unicode编码。汉字的偏旁部首通常位于Unicode编码的高位,可以通过位运算…

    python 2023年5月15日
    00
  • 如何使用Python在MySQL中使用外键?

    在MySQL中,可以使用外键来建立表之间的关系。在Python中,可以使用MySQL连接来执行外键查询。以下是在Python中使用外键的完整攻略,包括外键的基本语法、使用外键的示例及如何在Python中使用外键。 外键的基本语法 在MySQL中,可以使用FOREIGN KEY关键字来创建外键以下是创建外键的基本语法: CREATE TABLE table_n…

    python 2023年5月12日
    00
合作推广
合作推广
分享本页
返回顶部