地址:https://arxiv.org/pdf/2007.04242.pdf

github:https://github.com/zhuogege1943/dgc/

 

动态分组卷积-Dynamic Group Convolution for Accelerating Convolutional Neural Networks

 

from __future__ import absolute_import
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division

import torch
import torch.nn as nn
import torch.nn.functional as F

class DynamicMultiHeadConv(nn.Module):
    global_progress = 0.0
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, heads=4, squeeze_rate=16, gate_factor=0.25):
        super(DynamicMultiHeadConv, self).__init__()
        self.norm = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.squeeze_rate = squeeze_rate
        self.gate_factor = gate_factor
        self.stride = stride 
        self.padding = padding 
        self.dilation = dilation
        self.is_pruned = True
        self.register_buffer('_inactive_channels', torch.zeros(1))

        ### Check if arguments are valid
        assert self.in_channels % self.heads == 0, 
                "head number can not be divided by input channels"
        assert self.out_channels % self.heads == 0, 
                "head number can not be divided by output channels"
        assert self.gate_factor <= 1.0, "gate factor is greater than 1"

        for i in range(self.heads):
            self.__setattr__('headconv_%1d' % i, 
                    HeadConv(in_channels, out_channels // self.heads, squeeze_rate, 
                    kernel_size, stride, padding, dilation, 1, gate_factor))

    def forward(self, x):
        """
        The code here is just a coarse implementation.
        The forward process can be quite slow and memory consuming, need to be optimized.
        """
        if self.training:
            progress = DynamicMultiHeadConv.global_progress
            # gradually deactivate input channels
            if progress < 3.0 / 4 and progress > 1.0 / 12:
                self.inactive_channels = round(self.in_channels * (1 - self.gate_factor) * 3.0 / 2 * (progress - 1.0 / 12))
            elif progress >= 3.0 / 4:
                self.inactive_channels = round(self.in_channels * (1 - self.gate_factor))

        _lasso_loss = 0.0

        x = self.norm(x)
        x = self.relu(x)

        x_averaged = self.avg_pool(x)
        x_mask = []
        weight = []
        for i in range(self.heads):
            i_x, i_lasso_loss= self.__getattr__('headconv_%1d' % i)(x, x_averaged, self.inactive_channels)
            x_mask.append(i_x)
            weight.append(self.__getattr__('headconv_%1d' % i).conv.weight)
            _lasso_loss = _lasso_loss + i_lasso_loss
        
        x_mask = torch.cat(x_mask, dim=1) # batch_size, 4 x C_in, H, W
        weight = torch.cat(weight, dim=0) # 4 x C_out, C_in, k, k

        out = F.conv2d(x_mask, weight, None, self.stride,
                        self.padding, self.dilation, self.heads)
        b, c, h, w = out.size()
        out = out.view(b, self.heads, c // self.heads, h, w)
        out = out.transpose(1, 2).contiguous().view(b, c, h, w)
        return [out, _lasso_loss]

    @property
    def inactive_channels(self):
        return int(self._inactive_channels[0])

    @inactive_channels.setter
    def inactive_channels(self, val):
        self._inactive_channels.fill_(val)

class HeadConv(nn.Module):
    def __init__(self, in_channels, out_channels, squeeze_rate, kernel_size, stride=1,
            padding=0, dilation=1, groups=1, gate_factor=0.25):
        super(HeadConv, self).__init__() 
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, 
                padding, dilation, groups=1, bias=False)
        self.target_pruning_rate = gate_factor
        if in_channels < 80:
            squeeze_rate = squeeze_rate // 2
        self.fc1 = nn.Linear(in_channels, in_channels // squeeze_rate, bias=False)
        self.relu_fc1 = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(in_channels // squeeze_rate, in_channels, bias=True)
        self.relu_fc2 = nn.ReLU(inplace=True)

        nn.init.kaiming_normal_(self.fc1.weight)
        nn.init.kaiming_normal_(self.fc2.weight)
        nn.init.constant_(self.fc2.bias, 1.0)

    def forward(self, x, x_averaged, inactive_channels):
        b, c, _, _ = x.size()
        x_averaged = x_averaged.view(b, c)
        y = self.fc1(x_averaged)
        y = self.relu_fc1(y)
        y = self.fc2(y)


        mask = self.relu_fc2(y) # b, c
        _lasso_loss = mask.mean()

        mask_d = mask.detach()
        mask_c = mask

        if inactive_channels > 0:
            mask_c = mask.clone()
            topk_maxmum, _ = mask_d.topk(inactive_channels, dim=1, largest=False, sorted=False)
            clamp_max, _ = topk_maxmum.max(dim=1, keepdim=True)
            mask_index = mask_d.le(clamp_max)
            mask_c[mask_index] = 0

        mask_c = mask_c.view(b, c, 1, 1)
        x = x * mask_c.expand_as(x)
        return x, _lasso_loss


class Conv(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size, 
                 stride=1, padding=0, groups=1):
        super(Conv, self).__init__()
        self.add_module('norm', nn.BatchNorm2d(in_channels))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(in_channels, out_channels,
                                          kernel_size=kernel_size,
                                          stride=stride,
                                          padding=padding, bias=False,
                                          groups=groups))