TensorFlow下反卷积(Deconvolution)的实现

2024-03-01 03:58

本文主要是介绍TensorFlow下反卷积(Deconvolution)的实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

简介反卷积的过程

如图所示是一个卷积的过程
那么反卷积就应该是一个逆向的过程

假设输入如下:

[[1,0,1],[0,2,1],[1,1,0]]

反卷积卷积核如下:

[[ 1, 0, 1],[-1, 1, 0],[ 0,-1, 0]]

现在通过stride=2来进行反卷积,使得尺寸由原来的3*3变为6*6.那么在Tensorflow框架中,反卷积的过程如下(不同框架在裁剪这步可能不一样):

反卷积实现例子

其实通过我绘制的这张图,就已经把原理讲的很清楚了。大致步奏就是,先填充0,然后进行卷积,卷积过程跟上一篇文章讲述的一致。最后一步还要进行裁剪。好了,原理讲完了,(#.#)....

代码

上一篇文章我们只针对了输出通道数为1进行代码实现,在这篇文章中,反卷积我们将输出通道设置为多个,这样更符合实际场景。

先定义输入和卷积核:

input_data=[
               [[1,0,1],
                [0,2,1],
                [1,1,0]],

               [[2,0,2],
                [0,1,0],
                [1,0,0]],

               [[1,1,1],
                [2,2,0],
                [1,1,1]],

               [[1,1,2],
                [1,0,1],
                [0,2,2]]

            ]
weights_data=[
                [[[ 1, 0, 1],
                  [-1, 1, 0],
                  [ 0,-1, 0]],

                 [[-1, 0, 1],
                  [ 0, 0, 1],
                  [ 1, 1, 1]],

                 [[ 0, 1, 1],
                  [ 2, 0, 1],
                  [ 1, 2, 1]],

                 [[ 1, 1, 1],
                  [ 0, 2, 1],
                  [ 1, 0, 1]]],

                [[[ 1, 0, 2],
                  [-2, 1, 1],
                  [ 1,-1, 0]],

                 [[-1, 0, 1],
                  [-1, 2, 1],
                  [ 1, 1, 1]],

                 [[ 0, 0, 0],
                  [ 2, 2, 1],
                  [ 1,-1, 1]],

                 [[ 2, 1, 1],
                  [ 0,-1, 1],
                  [ 1, 1, 1]]] ]

上面定义的输入和卷积核,在接下的运算过程如下图所示:

执行过程

可以看到实际上,反卷积和卷积基本一致,差别在于,反卷积需要填充过程,并在最后一步需要裁剪。具体实现代码如下:

def compute_conv(fm,kernel):
    [h,w]=fm.shape[k,_]=kernel.shaper=int(k/2)
    #定义边界填充0后的map
    padding_fm=np.zeros([h+2,w+2],np.float32)
    #保存计算结果
    rs=np.zeros([h,w],np.float32)
    #将输入在指定该区域赋值,即除了4个边界后,剩下的区域
    padding_fm[1:h+1,1:w+1]=fm#对每个点为中心的区域遍历
    for i in range(1,h+1):
        for j in range(1,w+1):
            #取出当前点为中心的k*k区域
            roi=padding_fm[i-r:i+r+1,j-r:j+r+1]
            #计算当前点的卷积,对k*k个点点乘后求和
            rs[i-1][j-1]=np.sum(roi*kernel)
    return rs #填充0
def fill_zeros(input):
    [c,h,w]=input.shapers=np.zeros([c,h*2+1,w*2+1],np.float32)
    for i in range(c):
        for j in range(h):
            for k in range(w):
                rs[i,2*j+1,2*k+1]=input[i,j,k]
    return rsdef my_deconv(input,weights): #weights shape=[out_c,in_c,h,w]
    [out_c,in_c,h,w]=weights.shapeout_h=h*2
    out_w=w*2
    rs=[]
    for i in range(out_c):
        w=weights[i]
        tmp=np.zeros([out_h,out_w],np.float32)
        for j in range(in_c):
            conv=compute_conv(input[j],w[j]) #注意裁剪,最后一行和最后一列去掉
            tmp=tmp+conv[0:out_h,0:out_w]
        rs.append(tmp)
    return rs
def main():
    input=np.asarray(input_data,np.float32)
    input= fill_zeros(input)
    weights=np.asarray(weights_data,np.float32)
    deconv=my_deconv(input,weights)
    print(np.asarray(deconv))
if __name__=='__main__':
    main()

计算卷积代码,跟上一篇文章一致。代码直接看注释,不再解释。运行结果如下:

为了验证实现的代码的正确性,我们使用tensorflow的conv2d_transpose函数执行相同的输入和卷积核,看看结果是否一致。验证代码如下:

def tf_conv2d_transpose(input, weights):
    input_shape = input.get_shape().as_list()
    weights_shape = weights.get_shape().as_list()
    output_shape = [input_shape[0], input_shape[1] * 2, input_shape[2] * 2, weights_shape[2]]
    print("output_shape:", output_shape)

    deconv = tf.nn.conv2d_transpose(input, weights, output_shape=output_shape,strides=[1, 2, 2, 1], padding='SAME')
    return deconvdef main():
    weights_np = np.asarray(weights_data, np.float32)
    # 将输入的每个卷积核旋转180°
    weights_np = np.rot90(weights_np, 2, (2, 3))

    const_input = tf.constant(input_data, tf.float32)
    const_weights = tf.constant(weights_np, tf.float32)

    input = tf.Variable(const_input, name="input")
    # [c,h,w]------>[h,w,c]
    input = tf.transpose(input, perm=(1, 2, 0))
    # [h,w,c]------>[n,h,w,c]
    input = tf.expand_dims(input, 0)

    # weights shape=[out_c,in_c,h,w]
    weights = tf.Variable(const_weights, name="weights")
    # [out_c,in_c,h,w]------>[h,w,out_c,in_c]
    weights = tf.transpose(weights, perm=(2, 3, 0, 1))

    # 执行tensorflow的反卷积
    deconv = tf_conv2d_transpose(input, weights)

    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)

    deconv_val = sess.run(deconv)

    hwc = deconv_val[0]
    print(hwc)


if __name__ == '__main__':
    main()

上面代码中,有几点需要注意:

  1. 每个卷积核需要旋转180°后,再传入tf.nn.conv2d_transpose函数中,因为tf.nn.conv2d_transpose内部会旋转180°,所以提前旋转,再经过内部旋转后,能保证卷积核跟我们所使用的卷积核的数据排列一致。
  2. 我们定义的输入的shape为[c,h,w]需要转为tensorflow所使用的[n,h,w,c]。
  3. 我们定义的卷积核shape为[out_c,in_c,h,w],需要转为tensorflow反卷积中所使用的[h,w,out_c,in_c]

执行上面代码后,执行结果如下:

[[[[ 4.  4.][ 3.  1.][ 6.  7.][ 2.  0.][ 7.  7.][ 3.  2.]][[ 4.  5.][ 3.  6.][ 3.  0.][ 2.  1.][ 7.  8.][ 5.  5.]][[ 8.  8.][ 6.  0.][ 8.  8.][ 5. -2.][11. 14.][ 2.  2.]][[ 3.  3.][ 2.  3.][ 7.  9.][ 2.  8.][ 3.  1.][ 3.  0.]][[ 5.  3.][ 5.  0.][11. 13.][ 3.  0.][ 9. 11.][ 3.  2.]][[ 2.  3.][ 1.  5.][ 4.  3.][ 5.  1.][ 4.  3.][ 4.  0.]]]]

对比结果可以看到,数据是一致的,这个只是按照两个矩阵的对应位置的元素进行输出,(如[4,4]代表第一个矩阵的第一个元素和第二个矩阵的第一个元素)证明前面手写的python实现的反卷积代码是正确的

这篇关于TensorFlow下反卷积(Deconvolution)的实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/761141

相关文章

关于集合与数组转换实现方法

《关于集合与数组转换实现方法》:本文主要介绍关于集合与数组转换实现方法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、Arrays.asList()1.1、方法作用1.2、内部实现1.3、修改元素的影响1.4、注意事项2、list.toArray()2.1、方

使用Python实现可恢复式多线程下载器

《使用Python实现可恢复式多线程下载器》在数字时代,大文件下载已成为日常操作,本文将手把手教你用Python打造专业级下载器,实现断点续传,多线程加速,速度限制等功能,感兴趣的小伙伴可以了解下... 目录一、智能续传:从崩溃边缘抢救进度二、多线程加速:榨干网络带宽三、速度控制:做网络的好邻居四、终端交互

java实现docker镜像上传到harbor仓库的方式

《java实现docker镜像上传到harbor仓库的方式》:本文主要介绍java实现docker镜像上传到harbor仓库的方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录1. 前 言2. 编写工具类2.1 引入依赖包2.2 使用当前服务器的docker环境推送镜像2.2

C++20管道运算符的实现示例

《C++20管道运算符的实现示例》本文简要介绍C++20管道运算符的使用与实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录标准库的管道运算符使用自己实现类似的管道运算符我们不打算介绍太多,因为它实际属于c++20最为重要的

Java easyExcel实现导入多sheet的Excel

《JavaeasyExcel实现导入多sheet的Excel》这篇文章主要为大家详细介绍了如何使用JavaeasyExcel实现导入多sheet的Excel,文中的示例代码讲解详细,感兴趣的小伙伴可... 目录1.官网2.Excel样式3.代码1.官网easyExcel官网2.Excel样式3.代码

python实现对数据公钥加密与私钥解密

《python实现对数据公钥加密与私钥解密》这篇文章主要为大家详细介绍了如何使用python实现对数据公钥加密与私钥解密,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录公钥私钥的生成使用公钥加密使用私钥解密公钥私钥的生成这一部分,使用python生成公钥与私钥,然后保存在两个文

浏览器插件cursor实现自动注册、续杯的详细过程

《浏览器插件cursor实现自动注册、续杯的详细过程》Cursor简易注册助手脚本通过自动化邮箱填写和验证码获取流程,大大简化了Cursor的注册过程,它不仅提高了注册效率,还通过友好的用户界面和详细... 目录前言功能概述使用方法安装脚本使用流程邮箱输入页面验证码页面实战演示技术实现核心功能实现1. 随机

Golang如何对cron进行二次封装实现指定时间执行定时任务

《Golang如何对cron进行二次封装实现指定时间执行定时任务》:本文主要介绍Golang如何对cron进行二次封装实现指定时间执行定时任务问题,具有很好的参考价值,希望对大家有所帮助,如有错误... 目录背景cron库下载代码示例【1】结构体定义【2】定时任务开启【3】使用示例【4】控制台输出总结背景

Golang如何用gorm实现分页的功能

《Golang如何用gorm实现分页的功能》:本文主要介绍Golang如何用gorm实现分页的功能方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录背景go库下载初始化数据【1】建表【2】插入数据【3】查看数据4、代码示例【1】gorm结构体定义【2】分页结构体

在Golang中实现定时任务的几种高效方法

《在Golang中实现定时任务的几种高效方法》本文将详细介绍在Golang中实现定时任务的几种高效方法,包括time包中的Ticker和Timer、第三方库cron的使用,以及基于channel和go... 目录背景介绍目的和范围预期读者文档结构概述术语表核心概念与联系故事引入核心概念解释核心概念之间的关系