pytorch torch.nn.AdaptiveAvgPool2d()自适应平均池化函数详解

PyTorch torch.nn.AdaptiveAvgPool2d() 自适应池化函数详解

池化操作简介

在深度学习的卷积神经网络(Convolutional Neural Network,CNN)中,池化操作是常用的一种非线性操作,用于缩小特征图尺寸和提取主要特征。

普通池化操作

普通池化操作,也称为固定池化(Fixed Pooling),是指一种对特征图按照固定大小裁剪并执行固定步幅下采样的操作。普通池化操作一般有 Max PoolingMean Pooling 两种,其中 Max Pooling 取所有像素点中最大值,而 Mean Pooling 取像素点的平均值。

自适应池化操作

相比于固定池化操作,自适应池化操作则可动态地针对特征图的大小进行调整,以适应不同大小的输入特征图和不同的网络结构。

PyTorch 中的自适应池化函数 torch.nn.AdaptiveAvgPool2d() 即可实现自适应平均池化操作。

torch.nn.AdaptiveAvgPool2d()

介绍

torch.nn.AdaptiveAvgPool2d(output_size)

自适应平均池化操作,它会对输入的 feature map 进行自适应的裁剪与池化,选择输出大小为 output_size 的 feature map。

具体地,该函数的输入为一个四维的 Tensor,维度分别为 (batch_size, C, H, W),其中 batch_size 表示输入数据的 batch 大小,C 表示输入数据的通道数,HW 分别为输入数据的高度和宽度。

该函数的输出也是一个四维的 Tensor,其维度同样为 (batch_size, C, H, W),其中 HWoutput_size所指定的值。

参数

  • output_size (tuple[int]) – 输出的特征图的大小,格式为 (height, width) ,也可以是单个 int 值,表示输出特征图高度和宽度的相同长度。

用法示例1

我们来看一个使用 AdaptiveAvgPool2d() 的示例,这个示例代码用于实现任意大小输入图像的分类模型。

import torch.nn as nn
import torch

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1,1)),
        )
        self.fc = nn.Linear(32, 10)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

net = CNN()
x = torch.randn(1, 3, 32, 32)
y = net(x)
print(y.shape)

在上面的示例代码中,首先使用了一组卷积、池化来提取特征,最后使用了自适应池化层将得到的特征图转换成向量,再通过全连接层完成最终的分类任务。由于自适应池化层的输入大小可以是任意的,所以无论输入图像的大小是多少,都可以适应。

用法示例2

下面我们再来看一个使用 AdaptiveAvgPool2d() 的示例,这个示例代码用于提取图像的全局特征。

import torch
import torch.nn as nn

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.shape[0], -1)

class Model(nn.Module):
    def __init__(self, input_shape=(3, 64, 64), output_dim=10):
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
        )

        self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
        self.flatten = Flatten()
        self.fc = nn.Linear(256, output_dim)

    def forward(self, x):
        z = self.features(x)
        z = self.avgpool(z)
        z = self.flatten(z)
        logits = self.fc(z)
        return logits

net = Model()
x = torch.randn(1, 3, 64, 64)
y = net(x)
print(y.shape)

在上面的示例代码中,首先使用一组卷积层提取图像的特征,随后使用 nn.AdaptiveAvgPool2d(output_size=1) 自适应平均池化层将其转换为全局特征。最后通过全连接层完成最终的分类任务。

总结

torch.nn.AdaptiveAvgPool2d() 函数是 PyTorch 提供的一个自适应平均池化函数,可以根据输入数据的大小动态地进行特征图的裁剪和池化,适用于各种大小的输入特征图和不同的网络结构。可以用于构建各种基于卷积神经网络模型的分类、分割、检测等任务。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch torch.nn.AdaptiveAvgPool2d()自适应平均池化函数详解 - Python技术站

(0)
上一篇 2023年6月13日
下一篇 2023年6月13日

相关文章

  • Python 生成器yield原理及用法

    当我们在编写 Python 程序时,如果需要对大量数据进行处理,一般会考虑使用迭代器。但是,如果我们使用列表等数据结构作为迭代器,会面临一些问题,如占用过多的内存资源等。这时,Python 提供了生成器可以解决这些问题。 生成器是一种特殊的迭代器,可以通过函数来实现,使用 yield 关键字实现迭代器的功能,并且在使用时能够节省大量的内存资源。下面依次讲解 …

    python 2023年6月13日
    00
  • pandas中fillna()函数填充NaN和None的实现

    在pandas中,fillna()函数被广泛用于填充数据中存在的NaN或None值,以便能够更方便地进行数据分析和处理。下面是该函数的详细攻略和两条示例说明。 1. 基本语法 DataFrame.fillna(value=None, method=None, axis=None, inplace=False, limit=None, downcast=Non…

    python 2023年6月13日
    00
  • 如何在向量化NumPy数组上进行移动窗口

    在NumPy中使用移动窗口是常见的数据处理操作。移动窗口可以用于计算滑动平均值、滑动方差及其他一些统计量。在NumPy中,执行这些计算的最有效的方法之一是向量化。 下面是如何在向量化NumPy数组上进行移动窗口的完整攻略: 准备数据 首先,我们需要准备要进行移动窗口计算的数据。我们可以使用rand函数生成一组随机数据。 import numpy as np …

    python 2023年6月13日
    00
  • 浅谈pandas中shift和diff函数关系

    浅谈pandas中shift和diff函数关系 简介 在Pandas中,shift和diff两个函数都是用于时间序列数据分析的常用函数,它们具有不同的作用。在本文中,我们将会详细讲解这两个函数,并说明它们之间的关系。 shift函数 shift函数用于将时间序列数据沿着时间轴移动指定的时间步长,可以用来计算相邻时间点之间的差异,或者用于实现滑动窗口操作等功能…

    python 2023年6月13日
    00
  • NodeJS使用Range请求实现下载功能的方法示例

    标题:NodeJS使用Range请求实现下载功能的方法示例 简介 NodeJS是一个基于事件驱动的异步I/O框架,可以轻松地实现文件的读写操作。在本文中,我们将介绍如何使用NodeJS的Range请求实现文件的分块下载功能。该功能可以使得下载大文件时更加快速且可靠,并且用户可以暂停和继续下载,而无需重新下载整个文件。 实现方法 文件分块下载通常是通过在HTT…

    python 2023年6月13日
    00
  • 详解matplotlib中pyplot和面向对象两种绘图模式之间的关系

    详解matplotlib中pyplot和面向对象两种绘图模式之间的关系 matplotlib绘图模式 matplotlib是Python进行数据可视化的重要库之一。在matplotlib中,数据可视化都是通过绘制图形来完成的,而绘制图形的方式则有两种:pyplot和面向对象两种方式。在pyplot方式下,我们可以直接调用函数来绘制出所需的图形,而在面向对象方…

    python 2023年6月13日
    00
  • Python趣味挑战之用pygame实现简单的金币旋转效果

    Python趣味挑战之用pygame实现简单的金币旋转效果教程如下: 课程介绍 Python是一个强大的编程语言,可以用于开发各种应用程序,包括图形界面、游戏、网站等。而pygame是一个基于Python的多媒体库,专门用于开发2D游戏。在这个课程中,我们将会用pygame实现简单的金币旋转效果,让你学会如何用Python和pygame开发2D游戏。 环境准…

    python 2023年6月13日
    00
  • 教你用Python matplotlib库制作简单的动画

    下面是关于“教你用Python matplotlib库制作简单的动画”的完整攻略: 1. 简介 matplotlib是Python中常用的绘图库,除了静态的图形外,它还可以制作动画效果。利用动画,我们可以更好的展示数据或者进行数据故事化呈现。 2. 准备工作 2.1 安装matplotlib 在开始前,需要确保你已经安装好了matplotlib库。如果没有安…

    python 2023年6月13日
    00
合作推广
合作推广
分享本页
返回顶部