PyTorch梯度裁剪避免训练loss nan的操作

PyTorch梯度裁剪是一种用于避免训练过程中出现lossnan的问题,其通过限制模型的参数梯度范围来提高训练稳定性和收敛效果。以下是PyTorch梯度裁剪的完整攻略:

什么是梯度裁剪

梯度裁剪是一种通过限制参数梯度范围的方法,防止训练过程中出现梯度爆炸或梯度消失的情况。这种现象常常发生在深层神经网络中,尤其是在使用长短时记忆网络(LSTM)等循环神经网络时更加明显。

常见方法

常见的梯度裁剪方法包括全局范围裁剪和逐层范围裁剪两种。

全局范围裁剪:对所有参数的梯度进行裁剪,即裁剪的范围是所有参数的梯度范围。

逐层范围裁剪:对每个层的参数进行裁剪,裁剪的范围是该层参数的梯度范围。这种方法可更切合实际应用,因为不同层的参数梯度范围差异较大。

操作步骤

PyTorch内置了对梯度裁剪的支持,以下是梯度裁剪的操作步骤:

  1. 定义模型,如:
import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(100, 10)

    def forward(self, x):
        return self.linear(x)
  1. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
  1. 对模型进行梯度裁剪

全局裁剪:

clipped_gradients = []
max_grad_norm = 1.0 # 定义最大梯度范围

for group in optimizer.param_groups: # 遍历每个参数组
    for param in group['params']: # 遍历每个参数
        # 计算梯度范数
        grad_norm = param.grad.norm(2)
        # 若梯度范数超过指定范围,则进行梯度裁剪
        if grad_norm > max_grad_norm:
            clipped_gradients.append(param.grad)
            param.grad.div_(grad_norm / max_grad_norm)

逐层裁剪:

clipped_gradients = []
max_grad_norm = 1.0 # 定义最大梯度范围

for group in optimizer.param_groups: # 遍历每个参数组
    for param in group['params']: # 遍历每个参数
        # 若当前参数在线性层中,则计算梯度范数
        if hasattr(param, 'weight'): 
            grad = param.grad
            if grad is None:
                continue
            weight = param.detach()
            if weight.grad is not None:
                weight.grad.detach_()
            else:
                weight.grad = torch.zeros_like(weight)
            # 计算梯度范数
            grad_norm = grad.norm(2)
            # 若梯度范数超过指定范围,则进行梯度裁剪
            if grad_norm > max_grad_norm:
                clipped_gradients.append(grad)
                grad.div_(grad_norm / max_grad_norm)
  1. 更新模型参数
optimizer.step()
  1. 清空梯度
optimizer.zero_grad()

示例说明

下面举两个示例说明梯度裁剪的操作方法。

示例一:全局梯度裁剪

import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(100, 10)

    def forward(self, x):
        return self.linear(x)

model = Model()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# 数据集
x = torch.randn(32, 100)
y = torch.randint(10, (32,))

# 训练
for i in range(100):
    optimizer.zero_grad()
    pred = model(x)
    loss = criterion(pred, y)
    loss.backward()
    # 对梯度进行裁剪
    clipped_gradients = []
    max_grad_norm = 1.0
    for group in optimizer.param_groups:
        for param in group['params']:
            # 计算梯度范数
            grad_norm = param.grad.norm(2)
            # 若梯度范数超过指定范围,则进行梯度裁剪
            if grad_norm > max_grad_norm:
                clipped_gradients.append(param.grad)
                param.grad.div_(grad_norm / max_grad_norm)
    # 更新参数
    optimizer.step()

示例二:逐层梯度裁剪

import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(100, 10)
        self.linear2 = nn.Linear(10, 2)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x

model = Model()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# 数据集
x = torch.randn(32, 100)
y = torch.randint(2, (32,))

# 训练
for i in range(100):
    optimizer.zero_grad()
    pred = model(x)
    loss = criterion(pred, y)
    loss.backward()
    # 对梯度进行裁剪
    clipped_gradients = []
    max_grad_norm = 1.0
    for group in optimizer.param_groups:
        for param in group['params']:
            # 若当前参数在线性层中,则计算梯度范数
            if hasattr(param, 'weight'): 
                grad = param.grad
                if grad is None:
                    continue
                weight = param.detach()
                if weight.grad is not None:
                    weight.grad.detach_()
                else:
                    weight.grad = torch.zeros_like(weight)
                # 计算梯度范数
                grad_norm = grad.norm(2)
                # 若梯度范数超过指定范围,则进行梯度裁剪
                if grad_norm > max_grad_norm:
                    clipped_gradients.append(grad)
                    grad.div_(grad_norm / max_grad_norm)
    # 更新参数
    optimizer.step()

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch梯度裁剪避免训练loss nan的操作 - Python技术站

(0)
上一篇 2023年5月25日
下一篇 2023年5月25日

相关文章

  • 一次nginx 504 Gateway Time-out错误排查、解决记录

    一次NGINX 504 Gateway Time-out错误排查和解决可能涉及到多个原因和步骤,下面我将详细介绍一下完整的攻略。 1. 什么是504 Gateway Time-out错误 当我们访问一个Web站点的时候,我们的浏览器会向Web服务器发送请求。Web服务器通常与一个应用服务器连接,如PHP-FPM、Django等,以处理请求和生成响应。在一些情…

    人工智能概览 2023年5月25日
    00
  • nginx负载均衡配置,宕机自动切换方式

    下面是详细讲解nginx负载均衡配置,宕机自动切换方式的完整攻略过程。 1. 安装nginx 首先需要在服务器上安装nginx,可以使用包管理器如apt-get或yum进行安装,也可以在官网下载源码进行编译安装。 2. 配置负载均衡 在nginx的配置文件中,可以使用upstream指令来定义后端服务器的列表,然后使用proxy_pass指令将请求转发到后端…

    人工智能概览 2023年5月25日
    00
  • VS2022+libtorch+Cuda11.3安装测试教程详解(调用cuda)

    下面给您讲解“VS2022+libtorch+Cuda11.3安装测试教程详解(调用cuda)”的完整攻略。 步骤一:安装VS2022 下载VS2022安装包,可以从微软官网或者其他可靠的下载网站下载。 双击安装包进行安装,根据提示进行操作即可。 步骤二:安装Cuda11.3 下载Cuda11.3安装包,可以从NVIDIA官网或者其他可靠的下载网站下载。 双…

    人工智能概览 2023年5月25日
    00
  • pandas库中 DataFrame的用法小结

    下面是“pandas库中 DataFrame的用法小结”的完整攻略,分为以下几个部分: 1. 什么是DataFrame DataFrame是pandas库中的一种数据结构,类似于Excel中的数据表。DataFrame有行和列,行代表样本,列代表特征。DataFrame可以由多种数据源创建,包括Numpy数组、Python字典、CSV文件等。 2. 创建Da…

    人工智能概论 2023年5月25日
    00
  • 详解nginx 配置文件解读

    下面我来详细讲解“详解nginx 配置文件解读”的攻略。 什么是Nginx Nginx是一款高性能的Web服务软件,支持负载均衡和反向代理等功能,同时也是一款高可靠性的服务器,被广泛应用于各种Web服务应用场景中。 Nginx配置文件的结构 Nginx配置文件一般包括了以下五个部分 配置全局块 配置http块,包括http全局块和http server块 配…

    人工智能概览 2023年5月25日
    00
  • Go Ginrest实现一个RESTful接口

    Go Ginrest是基于Go语言和Gin框架开发的一个简化RESTful接口开发的工具库,可以大大缩短开发时间和减少代码量。下面我将介绍如何使用Go Ginrest来实现一个RESTful接口。 步骤一:安装Go Ginrest 在终端中执行以下命令: go get github.com/gin-rest-framework/gin-rest 步骤二:创建…

    人工智能概览 2023年5月25日
    00
  • 深入了解JavaScript发布订阅模式

    深入了解JavaScript发布订阅模式 什么是发布订阅模式? 发布订阅模式 是一种解耦的设计模式,它把服 务端提供的所有服务都抽象成订阅事件,客户端只需要订阅自己关注的服务即可,而不需要提前知道服务提供端的具体实现方式。服务端则维护着需要订阅的事件,同时维护了客户端列表,当某个事件被触发时,服务端向关注该事件的所有客户端发送通知。 实现发布订阅模式的步骤 …

    人工智能概览 2023年5月25日
    00
  • Nginx日志自定义记录以及启用日志缓冲区详解

    下面是关于Nginx日志自定义记录以及启用日志缓冲区的完整攻略。 什么是Nginx日志自定义记录以及启用日志缓冲区? 在使用Nginx作为Web服务器时,日志记录是非常重要的。Nginx提供了自定义记录日志的功能,以便我们可以根据需要选择需要记录的信息。同时,Nginx还有一个叫做日志缓冲区的功能,在高并发情况下,可以提高日志的写入效率。 如何在Nginx中…

    人工智能概览 2023年5月25日
    00
合作推广
合作推广
分享本页
返回顶部