pytorch固定BN层参数的操作

关于PyTorch中如何固定BN层的参数,通常有两种方法:

  1. 冻结BN层

在PyTorch中,可以通过requires_grad属性来决定一个参数是否需要被训练。为了固定BN层的参数,我们可以将其requires_grad属性设置为False,这样就不会更新其参数了。具体步骤如下:

import torch.nn as nn

bn_layer = nn.BatchNorm2d(3)  # 创建一个BatchNorm2d层
bn_layer.eval()              # 将其设为评估模式(不影响参数值,只是改变了其前向传播的行为)
for param in bn_layer.parameters():
    param.requires_grad = False  # 将所有参数的requires_grad属性设为False,即冻结了该层所有参数
  1. 使用hook

hook可以在模型中的指定位置注入自己的代码,捕获到该位置的输入,输出和参数等信息。利用hook,我们可以直接修改BN层的参数。具体步骤如下:

import torch.nn.functional as F
from torch.nn import Module, Parameter

class FrozenBN(Module):
    def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True):
        super(FrozenBN, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine

        if self.affine:
            self.weight = Parameter(torch.Tensor(num_features))
            self.bias = Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

        self.reset_parameters()
        self.eval()

    def reset_parameters(self):
        self.running_mean.zero_()
        self.running_var.fill_(1)

        if self.affine:
            nn.init.ones_(self.weight)
            nn.init.zeros_(self.bias)

    def forward(self, x):
        return F.batch_norm(
            x, self.running_mean, self.running_var, self.weight, self.bias,
            self.training, self.momentum, self.eps)

    def set_freeze(self, freeze: bool = True):
        if freeze:
            self.eval()
            self.requires_grad_ = False
        else:
            self.train()
            self.requires_grad_ = True

以上是两种常用的固定BN层的方法,下面展示两个具体的应用场景:

  1. 预训练时冻结BN层

前面提到过,冻结BN层是有利于训练的,因为在预训练阶段,一些层的权重通常已经很好了,而为每一个batch计算定量的归一化操作(BN层)会浪费很多时间。因此,在预训练过程中冻结BN层是一个很好的激进优化方法。

import torchvision
import torch.optim as optim

# load backbone network
model = torchvision.models.resnet18(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 10)

# freeze all bn layer
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        m.eval()
        m.requires_grad_ = False

# training
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(epochs):
    for i, (inputs, labels) in enumerate(train_loader):
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
  1. 从预训练模型中恢复BN层的参数

在使用PyTorch中的预训练模型时,有时需要将之前已经训练好的模型参数迁移到新模型中。然而,由于新模型中的BN层维度不同,因此无法直接将参数赋值过去。此时需要将参数逐个拷贝过去,但这个过程很麻烦。下面我们可以将FrozenBN层封装一下,通过hook的方式自动拷贝参数。

import torch
import torch.nn as nn

def copy_bn(src_bn, dst_bn):
    dst_bn.weight.data.copy_(src_bn.weight.data)
    dst_bn.bias.data.copy_(src_bn.bias.data)
    dst_bn.running_mean.data.copy_(src_bn.running_mean.data)
    dst_bn.running_var.data.copy_(src_bn.running_var.data)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
        self.bn = FrozenBN(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

def load_pretrained_model(src_net, dst_net):
    src_dict = src_net.state_dict()
    dst_dict = dst_net.state_dict()

    for key in src_dict:

        if key in dst_dict and src_dict[key].shape == dst_dict[key].shape:
            # copy the values directly if the shape is the same
            dst_dict[key] = src_dict[key]  

        elif key.startswith('conv1') and 'conv1' in dst_dict:
            # special handling for the first layer
            src_conv_weight = src_dict[key]
            dst_conv_weight = dst_dict['conv1.weight']
            dst_dict['conv1.weight'] = src_conv_weight[:, :3]
            src_bn = src_dict['bn1']
            dst_bn = dst_dict['bn1']
            copy_bn(src_bn, dst_bn)

        elif key.startswith('bn') and '2' in key:
            # handle BN layer at stride=2
            src_bn = src_dict[key]
            dst_bn = dst_dict[key.replace('2', '3')]
            copy_bn(src_bn, dst_bn)

    dst_net.load_state_dict(dst_dict)

以上代码定义了一个ConvBlock,包括一个卷积层、一个FrozenBN层和一个ReLU激活层。当从另一个模型中恢复参数时,因为模型中的BN层维度不同,我们需要将源模型中的所有BN层替换为我们定义的FrozenBN层。此时,只需在BN层后添加一个hook,在每次前向传播时,自动进行参数拷贝即可。

import torchvision

model1 = torchvision.models.resnet18(pretrained=True)
model2 = torchvision.models.resnet18()

for i, (model1_module, model2_module) in enumerate(zip(model1.modules(), model2.modules())):
    if isinstance(model1_module, nn.BatchNorm2d):
        model2_module_new = FrozenBN(model1_module.num_features)
        model2_module_new.set_freeze(freeze=True)
        model2_module.register_forward_pre_hook(model2_module_new.forward_pre_hook)
model2.eval()

# copy the parameter of model1 to model2
load_pretrained_model(model1, model2)

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch固定BN层参数的操作 - Python技术站

(0)
上一篇 2023年5月25日
下一篇 2023年5月25日

相关文章

  • 详解Django-auth-ldap 配置方法

    详解Django-auth-ldap 配置方法 简介 Django-auth-ldap 用于 Django 应用中和 LDAP 目录服务集成,提供用户认证和授权功能。在使用 Django-auth-ldap 前,需要在 Django 设置中配置 LDAP 访问,并根据您的需求配置认证、授权和同步等选项。 安装 您可以通过运行以下命令安装 Django-aut…

    人工智能概论 2023年5月25日
    00
  • 详解Centos7 源码编译安装 Nginx1.13

    详解Centos7 源码编译安装 Nginx1.13 本文详细讲解了如何在Centos7上通过源码编译的方式安装Nginx1.13,从而获得最新版本的Nginx并自定义配置启用各种功能,同时还能够加深对Nginx的理解,方便进一步进行二次开发。 环境准备 首先需要确保Centos7系统正常运行,并且已安装了必要的依赖包。如果没有,则需要提前安装。 yum i…

    人工智能概览 2023年5月25日
    00
  • django 实现电子支付功能的示例代码

    下面是 django 实现电子支付功能的示例代码的完整攻略: 1. 安装相关库 在 django 项目中实现电子支付功能,首先需要使用到相应的库。目前比较流行的有以下两个: django-payments:这是一个基于 Django 的支付应用,集成了多个第三方支付服务提供商的 SDK,可通过该应用快速实现主流的电子支付功能。 stripe:这是一家美国电子…

    人工智能概论 2023年5月24日
    00
  • .netcore 使用surging框架发布到docker

    环境准备 首先我们需要准备本地的开发环境,主要包括以下几个方面: 安装 Docker 安装 Docker Compose 安装 .NET Core SDK 创建 .NET Core 应用 我们需要创建一个 .NET Core 应用,使用 Surging 框架,这里提供一个简单的示例: 使用 Visual Studio Code 打开控制台,执行以下命令: d…

    人工智能概览 2023年5月25日
    00
  • Django框架之登录后自定义跳转页面的实现方法

    下面我会详细讲解“Django框架之登录后自定义跳转页面的实现方法”的完整攻略。 1、什么是Django框架 Django是一个基于Python语言的Web开发框架。它采用了MTV(Model-Template-View)的设计模式,使得开发者能够更轻松地开发高质量的Web应用。Django自带了Admin后台管理系统、ORM框架等,并具有高度灵活性和可扩展…

    人工智能概览 2023年5月25日
    00
  • Mac下关于PHP环境和扩展的安装详解

    Mac下关于PHP环境和扩展的安装详解 1. 安装Homebrew Homebrew 是 Mac OS 下的包管理工具,可以方便的安装一些必要的软件及扩展,通过命令行可以轻松实现。 安装 Homebrew 命令如下: /usr/bin/ruby -e “$(curl -fsSL https://raw.githubusercontent.com/Homebr…

    人工智能概览 2023年5月25日
    00
  • 深入理解nginx如何实现高性能和可扩展性

    深入理解nginx如何实现高性能和可扩展性 Nginx 是一个高性能、高可靠性的 Web 服务器和反向代理服务器。在处理高并发网络请求时,它可以同时保持较高的稳定性和扩展性。以下是 Nginx 实现高性能和可扩展性的攻略: 1.事件驱动模型 Nginx 使用了事件驱动的模型,在单个进程中处理多个并发连接,从而避免了每个连接都创建一个新进程或线程的模型。这种模…

    人工智能概览 2023年5月25日
    00
  • Anaconda下Python中GDAL模块的下载与安装过程

    下面是Anaconda下Python中GDAL模块的下载与安装过程的完整攻略: 1. 安装Anaconda 如果已经安装了Anaconda,可以跳到步骤2。 Anaconda是一个便捷的Python发行版,可以方便地安装和管理Python模块。可以从官方网站https://www.anaconda.com/products/individual下载对应版本的…

    人工智能概览 2023年5月25日
    00
合作推广
合作推广
分享本页
返回顶部