SENet是一种用于图像分类的深度神经网络,它通过引入Squeeze-and-Excitation模块来增强模型的表达能力。本文将深入浅析PyTorch中SENet的实现方法,并提供两个示例说明。
1. PyTorch中SENet的实现方法
PyTorch中SENet的实现方法如下:
import torch.nn as nn
import torch.nn.functional as F
class SEBlock(nn.Module):
def __init__(self, in_channels, reduction_ratio=16):
super(SEBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Linear(in_channels, in_channels // reduction_ratio)
self.fc2 = nn.Linear(in_channels // reduction_ratio, in_channels)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc1(y)
y = self.relu(y)
y = self.fc2(y)
y = self.sigmoid(y).view(b, c, 1, 1)
return x * y
其中,in_channels
是输入通道数,reduction_ratio
是压缩比例。在SEBlock中,我们首先使用nn.AdaptiveAvgPool2d
对输入进行全局平均池化,然后使用两个全连接层对特征进行压缩和扩张,最后使用nn.Sigmoid
对特征进行缩放。在模型中使用SEBlock时,只需要将其作为一个子模块添加到模型中即可。
以下是一个示例代码,展示如何在PyTorch中使用SENet实现CIFAR-10数据集的分类任务:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# 定义SEBlock
class SEBlock(nn.Module):
def __init__(self, in_channels, reduction_ratio=16):
super(SEBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Linear(in_channels, in_channels // reduction_ratio)
self.fc2 = nn.Linear(in_channels // reduction_ratio, in_channels)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc1(y)
y = self.relu(y)
y = self.fc2(y)
y = self.sigmoid(y).view(b, c, 1, 1)
return x * y
# 定义SENet
class SENet(nn.Module):
def __init__(self):
super(SENet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.se1 = SEBlock(64)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(128)
self.se2 = SEBlock(128)
self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(256)
self.se3 = SEBlock(256)
self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=False)
self.bn4 = nn.BatchNorm2d(512)
self.se4 = SEBlock(512)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(512, 10)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.se1(x)
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x)
x = self.se2(x)
x = F.max_pool2d(x, 2)
x = self.conv3(x)
x = self.bn3(x)
x = F.relu(x)
x = self.se3(x)
x = self.conv4(x)
x = self.bn4(x)
x = F.relu(x)
x = self.se4(x)
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# 加载数据集
transform = transforms.Compose(
[transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
shuffle=False, num_workers=2)
# 定义模型、损失函数和优化器
net = SENet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
# 训练模型
for epoch in range(200):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print('[%d] loss: %.3f' %
(epoch + 1, running_loss / len(trainloader)))
# 测试模型
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))
在上面的示例代码中,我们首先定义了一个SEBlock和一个SENet模型。然后,我们使用CIFAR-10数据集对模型进行训练和测试。在训练过程中,我们使用了随机水平翻转和随机裁剪等数据增强技术,并使用了带权重衰减的随机梯度下降优化器进行优化。在测试过程中,我们计算了模型在测试集上的准确率。
2. PyTorch中SENet的注意事项
在使用PyTorch中SENet时,需要注意以下几点:
- 在SEBlock中,我们首先使用
nn.AdaptiveAvgPool2d
对输入进行全局平均池化,然后使用两个全连接层对特征进行压缩和扩张,最后使用nn.Sigmoid
对特征进行缩放。 - 在SENet中,我们将SEBlock作为一个子模块添加到模型中,并在模型中使用SEBlock对特征进行缩放。
- 在训练过程中,我们可以使用数据增强技术来提高模型的泛化能力。
- 在优化器中,我们可以使用带权重衰减的随机梯度下降优化器来防止模型过拟合。
以下是一个示例代码,展示了如何在PyTorch中使用SENet实现ImageNet数据集的分类任务:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# 定义SEBlock
class SEBlock(nn.Module):
def __init__(self, in_channels, reduction_ratio=16):
super(SEBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Linear(in_channels, in_channels // reduction_ratio)
self.fc2 = nn.Linear(in_channels // reduction_ratio, in_channels)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc1(y)
y = self.relu(y)
y = self.fc2(y)
y = self.sigmoid(y).view(b, c, 1, 1)
return x * y
# 定义SENet
class SENet(nn.Module):
def __init__(self):
super(SENet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.se1 = SEBlock(64)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(128)
self.se2 = SEBlock(128)
self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(256)
self.se3 = SEBlock(256)
self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=False)
self.bn4 = nn.BatchNorm2d(512)
self.se4 = SEBlock(512)
self.conv5 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1, bias=False)
self.bn5 = nn.BatchNorm2d(1024)
self.se5 = SEBlock(1024)
self.conv6 = nn.Conv2d(1024, 2048, kernel_size=3, stride=1, padding=1, bias=False)
self.bn6 = nn.BatchNorm2d(2048)
self.se6 = SEBlock(2048)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(2048, 1000)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = nn.ReLU(inplace=True)(x)
x = self.se1(x)
x = self.conv2(x)
x = self.bn2(x)
x = nn.ReLU(inplace=True)(x)
x = self.se2(x)
x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
x = self.conv3(x)
x = self.bn3(x)
x = nn.ReLU(inplace=True)(x)
x = self.se3(x)
x = self.conv4(x)
x = self.bn4(x)
x = nn.ReLU(inplace=True)(x)
x = self.se4(x)
x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
x = self.conv5(x)
x = self.bn5(x)
x = nn.ReLU(inplace=True)(x)
x = self.se5(x)
x = self.conv6(x)
x = self.bn6(x)
x = nn.ReLU(inplace=True)(x)
x = self.se6(x)
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# 加载数据集
transform = transforms.Compose(
[transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
trainset = torchvision.datasets.ImageNet(root='./data', split='train',
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=256,
shuffle=True, num_workers=8)
testset = torchvision.datasets.ImageNet(root='./data', split='val',
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=256,
shuffle=False, num_workers=8)
# 定义模型、损失函数和优化器
net = SENet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
# 训练模型
for epoch in range(90):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print('[%d] loss: %.3f' %
(epoch + 1, running_loss / len(trainloader)))
# 测试模型
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 50000 test images: %d %%' % (
100 * correct / total))
在上面的示例代码中,我们首先定义了一个SEBlock和一个SENet模型。然后,我们使用ImageNet数据集对模型进行训练和测试。在训练过程中,我们使用了随机裁剪和随机水平翻转等数据增强技术,并使用了带权重衰减的随机梯度下降优化器进行优化。在测试过程中,我们计算了模型在测试集上的准确率。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch SENet实现案例 - Python技术站