本文主要是介绍triton之语法学习,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
一 基本语法
1 torch中tensor的声明
x = torch.tensor([[1,2, 1, 1, 1, 1, 1, 1],[2,2,2,2,2,2,2,2]],device='cuda')
二 triton中函数
1 sum
output = tl.sum(x,axis = 0)
如果输入是torch中的声明的话,则输出为
可以看出这个加,并未是一个reduce的过程
三 使用流程
1 数据划分
主要分为三个步骤
1.1 获取pid
row = tl.program_id(axis=0)
1.2 按照pid得到每个pid的数据起点
start = row * stride
stride对应的是输入数据的每一行的数据量
1.3 得到每个pid处理的数据范围
cols = start + tl.arange(0,8)
需要注意的是arange里面的一定要是常数,不能是函数传进来的参数 ,实用的传入方式如下所示
grid = lambda meta: (triton.cdiv(n_el
这篇关于triton之语法学习的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!