pytorch SENet实现案例

yizhihongxing

SENet是一种用于图像分类的深度神经网络,它通过引入Squeeze-and-Excitation模块来增强模型的表达能力。本文将深入浅析PyTorch中SENet的实现方法,并提供两个示例说明。

1. PyTorch中SENet的实现方法

PyTorch中SENet的实现方法如下:

import torch.nn as nn
import torch.nn.functional as F

class SEBlock(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(in_channels, in_channels // reduction_ratio)
        self.fc2 = nn.Linear(in_channels // reduction_ratio, in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc1(y)
        y = self.relu(y)
        y = self.fc2(y)
        y = self.sigmoid(y).view(b, c, 1, 1)
        return x * y

其中,in_channels是输入通道数,reduction_ratio是压缩比例。在SEBlock中,我们首先使用nn.AdaptiveAvgPool2d对输入进行全局平均池化,然后使用两个全连接层对特征进行压缩和扩张,最后使用nn.Sigmoid对特征进行缩放。在模型中使用SEBlock时,只需要将其作为一个子模块添加到模型中即可。

以下是一个示例代码,展示如何在PyTorch中使用SENet实现CIFAR-10数据集的分类任务:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# 定义SEBlock
class SEBlock(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(in_channels, in_channels // reduction_ratio)
        self.fc2 = nn.Linear(in_channels // reduction_ratio, in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc1(y)
        y = self.relu(y)
        y = self.fc2(y)
        y = self.sigmoid(y).view(b, c, 1, 1)
        return x * y

# 定义SENet
class SENet(nn.Module):
    def __init__(self):
        super(SENet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.se1 = SEBlock(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(128)
        self.se2 = SEBlock(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(256)
        self.se3 = SEBlock(256)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(512)
        self.se4 = SEBlock(512)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.se1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.se2(x)
        x = F.max_pool2d(x, 2)
        x = self.conv3(x)
        x = self.bn3(x)
        x = F.relu(x)
        x = self.se3(x)
        x = self.conv4(x)
        x = self.bn4(x)
        x = F.relu(x)
        x = self.se4(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# 加载数据集
transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.RandomCrop(32, padding=4),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

# 定义模型、损失函数和优化器
net = SENet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

# 训练模型
for epoch in range(200):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print('[%d] loss: %.3f' %
          (epoch + 1, running_loss / len(trainloader)))

# 测试模型
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

在上面的示例代码中,我们首先定义了一个SEBlock和一个SENet模型。然后,我们使用CIFAR-10数据集对模型进行训练和测试。在训练过程中,我们使用了随机水平翻转和随机裁剪等数据增强技术,并使用了带权重衰减的随机梯度下降优化器进行优化。在测试过程中,我们计算了模型在测试集上的准确率。

2. PyTorch中SENet的注意事项

在使用PyTorch中SENet时,需要注意以下几点:

  • 在SEBlock中,我们首先使用nn.AdaptiveAvgPool2d对输入进行全局平均池化,然后使用两个全连接层对特征进行压缩和扩张,最后使用nn.Sigmoid对特征进行缩放。
  • 在SENet中,我们将SEBlock作为一个子模块添加到模型中,并在模型中使用SEBlock对特征进行缩放。
  • 在训练过程中,我们可以使用数据增强技术来提高模型的泛化能力。
  • 在优化器中,我们可以使用带权重衰减的随机梯度下降优化器来防止模型过拟合。

以下是一个示例代码,展示了如何在PyTorch中使用SENet实现ImageNet数据集的分类任务:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# 定义SEBlock
class SEBlock(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(in_channels, in_channels // reduction_ratio)
        self.fc2 = nn.Linear(in_channels // reduction_ratio, in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc1(y)
        y = self.relu(y)
        y = self.fc2(y)
        y = self.sigmoid(y).view(b, c, 1, 1)
        return x * y

# 定义SENet
class SENet(nn.Module):
    def __init__(self):
        super(SENet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.se1 = SEBlock(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(128)
        self.se2 = SEBlock(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(256)
        self.se3 = SEBlock(256)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(512)
        self.se4 = SEBlock(512)
        self.conv5 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn5 = nn.BatchNorm2d(1024)
        self.se5 = SEBlock(1024)
        self.conv6 = nn.Conv2d(1024, 2048, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn6 = nn.BatchNorm2d(2048)
        self.se6 = SEBlock(2048)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(2048, 1000)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = nn.ReLU(inplace=True)(x)
        x = self.se1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = nn.ReLU(inplace=True)(x)
        x = self.se2(x)
        x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = nn.ReLU(inplace=True)(x)
        x = self.se3(x)
        x = self.conv4(x)
        x = self.bn4(x)
        x = nn.ReLU(inplace=True)(x)
        x = self.se4(x)
        x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
        x = self.conv5(x)
        x = self.bn5(x)
        x = nn.ReLU(inplace=True)(x)
        x = self.se5(x)
        x = self.conv6(x)
        x = self.bn6(x)
        x = nn.ReLU(inplace=True)(x)
        x = self.se6(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# 加载数据集
transform = transforms.Compose(
    [transforms.RandomResizedCrop(224),
     transforms.RandomHorizontalFlip(),
     transforms.ToTensor(),
     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
trainset = torchvision.datasets.ImageNet(root='./data', split='train',
                                         download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=256,
                                          shuffle=True, num_workers=8)
testset = torchvision.datasets.ImageNet(root='./data', split='val',
                                        download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=256,
                                         shuffle=False, num_workers=8)

# 定义模型、损失函数和优化器
net = SENet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)

# 训练模型
for epoch in range(90):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print('[%d] loss: %.3f' %
          (epoch + 1, running_loss / len(trainloader)))

# 测试模型
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 50000 test images: %d %%' % (
    100 * correct / total))

在上面的示例代码中,我们首先定义了一个SEBlock和一个SENet模型。然后,我们使用ImageNet数据集对模型进行训练和测试。在训练过程中,我们使用了随机裁剪和随机水平翻转等数据增强技术,并使用了带权重衰减的随机梯度下降优化器进行优化。在测试过程中,我们计算了模型在测试集上的准确率。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch SENet实现案例 - Python技术站

(0)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • PyTorch教程【二】Python编辑器的选择、安装及配置(PyCharm、Jupyter)

    详细步骤参考博客:PyCharm安装教程 二、PyCharm环境配置 可参考博客:在Pycharm中设置Anaconda环境(不完全一样) 三、PyCharm实用功能 Python Console 四、Jupyter的安装 安装了Anaconda后,默认里面就安装了Jupyter。安装Anaconda的方法可参考博客:Anaconda的安装 五、在新环境中安…

    PyTorch 2023年4月7日
    00
  • pytorch踩坑记

    因为我有数学物理背景,所以清楚卷积的原理。但是在看pytorch文档的时候感到非常头大,罗列的公式以及各种令人眩晕的下标让入门新手不知所云…最初我以为torch.nn.conv1d的参数in_channel/out_channel表示图像的通道数,经过运行错误提示之后,才知道[in_channel,kernel_size]构成了卷积核。  loss函数中…

    2023年4月6日
    00
  • Mac中PyCharm配置Anaconda环境的方法

    在Mac中,可以使用PyCharm配置Anaconda环境,以便在开发Python应用程序时使用Anaconda提供的库和工具。本文提供一个完整的攻略,以帮助您配置Anaconda环境。 步骤1:安装Anaconda 在这个示例中,我们将使用Anaconda3作为Python环境。您可以从Anaconda官网下载适用于Mac的Anaconda3安装程序,并按…

    PyTorch 2023年5月15日
    00
  • 取出预训练模型中间层的输出(pytorch)

    1 遍历子模块直接提取 对于简单的模型,可以采用直接遍历子模块的方法,取出相应name模块的输出,不对模型做任何改动。该方法的缺点在于,只能得到其子模块的输出,而对于使用nn.Sequensial()中包含很多层的模型,无法获得其指定层的输出。 示例 resnet18取出layer1的输出 from torchvision.models import res…

    2023年4月5日
    00
  • notMNIST 数据集pyTorch分类

    简介 notMNIST数据集 是于2011公布的,可以认为是MNIST数据集地一个加强版本。数据集包含了从A到J十个字母,由large与small两个子集组成。其中samll数据集是经过手工清理的,包含19k个图片,误分类率越为0.5%,large数据集是未经过手工清理的,包含500k张图片,误分类率约为6.5%。 作者推荐在large数据集上训练网络,在s…

    PyTorch 2023年4月6日
    00
  • 在Pytorch中使用Mask R-CNN进行实例分割操作

    在PyTorch中使用Mask R-CNN进行实例分割操作的完整攻略如下,包括两个示例说明。 1. 示例1:使用预训练模型进行实例分割 在PyTorch中,可以使用预训练的Mask R-CNN模型进行实例分割操作。以下是使用预训练模型进行实例分割的步骤: 安装必要的库 python !pip install torch torchvision !pip in…

    PyTorch 2023年5月15日
    00
  • 使用pytorch加载并读取COCO数据集的详细操作

    COCO(Common Objects in Context)数据集是一个广泛使用的计算机视觉数据集,其中包含超过33万张图像和超过200万个标注。在本文中,我们将介绍如何使用PyTorch加载并读取COCO数据集。 步骤1:下载COCO数据集 首先,我们需要从COCO数据集的官方网站下载数据集。可以从以下链接下载: COCO 2017 Train imag…

    PyTorch 2023年5月15日
    00
  • pytorch 学习–60分钟入个门

    pytorch视频教程 标量(Scalar)是只有大小,没有方向的量,如1,2,3等向量(Vector)是有大小和方向的量,其实就是一串数字,如(1,2)矩阵(Matrix)是好几个向量拍成一排合并而成的一堆数字,如[1,2;3,4]其实标量,向量,矩阵它们三个也是张量,标量是零维的张量,向量是一维的张量,矩阵是二维的张量。 简单相加 a+b torch.a…

    PyTorch 2023年4月8日
    00
合作推广
合作推广
分享本页
返回顶部