本文主要是介绍TensorFlow下反卷积(Deconvolution)的实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
简介反卷积的过程
如图所示是一个卷积的过程 假设输入如下:
[[1,0,1],[0,2,1],[1,1,0]]
反卷积卷积核如下:
[[ 1, 0, 1],[-1, 1, 0],[ 0,-1, 0]]
现在通过stride=2
来进行反卷积,使得尺寸由原来的3*3
变为6*6
.那么在Tensorflow框架中,反卷积的过程如下(不同框架在裁剪这步可能不一样):
其实通过我绘制的这张图,就已经把原理讲的很清楚了。大致步奏就是,先填充0,然后进行卷积,卷积过程跟上一篇文章讲述的一致。最后一步还要进行裁剪。好了,原理讲完了,(#.#)....
代码
上一篇文章我们只针对了输出通道数为1进行代码实现,在这篇文章中,反卷积我们将输出通道设置为多个,这样更符合实际场景。
先定义输入和卷积核:
input_data=[ [[1,0,1], [0,2,1], [1,1,0]], [[2,0,2], [0,1,0], [1,0,0]], [[1,1,1], [2,2,0], [1,1,1]], [[1,1,2], [1,0,1], [0,2,2]] ] weights_data=[ [[[ 1, 0, 1], [-1, 1, 0], [ 0,-1, 0]], [[-1, 0, 1], [ 0, 0, 1], [ 1, 1, 1]], [[ 0, 1, 1], [ 2, 0, 1], [ 1, 2, 1]], [[ 1, 1, 1], [ 0, 2, 1], [ 1, 0, 1]]], [[[ 1, 0, 2], [-2, 1, 1], [ 1,-1, 0]], [[-1, 0, 1], [-1, 2, 1], [ 1, 1, 1]], [[ 0, 0, 0], [ 2, 2, 1], [ 1,-1, 1]], [[ 2, 1, 1], [ 0,-1, 1], [ 1, 1, 1]]] ]
上面定义的输入和卷积核,在接下的运算过程如下图所示:
可以看到实际上,反卷积和卷积基本一致,差别在于,反卷积需要填充过程,并在最后一步需要裁剪。具体实现代码如下:
def compute_conv(fm,kernel): [h,w]=fm.shape[k,_]=kernel.shaper=int(k/2) #定义边界填充0后的map padding_fm=np.zeros([h+2,w+2],np.float32) #保存计算结果 rs=np.zeros([h,w],np.float32) #将输入在指定该区域赋值,即除了4个边界后,剩下的区域 padding_fm[1:h+1,1:w+1]=fm#对每个点为中心的区域遍历 for i in range(1,h+1): for j in range(1,w+1): #取出当前点为中心的k*k区域 roi=padding_fm[i-r:i+r+1,j-r:j+r+1] #计算当前点的卷积,对k*k个点点乘后求和 rs[i-1][j-1]=np.sum(roi*kernel) return rs #填充0 def fill_zeros(input): [c,h,w]=input.shapers=np.zeros([c,h*2+1,w*2+1],np.float32) for i in range(c): for j in range(h): for k in range(w): rs[i,2*j+1,2*k+1]=input[i,j,k] return rsdef my_deconv(input,weights): #weights shape=[out_c,in_c,h,w] [out_c,in_c,h,w]=weights.shapeout_h=h*2 out_w=w*2 rs=[] for i in range(out_c): w=weights[i] tmp=np.zeros([out_h,out_w],np.float32) for j in range(in_c): conv=compute_conv(input[j],w[j]) #注意裁剪,最后一行和最后一列去掉 tmp=tmp+conv[0:out_h,0:out_w] rs.append(tmp) return rs def main(): input=np.asarray(input_data,np.float32) input= fill_zeros(input) weights=np.asarray(weights_data,np.float32) deconv=my_deconv(input,weights) print(np.asarray(deconv)) if __name__=='__main__': main()
计算卷积代码,跟上一篇文章一致。代码直接看注释,不再解释。运行结果如下:
为了验证实现的代码的正确性,我们使用tensorflow的conv2d_transpose函数执行相同的输入和卷积核,看看结果是否一致。验证代码如下:
def tf_conv2d_transpose(input, weights): input_shape = input.get_shape().as_list() weights_shape = weights.get_shape().as_list() output_shape = [input_shape[0], input_shape[1] * 2, input_shape[2] * 2, weights_shape[2]] print("output_shape:", output_shape) deconv = tf.nn.conv2d_transpose(input, weights, output_shape=output_shape,strides=[1, 2, 2, 1], padding='SAME') return deconvdef main(): weights_np = np.asarray(weights_data, np.float32) # 将输入的每个卷积核旋转180° weights_np = np.rot90(weights_np, 2, (2, 3)) const_input = tf.constant(input_data, tf.float32) const_weights = tf.constant(weights_np, tf.float32) input = tf.Variable(const_input, name="input") # [c,h,w]------>[h,w,c] input = tf.transpose(input, perm=(1, 2, 0)) # [h,w,c]------>[n,h,w,c] input = tf.expand_dims(input, 0) # weights shape=[out_c,in_c,h,w] weights = tf.Variable(const_weights, name="weights") # [out_c,in_c,h,w]------>[h,w,out_c,in_c] weights = tf.transpose(weights, perm=(2, 3, 0, 1)) # 执行tensorflow的反卷积 deconv = tf_conv2d_transpose(input, weights) init = tf.global_variables_initializer() sess = tf.Session() sess.run(init) deconv_val = sess.run(deconv) hwc = deconv_val[0] print(hwc) if __name__ == '__main__': main()
上面代码中,有几点需要注意:
- 每个卷积核需要旋转180°后,再传入tf.nn.conv2d_transpose函数中,因为tf.nn.conv2d_transpose内部会旋转180°,所以提前旋转,再经过内部旋转后,能保证卷积核跟我们所使用的卷积核的数据排列一致。
- 我们定义的输入的shape为[c,h,w]需要转为tensorflow所使用的[n,h,w,c]。
- 我们定义的卷积核shape为[out_c,in_c,h,w],需要转为tensorflow反卷积中所使用的[h,w,out_c,in_c]
执行上面代码后,执行结果如下:
[[[[ 4. 4.][ 3. 1.][ 6. 7.][ 2. 0.][ 7. 7.][ 3. 2.]][[ 4. 5.][ 3. 6.][ 3. 0.][ 2. 1.][ 7. 8.][ 5. 5.]][[ 8. 8.][ 6. 0.][ 8. 8.][ 5. -2.][11. 14.][ 2. 2.]][[ 3. 3.][ 2. 3.][ 7. 9.][ 2. 8.][ 3. 1.][ 3. 0.]][[ 5. 3.][ 5. 0.][11. 13.][ 3. 0.][ 9. 11.][ 3. 2.]][[ 2. 3.][ 1. 5.][ 4. 3.][ 5. 1.][ 4. 3.][ 4. 0.]]]]
对比结果可以看到,数据是一致的,这个只是按照两个矩阵的对应位置的元素进行输出,(如[4,4]代表第一个矩阵的第一个元素和第二个矩阵的第一个元素)证明前面手写的python实现的反卷积代码是正确的
这篇关于TensorFlow下反卷积(Deconvolution)的实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!