Pytorch官方FlashAttention速度测试

2024-04-11 09:12

本文主要是介绍Pytorch官方FlashAttention速度测试,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在Pytorch的2.2版本更新文档中,官方重点强调了通过实现FlashAtteneion-v2实现了对scaled_dot_product_attention约2X左右的加速。
在这里插入图片描述
今天抽空亲自试了下,看看加速效果是否如官方所说。测试前需要将Pytorch的版本更新到2.2及以上,下面是测试代码,一个是原始手写的Self-Attention的实现,一个是使用Pytorch官方的scaled_dot_product_attention接口:

import time
import torch
import torch.nn.functional as Fdef main():repeat = 100device = torch.device("cuda:0")dtype = torch.float16query = torch.rand(32, 8, 128, 64, dtype=dtype, device=device)key = torch.rand(32, 8, 128, 64, dtype=dtype, device=device)value = torch.rand(32, 8, 128, 64, dtype=dtype, device=device)scale_factor = 0.125ori_time_list = []for _ in range(repeat):torch.cuda.synchronize(device=device)time_start = time.perf_counter()# 原始Self-Attention实现res = torch.softmax(query @ key.transpose(-2, -1) * scale_factor, dim=-1) @ valuetorch.cuda.synchronize(device=device)time_end = time.perf_counter()ori_time_list.append(time_end - time_start)fa_time_list = []for _ in range(repeat):torch.cuda.synchronize(device=device)time_start = time.perf_counter()with torch.backends.cuda.sdp_kernel(enable_math=False):# 使用Pytorch官方提供的FA实现res_fa = F.scaled_dot_product_attention(query, key, value, scale=scale_factor)torch.cuda.synchronize(device=device)time_end = time.perf_counter()fa_time_list.append(time_end - time_start)diff = (res - res_fa).abs().max()ratio = [ori_time_list[i] / fa_time_list[i] for i in range(repeat)]avg_ratio = sum(ratio[1:]) / len(ratio[1:])print(f"max diff: {diff}")print(f"avg speed up ratio: {avg_ratio}")if __name__ == '__main__':main()

执行以上代码,终端输出如下:

max diff: 0.00048828125
avg speed up ratio: 2.2846881043417118

这里使用的设备是RTX4070,跑了很多次发现确实加速2X左右,看来以后训练或者推理时可以考虑直接使用官方的scaled_dot_product_attention接口了。但是这里也发现了两个问题,一个是原始手写的Self-Attention的计算结果和直接调用scaled_dot_product_attention接口得到的结果差异有点大(注意,这里计算的Tensor都是FP16精度的),如果我切换到FP32精度差异会再小两个数量级。第二个问题是如果使用FP32的话实测没有明显加速,这个就很奇怪了,官方文档里并没有说专门针对FP16精度优化的。关于这两个问题,暂时猜测是环境问题,或许换个GPU硬件设备或者更新下驱动啥的就可能没有这些问题了。

这篇关于Pytorch官方FlashAttention速度测试的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/893656

相关文章

Mybatis官方生成器的使用方式

《Mybatis官方生成器的使用方式》本文详细介绍了MyBatisGenerator(MBG)的使用方法,通过实际代码示例展示了如何配置Maven插件来自动化生成MyBatis项目所需的实体类、Map... 目录1. MyBATis Generator 简介2. MyBatis Generator 的功能3

PyTorch使用教程之Tensor包详解

《PyTorch使用教程之Tensor包详解》这篇文章介绍了PyTorch中的张量(Tensor)数据结构,包括张量的数据类型、初始化、常用操作、属性等,张量是PyTorch框架中的核心数据结构,支持... 目录1、张量Tensor2、数据类型3、初始化(构造张量)4、常用操作5、常用属性5.1 存储(st

如何测试计算机的内存是否存在问题? 判断电脑内存故障的多种方法

《如何测试计算机的内存是否存在问题?判断电脑内存故障的多种方法》内存是电脑中非常重要的组件之一,如果内存出现故障,可能会导致电脑出现各种问题,如蓝屏、死机、程序崩溃等,如何判断内存是否出现故障呢?下... 如果你的电脑是崩溃、冻结还是不稳定,那么它的内存可能有问题。要进行检查,你可以使用Windows 11

性能测试介绍

性能测试是一种测试方法,旨在评估系统、应用程序或组件在现实场景中的性能表现和可靠性。它通常用于衡量系统在不同负载条件下的响应时间、吞吐量、资源利用率、稳定性和可扩展性等关键指标。 为什么要进行性能测试 通过性能测试,可以确定系统是否能够满足预期的性能要求,找出性能瓶颈和潜在的问题,并进行优化和调整。 发现性能瓶颈:性能测试可以帮助发现系统的性能瓶颈,即系统在高负载或高并发情况下可能出现的问题

字节面试 | 如何测试RocketMQ、RocketMQ?

字节面试:RocketMQ是怎么测试的呢? 答: 首先保证消息的消费正确、设计逆向用例,在验证消息内容为空等情况时的消费正确性; 推送大批量MQ,通过Admin控制台查看MQ消费的情况,是否出现消费假死、TPS是否正常等等问题。(上述都是临场发挥,但是RocketMQ真正的测试点,还真的需要探讨) 01 先了解RocketMQ 作为测试也是要简单了解RocketMQ。简单来说,就是一个分

活用c4d官方开发文档查询代码

当你问AI助手比如豆包,如何用python禁止掉xpresso标签时候,它会提示到 这时候要用到两个东西。https://developers.maxon.net/论坛搜索和开发文档 比如这里我就在官方找到正确的id描述 然后我就把参数标签换过来

【测试】输入正确用户名和密码,点击登录没有响应的可能性原因

目录 一、前端问题 1. 界面交互问题 2. 输入数据校验问题 二、网络问题 1. 网络连接中断 2. 代理设置问题 三、后端问题 1. 服务器故障 2. 数据库问题 3. 权限问题: 四、其他问题 1. 缓存问题 2. 第三方服务问题 3. 配置问题 一、前端问题 1. 界面交互问题 登录按钮的点击事件未正确绑定,导致点击后无法触发登录操作。 页面可能存在

业务中14个需要进行A/B测试的时刻[信息图]

在本指南中,我们将全面了解有关 A/B测试 的所有内容。 我们将介绍不同类型的A/B测试,如何有效地规划和启动测试,如何评估测试是否成功,您应该关注哪些指标,多年来我们发现的常见错误等等。 什么是A/B测试? A/B测试(有时称为“分割测试”)是一种实验类型,其中您创建两种或多种内容变体——如登录页面、电子邮件或广告——并将它们显示给不同的受众群体,以查看哪一种效果最好。 本质上,A/B测

Verybot之OpenCV应用一:安装与图像采集测试

在Verybot上安装OpenCV是很简单的,只需要执行:         sudo apt-get update         sudo apt-get install libopencv-dev         sudo apt-get install python-opencv         下面就对安装好的OpenCV进行一下测试,编写一个通过USB摄像头采

Adblock Plus官方规则Easylist China说明与反馈贴(2015.12.15)

-------------------------------特别说明--------------------------------------- 视频广告问题:因Adblock Plus的局限,存在以下现象,优酷、搜狐、17173黑屏并倒数;乐视、爱奇艺播放广告。因为这些视频网站的Flash播放器被植入了检测代码,而Adblock Plus无法修改播放器。 如需同时使用ads