使用python/numpy实现im2col的学习心得

系统 1828 0

使用python/numpy实现im2col的学习心得

  • 背景
    • 书上的程序
    • 分析
      • 首先是:
      • 其次:
      • 写在最后

背景

最近在看深度学习的东西。使用的参考书是《深度学习入门——基于python的理论与实现》。在看到7.4时,里面引入了一个im2col的函数,从而方便讲不断循环进行地相乘相加操作变成矩阵的运算,通过空间资源换取时间效率。

为什么要这么操作和操作以后col矩阵的样子比较好理解。由于对python和numpy不太熟悉,理解书上给出的程序实现想了很久。终于有点感觉了,记录下来。

书上的程序

            
              
                def
              
              
                im2col
              
              
                (
              
              input_data
              
                ,
              
               filter_h
              
                ,
              
               filter_w
              
                ,
              
               stride
              
                =
              
              
                1
              
              
                ,
              
               pad
              
                =
              
              
                0
              
              
                )
              
              
                :
              
              
                """

    Parameters
    ----------
    input_data : 由(数据量, 通道, 高, 长)的4维数组构成的输入数据
    filter_h : 滤波器的高
    filter_w : 滤波器的长
    stride : 步幅
    pad : 填充

    Returns
    -------
    col : 2维数组
    """
              
              
    N
              
                ,
              
               C
              
                ,
              
               H
              
                ,
              
               W 
              
                =
              
               input_data
              
                .
              
              shape
    out_h 
              
                =
              
              
                (
              
              H 
              
                +
              
              
                2
              
              
                *
              
              pad 
              
                -
              
               filter_h
              
                )
              
              
                //
              
              stride 
              
                +
              
              
                1
              
              
    out_w 
              
                =
              
              
                (
              
              W 
              
                +
              
              
                2
              
              
                *
              
              pad 
              
                -
              
               filter_w
              
                )
              
              
                //
              
              stride 
              
                +
              
              
                1
              
              

    img 
              
                =
              
               np
              
                .
              
              pad
              
                (
              
              input_data
              
                ,
              
              
                [
              
              
                (
              
              
                0
              
              
                ,
              
              
                0
              
              
                )
              
              
                ,
              
              
                (
              
              
                0
              
              
                ,
              
              
                0
              
              
                )
              
              
                ,
              
              
                (
              
              pad
              
                ,
              
               pad
              
                )
              
              
                ,
              
              
                (
              
              pad
              
                ,
              
               pad
              
                )
              
              
                ]
              
              
                ,
              
              
                'constant'
              
              
                )
              
              
    col 
              
                =
              
               np
              
                .
              
              zeros
              
                (
              
              
                (
              
              N
              
                ,
              
               C
              
                ,
              
               filter_h
              
                ,
              
               filter_w
              
                ,
              
               out_h
              
                ,
              
               out_w
              
                )
              
              
                )
              
              
                for
              
               y 
              
                in
              
              
                range
              
              
                (
              
              filter_h
              
                )
              
              
                :
              
              
        y_max 
              
                =
              
               y 
              
                +
              
               stride
              
                *
              
              out_h
        
              
                for
              
               x 
              
                in
              
              
                range
              
              
                (
              
              filter_w
              
                )
              
              
                :
              
              
            x_max 
              
                =
              
               x 
              
                +
              
               stride
              
                *
              
              out_w
            col
              
                [
              
              
                :
              
              
                ,
              
              
                :
              
              
                ,
              
               y
              
                ,
              
               x
              
                ,
              
              
                :
              
              
                ,
              
              
                :
              
              
                ]
              
              
                =
              
               img
              
                [
              
              
                :
              
              
                ,
              
              
                :
              
              
                ,
              
               y
              
                :
              
              y_max
              
                :
              
              stride
              
                ,
              
               x
              
                :
              
              x_max
              
                :
              
              stride
              
                ]
              
              

    col 
              
                =
              
               col
              
                .
              
              transpose
              
                (
              
              
                0
              
              
                ,
              
              
                4
              
              
                ,
              
              
                5
              
              
                ,
              
              
                1
              
              
                ,
              
              
                2
              
              
                ,
              
              
                3
              
              
                )
              
              
                .
              
              reshape
              
                (
              
              N
              
                *
              
              out_h
              
                *
              
              out_w
              
                ,
              
              
                -
              
              
                1
              
              
                )
              
              
                return
              
               col

            
          

分析

首先只考虑一个数据,即此时 N = 1 N=1 N = 1 。并且假设数据只有一层,比如灰度图,即 C = 1 C=1 C = 1 。假设数据的高和长分别为4,4。即 H = 4 H=4 H = 4 , W = 4 W=4 W = 4 。滤波器的长和高分别为2,2。即 f i l t e r _ h = 2 filter\_h=2 f i l t e r _ h = 2 , f i l t e r _ w = 2 filter\_w=2 f i l t e r _ w = 2 。更进一步地,将Pad简化为0。

此时,img就是一个 4 ∗ 4 4*4 4 4 的矩阵,假设如下:
使用python/numpy实现im2col的学习心得_第1张图片
滤波器是 2 ∗ 2 2*2 2 2 的矩阵,假设为
使用python/numpy实现im2col的学习心得_第2张图片
因此,卷积层输出是 3 ∗ 3 3*3 3 3 的矩阵。

有了这些预备设定,就可以开始理解程序了。我们重点关注两句话。

首先是:

            
              col
              
                [
              
              
                :
              
              
                ,
              
              
                :
              
              
                ,
              
               y
              
                ,
              
               x
              
                ,
              
              
                :
              
              
                ,
              
              
                :
              
              
                ]
              
              
                =
              
               img
              
                [
              
              
                :
              
              
                ,
              
              
                :
              
              
                ,
              
               y
              
                :
              
              y_max
              
                :
              
              stride
              
                ,
              
               x
              
                :
              
              x_max
              
                :
              
              stride
              
                ]
              
            
          

y和x分别代表滤波器的尺寸,由于设定 N = 1 N=1 N = 1 C = 1 C=1 C = 1 。因此可以先只看后面四个维度。那么 ( y , x , : , : ) (y,x,:,:) ( y , x , : , : ) 意味着矩阵前两维和滤波器尺寸一致,即 2 ∗ 2 2*2 2 2 ,后面的两个冒号,就代表了在卷积运算(滤波)时,第y行第x列的滤波器参数,需要和img中运算的数的矩阵。
解释一下:当y=0,x=0时,对应的a的位置,如图
使用python/numpy实现im2col的学习心得_第3张图片
此时,完成整个卷积运算的时候,a分别需要做(3*3)=9次的乘法,每次做乘法是对应img中的数如下:
第一次:(绿色表示当前卷积时,每个滤波器参数对应的位置,红色表示a对应的位置)
使用python/numpy实现im2col的学习心得_第4张图片
第二次:

使用python/numpy实现im2col的学习心得_第5张图片
以此类推……
所以 ( 0 , 0 , : , : ) (0,0,:,:) ( 0 , 0 , : , : ) 中存的数如下:(黄色标注的位置),对应滤波器参数a所以进行运算的范围。
使用python/numpy实现im2col的学习心得_第6张图片
所以:

            
              img
              
                [
              
              
                :
              
              
                ,
              
              
                :
              
              
                ,
              
               y
              
                :
              
              y_max
              
                :
              
              stride
              
                ,
              
               x
              
                :
              
              x_max
              
                :
              
              stride
              
                ]
              
            
          

中,y和x分别表示filter中第几行,第几列。然后每次移动stride,直到走完img中所有的位置,抵达y_max和x_max。

这句话,以及这个for循环的作用就解释完了。

其次:

            
              col 
              
                =
              
               col
              
                .
              
              transpose
              
                (
              
              
                0
              
              
                ,
              
              
                4
              
              
                ,
              
              
                5
              
              
                ,
              
              
                1
              
              
                ,
              
              
                2
              
              
                ,
              
              
                3
              
              
                )
              
              
                .
              
              reshape
              
                (
              
              N
              
                *
              
              out_h
              
                *
              
              out_w
              
                ,
              
              
                -
              
              
                1
              
              
                )
              
            
          

这句话目的是把矩阵重新排列,最后呈现出适合进行矩阵运算来代替循环的形式。
所以,这个矩阵一定是 N ∗ o u t _ h ∗ o u t _ w N*out\_h*out\_w N o u t _ h o u t _ w 行。这里就是 3 ∗ 3 = 9 3*3=9 3 3 = 9 行。有多少列呢,肯定是滤波器系数的个数,即 2 ∗ 2 = 4 2*2=4 2 2 = 4 列。

至于transpose函数中的设置,主要是为了配合后面的reshape函数的参数。

多说一句,我觉得这里transpose不要老是想着转置,我开始也这么想,这么多维度,就转不过来弯了。
我觉得,其实transpose就是决定一个新的取数顺序,依次取出来就可以,然后能够和原来对应上,就没问题了。比如 a是一个三维的东西。然后b = a.transpose(1,2,0)。也就是说 a [ y ] [ z ] [ x ] = b [ x ] [ y ] [ z ] a[y][z][x] = b[x][y][z] a [ y ] [ z ] [ x ] = b [ x ] [ y ] [ z ]

transpose第一个参数,0,表示第0维,也就是transpose以后,第0维不变,说明即便展开,输入的img也是按顺序一个一个处理完的。

第2和3的参数,之所以放 o u t _ h 和 o u t _ w out\_h和out\_w o u t _ h o u t _ w 的大小,得明白reshape的操作方法。如果没有指定order参数,并且是默认按照C的存储格式(这里不理解可以看看reshape的参数有哪些),它是把矩阵按照从第0维开始,依次全部排列开,然后在按需求重组。所以这里,要按照 o u t _ h 和 o u t _ w out\_h和out\_w o u t _ h o u t _ w 优先顺序排列开,然后再使col总共就 N ∗ o u t _ h ∗ o u t _ w N*out\_h*out\_w N o u t _ h o u t _ w 行,那么reshpe函数会使每行中,就存储一次卷积所需要所有值,即 C ∗ f i l t e r h ∗ f i l t e r w C*filter_h*filter_w C f i l t e r h f i l t e r w 列。

后面三个参数保证顺序不变就行,方便和滤波器参数位置一一对应。

以上,总结成一句话:其实就是准确找到滤波器每个参数对应需要相乘的所有值,然后再变换一下矩阵的行状,就可以了。

写在最后

由于本人水平有限,这一点代码都想了一下午加一晚上才明白。还得继续努力了。加油!虽然整理出来了,感觉有些东西不太好表述清楚,大家有什么问题可以留言,多多交流,互相学习。


更多文章、技术交流、商务合作、联系博主

微信扫码或搜索:z360901061

微信扫一扫加我为好友

QQ号联系: 360901061

您的支持是博主写作最大的动力,如果您喜欢我的文章,感觉我的文章对您有帮助,请用微信扫描下面二维码支持博主2元、5元、10元、20元等您想捐的金额吧,狠狠点击下面给点支持吧,站长非常感激您!手机微信长按不能支付解决办法:请将微信支付二维码保存到相册,切换到微信,然后点击微信右上角扫一扫功能,选择支付二维码完成支付。

【本文对您有帮助就好】

您的支持是博主写作最大的动力,如果您喜欢我的文章,感觉我的文章对您有帮助,请用微信扫描上面二维码支持博主2元、5元、10元、自定义金额等您想捐的金额吧,站长会非常 感谢您的哦!!!

发表我的评论
最新评论 总共0条评论