最近项目上需要实现直接卷积,就看相关的教程中实现的都是信号和电子领域的卷积,结果和计算机领域的不一致,原因大家可以自己搜一下,计算机图像领域的卷积其实不是真正的卷积。

其算法示意如下图所示:

直接卷积理解

相关代码参考于他人代码,但是目前找不到了,欢迎作者联系我补充。代码有所修改。

输入:imput[IC][IH][IW]
IC = input.channels
IH = input.height
IW = input.width

卷积核: kernel[KC1][KC2][KH][KW]
KC1 = OC
KC2 = IC
KH = kernel.height
KW = kernel.width

输出:output[OC][OH][OW]
OC = output.channels
OH = output.height
OW = output.width

其中,padding = VALID,stride = 1,
OH = IH - KH + 1
OW = IW - KW + 1


for(int ch=0;ch<output.channels;ch++)
{
    for(int oh=0;oh<output.height;oh++)
    {
        for(int ow=0;ow<output.width;ow++)
        {
            float sum=0;
            for(int kc=0;kc<kernel.channels;kc++)
            {
                for(int kh=0;kh<kernel.height;kh++)
                {
                    for(int kw=0;kw<kernel.width;kw++)
                    {
                        sum += input[kc][oh+kh][ow+kw]*kernel[ch][kc][kh][kw];
                    }
                }
            }
            //if(bias) sum +=bias[]
            output[ch][oh][ow]=sum;
        }
    }
}

 后边两种是我根据上边的理解,加入stride的影响,写的。欢迎参考

方案1,根据input

for(int ch=0;ch<output.channels;ch++)
{    
    for(int ih=0;ih<input.height;ih += stride_h)
    {
        for(int iw=0;iw<input.width;iw += stride_w)
        {
            float sum = 0;
            for(int kc=0;kc<kernel.channels;kc++) //kernel's channel = input_data's channel
            {
                for(int kh=0;kh<kernel.height;kh++)
                {
                    for(int kw=0;kw<kernel.width;kw++)
                    {
                        sum += input[kc][ih+kh][iw+kw]*kernel[ch][kc][kh][kw];
                    }
                }
            }
            output[ch][ih+kh/2][iw+kw/2] = sum;
        }
    }
}
for(int ch=0;ch<output.channels;ch++)
{    
    for(int oh=0;oh<output.height;oh++)
    {
        for(int ow=0;ow<output.width;ow++)
        {
            float sum = 0;
            for(int kc=0;kc<kernel.channels;kc++) //kernel's channel = input_data's channel
            {
                for(int kh=0;kh<kernel.height;kh++)
                {
                    for(int kw=0;kw<kernel.width;kw++)
                    {
                        sum += input[kc][oh*stride_h+kh][ow*stride_w+kw]*kernel[ch][kc][kh][kw];
                    }
                }
            }
            output[ch][oh][ow] += sum;
        }
        //if(bias) tem +=bias[]
    }
}