pytorch固定BN层参数的操作

yizhihongxing

关于PyTorch中如何固定BN层的参数,通常有两种方法:

  1. 冻结BN层

在PyTorch中,可以通过requires_grad属性来决定一个参数是否需要被训练。为了固定BN层的参数,我们可以将其requires_grad属性设置为False,这样就不会更新其参数了。具体步骤如下:

import torch.nn as nn

bn_layer = nn.BatchNorm2d(3)  # 创建一个BatchNorm2d层
bn_layer.eval()              # 将其设为评估模式(不影响参数值,只是改变了其前向传播的行为)
for param in bn_layer.parameters():
    param.requires_grad = False  # 将所有参数的requires_grad属性设为False,即冻结了该层所有参数
  1. 使用hook

hook可以在模型中的指定位置注入自己的代码,捕获到该位置的输入,输出和参数等信息。利用hook,我们可以直接修改BN层的参数。具体步骤如下:

import torch.nn.functional as F
from torch.nn import Module, Parameter

class FrozenBN(Module):
    def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True):
        super(FrozenBN, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine

        if self.affine:
            self.weight = Parameter(torch.Tensor(num_features))
            self.bias = Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

        self.reset_parameters()
        self.eval()

    def reset_parameters(self):
        self.running_mean.zero_()
        self.running_var.fill_(1)

        if self.affine:
            nn.init.ones_(self.weight)
            nn.init.zeros_(self.bias)

    def forward(self, x):
        return F.batch_norm(
            x, self.running_mean, self.running_var, self.weight, self.bias,
            self.training, self.momentum, self.eps)

    def set_freeze(self, freeze: bool = True):
        if freeze:
            self.eval()
            self.requires_grad_ = False
        else:
            self.train()
            self.requires_grad_ = True

以上是两种常用的固定BN层的方法,下面展示两个具体的应用场景:

  1. 预训练时冻结BN层

前面提到过,冻结BN层是有利于训练的,因为在预训练阶段,一些层的权重通常已经很好了,而为每一个batch计算定量的归一化操作(BN层)会浪费很多时间。因此,在预训练过程中冻结BN层是一个很好的激进优化方法。

import torchvision
import torch.optim as optim

# load backbone network
model = torchvision.models.resnet18(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 10)

# freeze all bn layer
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        m.eval()
        m.requires_grad_ = False

# training
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(epochs):
    for i, (inputs, labels) in enumerate(train_loader):
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
  1. 从预训练模型中恢复BN层的参数

在使用PyTorch中的预训练模型时,有时需要将之前已经训练好的模型参数迁移到新模型中。然而,由于新模型中的BN层维度不同,因此无法直接将参数赋值过去。此时需要将参数逐个拷贝过去,但这个过程很麻烦。下面我们可以将FrozenBN层封装一下,通过hook的方式自动拷贝参数。

import torch
import torch.nn as nn

def copy_bn(src_bn, dst_bn):
    dst_bn.weight.data.copy_(src_bn.weight.data)
    dst_bn.bias.data.copy_(src_bn.bias.data)
    dst_bn.running_mean.data.copy_(src_bn.running_mean.data)
    dst_bn.running_var.data.copy_(src_bn.running_var.data)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
        self.bn = FrozenBN(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

def load_pretrained_model(src_net, dst_net):
    src_dict = src_net.state_dict()
    dst_dict = dst_net.state_dict()

    for key in src_dict:

        if key in dst_dict and src_dict[key].shape == dst_dict[key].shape:
            # copy the values directly if the shape is the same
            dst_dict[key] = src_dict[key]  

        elif key.startswith('conv1') and 'conv1' in dst_dict:
            # special handling for the first layer
            src_conv_weight = src_dict[key]
            dst_conv_weight = dst_dict['conv1.weight']
            dst_dict['conv1.weight'] = src_conv_weight[:, :3]
            src_bn = src_dict['bn1']
            dst_bn = dst_dict['bn1']
            copy_bn(src_bn, dst_bn)

        elif key.startswith('bn') and '2' in key:
            # handle BN layer at stride=2
            src_bn = src_dict[key]
            dst_bn = dst_dict[key.replace('2', '3')]
            copy_bn(src_bn, dst_bn)

    dst_net.load_state_dict(dst_dict)

以上代码定义了一个ConvBlock,包括一个卷积层、一个FrozenBN层和一个ReLU激活层。当从另一个模型中恢复参数时,因为模型中的BN层维度不同,我们需要将源模型中的所有BN层替换为我们定义的FrozenBN层。此时,只需在BN层后添加一个hook,在每次前向传播时,自动进行参数拷贝即可。

import torchvision

model1 = torchvision.models.resnet18(pretrained=True)
model2 = torchvision.models.resnet18()

for i, (model1_module, model2_module) in enumerate(zip(model1.modules(), model2.modules())):
    if isinstance(model1_module, nn.BatchNorm2d):
        model2_module_new = FrozenBN(model1_module.num_features)
        model2_module_new.set_freeze(freeze=True)
        model2_module.register_forward_pre_hook(model2_module_new.forward_pre_hook)
model2.eval()

# copy the parameter of model1 to model2
load_pretrained_model(model1, model2)

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch固定BN层参数的操作 - Python技术站

(0)
上一篇 2023年5月25日
下一篇 2023年5月25日

相关文章

  • Django 导出 Excel 代码的实例详解

    下面是“Django 导出 Excel 代码的实例详解”。 目录 前言 安装依赖 代码实现 准备工作 HttpResponse 类型 FileResponse 类型 示例说明 示例1:HttpResponse 类型 示例2:FileResponse 类型 总结 前言 在Django开发中,有时我们需要将数据导出为Excel格式的文件,方便数据的分享和查看。本…

    人工智能概论 2023年5月24日
    00
  • Vue项目History模式404问题解决方法

    下面是“Vue项目History模式404问题解决方法”的完整攻略: 问题背景 在Vue项目中,我们可以选择使用History模式路由,以去除URL中的#符号。但是,在使用History模式路由时,如果浏览器直接访问某个路由或者刷新当前页面,就会出现404错误。 问题原因 在使用History模式路由时,当用户在浏览器中输入某个路由地址,或者在浏览器中刷新页…

    人工智能概览 2023年5月25日
    00
  • python Web flask 视图内容和模板实现代码

    Python Web 中,Flask 框架的视图函数和模板是实现动态 Web 应用的核心。下面我将为您提供完整的攻略。 一、Flask 视图实现 在 Flask 中,视图函数是用于处理 Web 请求并生成 Web 响应的函数。视图函数通常使用 Flask 提供的装饰器 @app.route() 来将函数绑定到一个 URL 路径上,例如: from flask…

    人工智能概论 2023年5月25日
    00
  • Django 解决新建表删除后无法重新创建等问题

    下面是基于Django的解决新建表删除后无法重新创建等问题的完整攻略。 问题描述 在使用Django开发时,有时候我们会遇到新建数据表之后,再次删除数据表时会出现无法重新创建数据表的情况。 这种情况通常出现在我们删除数据表之后,模型元数据表中仍然保留着该数据表的记录。如果我们重新创建同名数据表,Django会发现元数据表中已经保存了同名数据表的信息,进而拒绝…

    人工智能概论 2023年5月25日
    00
  • 深入学习spring cloud gateway 限流熔断

    深入学习Spring Cloud Gateway 限流熔断攻略 什么是Spring Cloud Gateway Spring Cloud Gateway是一个构建在Spring Framework 5,Project Reactor和Spring Boot 2之上的网关,可以作为所有基于HTTP路由的API的入口点。它提供了一种简单而有效的方式来传递客户端请…

    人工智能概览 2023年5月25日
    00
  • MongoDB中连接字符串的编写

    MongoDB中连接字符串是用于连接MongoDB数据库的字符串,通常由多个参数组成,包括主机名、端口号、认证信息等,构成一条完整的URL连接。下面是MongoDB连接字符串编写的完整攻略: 编写连接字符串的基本格式 MongoDB连接字符串的基本格式为: mongodb://[username:password@]host1[:port1][,host2[…

    人工智能概论 2023年5月25日
    00
  • django 微信网页授权认证api的步骤详解

    下面就来详细讲解“django 微信网页授权认证api的步骤详解”: 1. 概述 网页授权是通过OAuth2.0机制实现的,即用户打开第三方网页时,第三方网页要获取用户的微信基本信息(如昵称、头像等信息)时,需要用户授权才能获取到。本文将介绍如何在Django中使用微信网页授权认证API。 2. 步骤 2.1 获取用户授权链接 第一步是获取用户授权链接。用户…

    人工智能概览 2023年5月25日
    00
  • 详解angularjs的数组传参方式的简单实现

    首先,我们需要了解AngularJS中数组参数的传递方式。在AngularJS中,数组可以通过以下两种方式来传递参数: 1. 通过$scope 我们可以在控制器(Controller)中定义一个数组,并将其赋值给$scope对象。然后,我们可以在HTML视图(View)中使用ng-repeat指令来遍历该数组。下面是一个示例代码: // 在控制器中定义一个数…

    人工智能概览 2023年5月25日
    00
合作推广
合作推广
分享本页
返回顶部