关于PyTorch中如何固定BN层的参数,通常有两种方法:
- 冻结BN层
在PyTorch中,可以通过requires_grad
属性来决定一个参数是否需要被训练。为了固定BN层的参数,我们可以将其requires_grad
属性设置为False
,这样就不会更新其参数了。具体步骤如下:
import torch.nn as nn
bn_layer = nn.BatchNorm2d(3) # 创建一个BatchNorm2d层
bn_layer.eval() # 将其设为评估模式(不影响参数值,只是改变了其前向传播的行为)
for param in bn_layer.parameters():
param.requires_grad = False # 将所有参数的requires_grad属性设为False,即冻结了该层所有参数
- 使用hook
hook可以在模型中的指定位置注入自己的代码,捕获到该位置的输入,输出和参数等信息。利用hook,我们可以直接修改BN层的参数。具体步骤如下:
import torch.nn.functional as F
from torch.nn import Module, Parameter
class FrozenBN(Module):
def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True):
super(FrozenBN, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
if self.affine:
self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
self.eval()
def reset_parameters(self):
self.running_mean.zero_()
self.running_var.fill_(1)
if self.affine:
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
return F.batch_norm(
x, self.running_mean, self.running_var, self.weight, self.bias,
self.training, self.momentum, self.eps)
def set_freeze(self, freeze: bool = True):
if freeze:
self.eval()
self.requires_grad_ = False
else:
self.train()
self.requires_grad_ = True
以上是两种常用的固定BN层的方法,下面展示两个具体的应用场景:
- 预训练时冻结BN层
前面提到过,冻结BN层是有利于训练的,因为在预训练阶段,一些层的权重通常已经很好了,而为每一个batch计算定量的归一化操作(BN层)会浪费很多时间。因此,在预训练过程中冻结BN层是一个很好的激进优化方法。
import torchvision
import torch.optim as optim
# load backbone network
model = torchvision.models.resnet18(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 10)
# freeze all bn layer
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
m.requires_grad_ = False
# training
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(epochs):
for i, (inputs, labels) in enumerate(train_loader):
outputs = model(inputs)
loss = loss_fn(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
- 从预训练模型中恢复BN层的参数
在使用PyTorch中的预训练模型时,有时需要将之前已经训练好的模型参数迁移到新模型中。然而,由于新模型中的BN层维度不同,因此无法直接将参数赋值过去。此时需要将参数逐个拷贝过去,但这个过程很麻烦。下面我们可以将FrozenBN层封装一下,通过hook的方式自动拷贝参数。
import torch
import torch.nn as nn
def copy_bn(src_bn, dst_bn):
dst_bn.weight.data.copy_(src_bn.weight.data)
dst_bn.bias.data.copy_(src_bn.bias.data)
dst_bn.running_mean.data.copy_(src_bn.running_mean.data)
dst_bn.running_var.data.copy_(src_bn.running_var.data)
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
self.bn = FrozenBN(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
def load_pretrained_model(src_net, dst_net):
src_dict = src_net.state_dict()
dst_dict = dst_net.state_dict()
for key in src_dict:
if key in dst_dict and src_dict[key].shape == dst_dict[key].shape:
# copy the values directly if the shape is the same
dst_dict[key] = src_dict[key]
elif key.startswith('conv1') and 'conv1' in dst_dict:
# special handling for the first layer
src_conv_weight = src_dict[key]
dst_conv_weight = dst_dict['conv1.weight']
dst_dict['conv1.weight'] = src_conv_weight[:, :3]
src_bn = src_dict['bn1']
dst_bn = dst_dict['bn1']
copy_bn(src_bn, dst_bn)
elif key.startswith('bn') and '2' in key:
# handle BN layer at stride=2
src_bn = src_dict[key]
dst_bn = dst_dict[key.replace('2', '3')]
copy_bn(src_bn, dst_bn)
dst_net.load_state_dict(dst_dict)
以上代码定义了一个ConvBlock,包括一个卷积层、一个FrozenBN层和一个ReLU激活层。当从另一个模型中恢复参数时,因为模型中的BN层维度不同,我们需要将源模型中的所有BN层替换为我们定义的FrozenBN层。此时,只需在BN层后添加一个hook,在每次前向传播时,自动进行参数拷贝即可。
import torchvision
model1 = torchvision.models.resnet18(pretrained=True)
model2 = torchvision.models.resnet18()
for i, (model1_module, model2_module) in enumerate(zip(model1.modules(), model2.modules())):
if isinstance(model1_module, nn.BatchNorm2d):
model2_module_new = FrozenBN(model1_module.num_features)
model2_module_new.set_freeze(freeze=True)
model2_module.register_forward_pre_hook(model2_module_new.forward_pre_hook)
model2.eval()
# copy the parameter of model1 to model2
load_pretrained_model(model1, model2)
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch固定BN层参数的操作 - Python技术站