tvm学习笔记(八):卷积操作

2023-12-23 10:48
文章标签 学习 操作 笔记 卷积 tvm

本文主要是介绍tvm学习笔记(八):卷积操作,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

对于卷积神经网络中,卷积操作可能是最常见操作,具体原理可以去学习一下Andred NG的课程,建议搞计算机视觉方向的都去刷一波,具体过程如图1所示:

         

图1 'VALID'方式卷积操作过程

其实就是卷积核与图像待操作区域进行乘加操作,常见的卷积操作有两种形式,第一种是'VALID'的方式,如图1所示,第二种是'SAME'的方式,区别在于'SAME'方式会对输入进行填充,以保证卷积操作之后,输出的size和输入的size一致。

图2 'SAME'方式卷积操作过程

1、padding

先说一下填充padding,padding就是在原始图像四周填充0,对应于图2中虚线部分,使用tvm实现,代码如下:

def padding(X, ph, pw):assert len(X.shape) >= 2nh, nw = X.shape[-2], X.shape[-1]return tvm.compute((*X.shape[0:-2], nh + ph * 2, nw + pw * 2),lambda *i: tvm.if_then_else(tvm.any(i[-2] < ph, i[-2] >= nh + ph, i[-1] < pw, i[-1] >= nw + pw),0, X[i[:-2] + (i[-2] - ph, i[-1] - pw)]), name = 'PaddedX')

2、输出feature map尺寸计算

对于 输入size为n, 卷积核size为k, 填充size为p,卷积操作步长size为s,输出大小为:

o=floor(\frac{n-k+2*p}{s})+1

对应代码如下:

def conv_out_size(n, k, p, s):return (n - k + 2 * p) // s + 1

3、卷积操作

就是将卷积核与要操作的图像块进行乘加操作,对应于tvm代码为:

def conv(oc, ic, nh, nw, kh, kw, ph=0, pw=0, sh=1, sw=1):# reduction axesric = tvm.reduce_axis((0, ic), name='ric')rkh = tvm.reduce_axis((0, kh), name='rkh')rkw = tvm.reduce_axis((0, kw), name='rkw')# output height and widthoh = conv_out_size(nh, kh, ph, sh)ow = conv_out_size(nw, kw, pw, sw)# pad x and then conpute yX = tvm.placeholder((ic, nh, nw), name='x')K = tvm.placeholder((oc, ic, kh, kw), name='k')# 对输入填充PaddedX = padding(X, ph, pw) if ph * pw != 0 else XY = tvm.compute((oc, oh, ow),lambda c, i, j: tvm.sum(PaddedX[ric, i * sh + rkh, j * sw + rkw] * K[c, ric, rkh, rkw],axis=[ric, rkh, rkw]), name='Y')return X, K, Y, PaddedX

最后,看一下实际生成的伪代码:

import tvm
import numpy as np
import mxnet as mxdef padding(X, ph, pw):assert len(X.shape) >= 2nh, nw = X.shape[-2], X.shape[-1]return tvm.compute((*X.shape[0:-2], nh + ph * 2, nw + pw * 2),lambda *i: tvm.if_then_else(tvm.any(i[-2] < ph, i[-2] >= nh + ph, i[-1] < pw, i[-1] >= nw + pw),0, X[i[:-2] + (i[-2] - ph, i[-1] - pw)]), name = 'PaddedX')# 输入size:n
# 卷积核size:k
# 填充size:p
# 步长size:s
def conv_out_size(n, k, p, s):return (n - k + 2 * p) // s + 1def conv(oc, ic, nh, nw, kh, kw, ph=0, pw=0, sh=1, sw=1):# reduction axesric = tvm.reduce_axis((0, ic), name='ric')rkh = tvm.reduce_axis((0, kh), name='rkh')rkw = tvm.reduce_axis((0, kw), name='rkw')# output height and widthoh = conv_out_size(nh, kh, ph, sh)ow = conv_out_size(nw, kw, pw, sw)# pad x and then conpute yX = tvm.placeholder((ic, nh, nw), name='x')K = tvm.placeholder((oc, ic, kh, kw), name='k')# 对输入填充PaddedX = padding(X, ph, pw) if ph * pw != 0 else XY = tvm.compute((oc, oh, ow),lambda c, i, j: tvm.sum(PaddedX[ric, i * sh + rkh, j * sw + rkw] * K[c, ric, rkh, rkw],axis=[ric, rkh, rkw]), name='Y')return X, K, Y, PaddedXdef get_conv_data(oc, ic, n, k, p=0, s=1, constructor=None):np.random.seed(0)data = np.random.normal(size=(ic, n, n)).astype('float32')weight = np.random.normal(size=(oc, ic, k, k)).astype('float32')on = conv_out_size(n, k, p, s)out = np.empty((oc, on, on), dtype='float32')if constructor:data, weight, out = (constructor(x) for x in [data, weight, out])return data, weight, outoc, ic, n, k, p, s = 4, 6, 12, 3, 1, 1
X, K, Y, _ = conv(oc, ic, n, n, k, k, p, p, s, s)
sch = tvm.create_schedule(Y.op)
mod = tvm.build(sch, [X, K, Y])
print(tvm.lower(sch, [X, K, Y], simple_mode=True))data, weight, out = get_conv_data(oc, ic, n, k, p, s, tvm.nd.array)
mod(data, weight, out)def get_conv_data_mxnet(oc, ic, n, k, p, s, ctx='cpu'):ctx = getattr(mx, ctx)()data, weight, out = get_conv_data(oc, ic, n, k, p, s,lambda x: mx.nd.array(x, ctx=ctx))data, out = data.expand_dims(axis=0), out.expand_dims(axis=0)bias = mx.nd.zeros(out.shape[1], ctx=ctx)return data, weight, bias, outdef conv_mxnet(data, weight, bias, out, k, p, s):mx.nd.Convolution(data, weight, bias, kernel=(k, k), stride=(s, s),pad=(p, p), num_filter=out.shape[1], out=out)data, weight, bias, out_mx = get_conv_data_mxnet(oc, ic, n, k, p, s)
conv_mxnet(data, weight, bias, out_mx, k, p, s)
np.testing.assert_allclose(out_mx[0].asnumpy(), out.asnumpy(), atol=1e-5)

输出为:

// attr [PaddedX] storage_scope = "global"
allocate PaddedX[float32 * 1176]
produce PaddedX {for (i0, 0, 6) {for (i1, 0, 14) {for (i2, 0, 14) {PaddedX[(((i0*196) + (i1*14)) + i2)] = tvm_if_then_else(((((i1 < 1) |
| (13 <= i1)) || (i2 < 1)) || (13 <= i2)), 0f, x[((((i0*144) + (i1*12)) + i2) - 13)])      }}}
}
produce Y {for (c, 0, 4) {for (i, 0, 12) {for (j, 0, 12) {Y[(((c*144) + (i*12)) + j)] = 0ffor (ric, 0, 6) {for (rkh, 0, 3) {for (rkw, 0, 3) {Y[(((c*144) + (i*12)) + j)] = (Y[(((c*144) + (i*12)) + j)] + (P
addedX[(((((ric*196) + (i*14)) + (rkh*14)) + j) + rkw)]*k[((((c*54) + (ric*9)) + (rkh*3)) + rkw)]))            }}}}}}
}

 

参考资料:

[1] https://blog.csdn.net/kingroc/article/details/88192878

[2] http://tvm.d2l.ai.s3-website-us-west-2.amazonaws.com/

这篇关于tvm学习笔记(八):卷积操作的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/527758

相关文章

Python调用Orator ORM进行数据库操作

《Python调用OratorORM进行数据库操作》OratorORM是一个功能丰富且灵活的PythonORM库,旨在简化数据库操作,它支持多种数据库并提供了简洁且直观的API,下面我们就... 目录Orator ORM 主要特点安装使用示例总结Orator ORM 是一个功能丰富且灵活的 python O

python使用fastapi实现多语言国际化的操作指南

《python使用fastapi实现多语言国际化的操作指南》本文介绍了使用Python和FastAPI实现多语言国际化的操作指南,包括多语言架构技术栈、翻译管理、前端本地化、语言切换机制以及常见陷阱和... 目录多语言国际化实现指南项目多语言架构技术栈目录结构翻译工作流1. 翻译数据存储2. 翻译生成脚本

0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型的操作流程

《0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeekR1模型的操作流程》DeepSeekR1模型凭借其强大的自然语言处理能力,在未来具有广阔的应用前景,有望在多个领域发... 目录0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型,3步搞定一个应

Java深度学习库DJL实现Python的NumPy方式

《Java深度学习库DJL实现Python的NumPy方式》本文介绍了DJL库的背景和基本功能,包括NDArray的创建、数学运算、数据获取和设置等,同时,还展示了如何使用NDArray进行数据预处理... 目录1 NDArray 的背景介绍1.1 架构2 JavaDJL使用2.1 安装DJL2.2 基本操

轻松上手MYSQL之JSON函数实现高效数据查询与操作

《轻松上手MYSQL之JSON函数实现高效数据查询与操作》:本文主要介绍轻松上手MYSQL之JSON函数实现高效数据查询与操作的相关资料,MySQL提供了多个JSON函数,用于处理和查询JSON数... 目录一、jsON_EXTRACT 提取指定数据二、JSON_UNQUOTE 取消双引号三、JSON_KE

C++实现封装的顺序表的操作与实践

《C++实现封装的顺序表的操作与实践》在程序设计中,顺序表是一种常见的线性数据结构,通常用于存储具有固定顺序的元素,与链表不同,顺序表中的元素是连续存储的,因此访问速度较快,但插入和删除操作的效率可能... 目录一、顺序表的基本概念二、顺序表类的设计1. 顺序表类的成员变量2. 构造函数和析构函数三、顺序表

使用C++实现单链表的操作与实践

《使用C++实现单链表的操作与实践》在程序设计中,链表是一种常见的数据结构,特别是在动态数据管理、频繁插入和删除元素的场景中,链表相比于数组,具有更高的灵活性和高效性,尤其是在需要频繁修改数据结构的应... 目录一、单链表的基本概念二、单链表类的设计1. 节点的定义2. 链表的类定义三、单链表的操作实现四、

Python利用自带模块实现屏幕像素高效操作

《Python利用自带模块实现屏幕像素高效操作》这篇文章主要为大家详细介绍了Python如何利用自带模块实现屏幕像素高效操作,文中的示例代码讲解详,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1、获取屏幕放缩比例2、获取屏幕指定坐标处像素颜色3、一个简单的使用案例4、总结1、获取屏幕放缩比例from

通过prometheus监控Tomcat运行状态的操作流程

《通过prometheus监控Tomcat运行状态的操作流程》文章介绍了如何安装和配置Tomcat,并使用Prometheus和TomcatExporter来监控Tomcat的运行状态,文章详细讲解了... 目录Tomcat安装配置以及prometheus监控Tomcat一. 安装并配置tomcat1、安装

Python中操作Redis的常用方法小结

《Python中操作Redis的常用方法小结》这篇文章主要为大家详细介绍了Python中操作Redis的常用方法,文中的示例代码简洁易懂,具有一定的借鉴价值,有需要的小伙伴可以了解一下... 目录安装Redis开启、关闭Redisredis数据结构redis-cli操作安装redis-py数据库连接和释放增