pytorch固定BN层参数的操作

关于PyTorch中如何固定BN层的参数,通常有两种方法:

  1. 冻结BN层

在PyTorch中,可以通过requires_grad属性来决定一个参数是否需要被训练。为了固定BN层的参数,我们可以将其requires_grad属性设置为False,这样就不会更新其参数了。具体步骤如下:

import torch.nn as nn

bn_layer = nn.BatchNorm2d(3)  # 创建一个BatchNorm2d层
bn_layer.eval()              # 将其设为评估模式(不影响参数值,只是改变了其前向传播的行为)
for param in bn_layer.parameters():
    param.requires_grad = False  # 将所有参数的requires_grad属性设为False,即冻结了该层所有参数
  1. 使用hook

hook可以在模型中的指定位置注入自己的代码,捕获到该位置的输入,输出和参数等信息。利用hook,我们可以直接修改BN层的参数。具体步骤如下:

import torch.nn.functional as F
from torch.nn import Module, Parameter

class FrozenBN(Module):
    def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True):
        super(FrozenBN, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine

        if self.affine:
            self.weight = Parameter(torch.Tensor(num_features))
            self.bias = Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

        self.reset_parameters()
        self.eval()

    def reset_parameters(self):
        self.running_mean.zero_()
        self.running_var.fill_(1)

        if self.affine:
            nn.init.ones_(self.weight)
            nn.init.zeros_(self.bias)

    def forward(self, x):
        return F.batch_norm(
            x, self.running_mean, self.running_var, self.weight, self.bias,
            self.training, self.momentum, self.eps)

    def set_freeze(self, freeze: bool = True):
        if freeze:
            self.eval()
            self.requires_grad_ = False
        else:
            self.train()
            self.requires_grad_ = True

以上是两种常用的固定BN层的方法,下面展示两个具体的应用场景:

  1. 预训练时冻结BN层

前面提到过,冻结BN层是有利于训练的,因为在预训练阶段,一些层的权重通常已经很好了,而为每一个batch计算定量的归一化操作(BN层)会浪费很多时间。因此,在预训练过程中冻结BN层是一个很好的激进优化方法。

import torchvision
import torch.optim as optim

# load backbone network
model = torchvision.models.resnet18(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 10)

# freeze all bn layer
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        m.eval()
        m.requires_grad_ = False

# training
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(epochs):
    for i, (inputs, labels) in enumerate(train_loader):
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
  1. 从预训练模型中恢复BN层的参数

在使用PyTorch中的预训练模型时,有时需要将之前已经训练好的模型参数迁移到新模型中。然而,由于新模型中的BN层维度不同,因此无法直接将参数赋值过去。此时需要将参数逐个拷贝过去,但这个过程很麻烦。下面我们可以将FrozenBN层封装一下,通过hook的方式自动拷贝参数。

import torch
import torch.nn as nn

def copy_bn(src_bn, dst_bn):
    dst_bn.weight.data.copy_(src_bn.weight.data)
    dst_bn.bias.data.copy_(src_bn.bias.data)
    dst_bn.running_mean.data.copy_(src_bn.running_mean.data)
    dst_bn.running_var.data.copy_(src_bn.running_var.data)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
        self.bn = FrozenBN(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

def load_pretrained_model(src_net, dst_net):
    src_dict = src_net.state_dict()
    dst_dict = dst_net.state_dict()

    for key in src_dict:

        if key in dst_dict and src_dict[key].shape == dst_dict[key].shape:
            # copy the values directly if the shape is the same
            dst_dict[key] = src_dict[key]  

        elif key.startswith('conv1') and 'conv1' in dst_dict:
            # special handling for the first layer
            src_conv_weight = src_dict[key]
            dst_conv_weight = dst_dict['conv1.weight']
            dst_dict['conv1.weight'] = src_conv_weight[:, :3]
            src_bn = src_dict['bn1']
            dst_bn = dst_dict['bn1']
            copy_bn(src_bn, dst_bn)

        elif key.startswith('bn') and '2' in key:
            # handle BN layer at stride=2
            src_bn = src_dict[key]
            dst_bn = dst_dict[key.replace('2', '3')]
            copy_bn(src_bn, dst_bn)

    dst_net.load_state_dict(dst_dict)

以上代码定义了一个ConvBlock,包括一个卷积层、一个FrozenBN层和一个ReLU激活层。当从另一个模型中恢复参数时,因为模型中的BN层维度不同,我们需要将源模型中的所有BN层替换为我们定义的FrozenBN层。此时,只需在BN层后添加一个hook,在每次前向传播时,自动进行参数拷贝即可。

import torchvision

model1 = torchvision.models.resnet18(pretrained=True)
model2 = torchvision.models.resnet18()

for i, (model1_module, model2_module) in enumerate(zip(model1.modules(), model2.modules())):
    if isinstance(model1_module, nn.BatchNorm2d):
        model2_module_new = FrozenBN(model1_module.num_features)
        model2_module_new.set_freeze(freeze=True)
        model2_module.register_forward_pre_hook(model2_module_new.forward_pre_hook)
model2.eval()

# copy the parameter of model1 to model2
load_pretrained_model(model1, model2)

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch固定BN层参数的操作 - Python技术站

(0)
上一篇 2023年5月25日
下一篇 2023年5月25日

相关文章

  • Python实现构建一个仪表板的示例代码

    Python实现构建一个仪表板的示例代码可以通过以下步骤实现: 1. 安装必要的库 为了构建一个仪表板,我们需要使用一些Python库。最常用的库是Dash,它是一个基于Flask和ReactJS的Python Web框架。使用Dash,可以轻松地构建数据可视化仪表板。Dash需要配合Plotly和Pandas等其他库一起使用。 !pip install d…

    人工智能概论 2023年5月25日
    00
  • redis 限制内存使用大小的实现

    Redis是一个使用内存作为数据存储方式的高性能key-value数据库。由于内存资源的限制,设置使用Redis时需要对其进行一定的内存限制,以避免Redis使用过多内存导致服务器宕机。 下面将详细讲解Redis限制内存使用大小的实现攻略。 使用maxmemory配置项 Redis提供了maxmemory配置项,用于设置Redis所使用的内存上限。该配置项的…

    人工智能概览 2023年5月25日
    00
  • SpringBoot轻松整合MongoDB的全过程记录

    SpringBoot轻松整合MongoDB的全过程记录 简介 MongoDB是一个NoSQL数据库,以文档形式储存数据。Spring Boot作为一个快速开发框架,可以轻松整合MongoDB数据库。本文将介绍如何使用Spring Boot轻松地整合MongoDB。 步骤 步骤1:添加Maven依赖 在pom.xml文件中添加以下依赖: <depende…

    人工智能概论 2023年5月25日
    00
  • Django使用rest_framework写出API

    下面是关于“Django使用rest_framework写出API”的完整攻略。 1. 安装Django和rest_framework 在开始使用Django中的rest_framework库编写API之前,需要安装Django和rest_framework库,我们可以通过以下命令进行安装: pip install django pip install dj…

    人工智能概论 2023年5月25日
    00
  • Apache如何部署django项目

    下面是 Apache 如何部署 Django 项目的完整攻略: 一、在 Apache 中配置 mod_wsgi 模块 Apache 是一款广泛使用的 Web 服务器,而 mod_wsgi 是一款可以在 Apache 上运行 Python 代码的模块。因此,为了部署 Django 项目,我们首先需要在 Apache 中配置 mod_wsgi 模块。 安装 mo…

    人工智能概览 2023年5月25日
    00
  • django的settings中设置中文支持的实现

    当我们使用 Django 开发网站时,如果需要支持中文,需要在 Django 的 settings.py 文件中进行相应的配置。下面是实现中文支持的具体步骤: 在 Django 项目的 settings.py 文件中,找到 LANGUAGE_CODE 和 TIME_ZONE 两个选项,分别设置成你需要的语言和时区。比如: “` LANGUAGE_CODE …

    人工智能概览 2023年5月25日
    00
  • python中的随机数种子seed()用法说明

    Python中的随机数种子seed()用法说明 什么是随机数种子 在计算机科学中,随机数生成算法是一种用于生成随机数的算法,这个过程也被称为随机数生成器。随机数生成器的输入被称为“种子”,产生的输出被成为随机数。 随机数、伪随机数生成器产生随机或伪随机数字序列的质量取决于选择种子(输入)。如果使用相同的种子调用随机数生成器两次,它将会产生相同的数字序列。 一…

    人工智能概览 2023年5月25日
    00
  • Mongodb3.0.5 副本集搭建及spring和java连接副本集配置详细介绍

    Mongodb3.0.5 副本集搭建及spring和java连接副本集配置详细介绍: 搭建副本集 准备工作 在三台服务器上安装 MongoDB,建议都使用相同的版本 为每台服务器创建并开放 MongoDB 的端口(默认端口为 27017) 配置每台服务器的主机名并添加到 /etc/hosts 文件中,例如: 192.168.1.101 mongo1 192.…

    人工智能概论 2023年5月25日
    00
合作推广
合作推广
分享本页
返回顶部