Pytorch官方FlashAttention速度测试

2024-04-11 09:12

本文主要是介绍Pytorch官方FlashAttention速度测试,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在Pytorch的2.2版本更新文档中,官方重点强调了通过实现FlashAtteneion-v2实现了对scaled_dot_product_attention约2X左右的加速。
在这里插入图片描述
今天抽空亲自试了下,看看加速效果是否如官方所说。测试前需要将Pytorch的版本更新到2.2及以上,下面是测试代码,一个是原始手写的Self-Attention的实现,一个是使用Pytorch官方的scaled_dot_product_attention接口:

import time
import torch
import torch.nn.functional as Fdef main():repeat = 100device = torch.device("cuda:0")dtype = torch.float16query = torch.rand(32, 8, 128, 64, dtype=dtype, device=device)key = torch.rand(32, 8, 128, 64, dtype=dtype, device=device)value = torch.rand(32, 8, 128, 64, dtype=dtype, device=device)scale_factor = 0.125ori_time_list = []for _ in range(repeat):torch.cuda.synchronize(device=device)time_start = time.perf_counter()# 原始Self-Attention实现res = torch.softmax(query @ key.transpose(-2, -1) * scale_factor, dim=-1) @ valuetorch.cuda.synchronize(device=device)time_end = time.perf_counter()ori_time_list.append(time_end - time_start)fa_time_list = []for _ in range(repeat):torch.cuda.synchronize(device=device)time_start = time.perf_counter()with torch.backends.cuda.sdp_kernel(enable_math=False):# 使用Pytorch官方提供的FA实现res_fa = F.scaled_dot_product_attention(query, key, value, scale=scale_factor)torch.cuda.synchronize(device=device)time_end = time.perf_counter()fa_time_list.append(time_end - time_start)diff = (res - res_fa).abs().max()ratio = [ori_time_list[i] / fa_time_list[i] for i in range(repeat)]avg_ratio = sum(ratio[1:]) / len(ratio[1:])print(f"max diff: {diff}")print(f"avg speed up ratio: {avg_ratio}")if __name__ == '__main__':main()

执行以上代码,终端输出如下:

max diff: 0.00048828125
avg speed up ratio: 2.2846881043417118

这里使用的设备是RTX4070,跑了很多次发现确实加速2X左右,看来以后训练或者推理时可以考虑直接使用官方的scaled_dot_product_attention接口了。但是这里也发现了两个问题,一个是原始手写的Self-Attention的计算结果和直接调用scaled_dot_product_attention接口得到的结果差异有点大(注意,这里计算的Tensor都是FP16精度的),如果我切换到FP32精度差异会再小两个数量级。第二个问题是如果使用FP32的话实测没有明显加速,这个就很奇怪了,官方文档里并没有说专门针对FP16精度优化的。关于这两个问题,暂时猜测是环境问题,或许换个GPU硬件设备或者更新下驱动啥的就可能没有这些问题了。

这篇关于Pytorch官方FlashAttention速度测试的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/893656

相关文章

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确

pytorch+torchvision+python版本对应及环境安装

《pytorch+torchvision+python版本对应及环境安装》本文主要介绍了pytorch+torchvision+python版本对应及环境安装,安装过程中需要注意Numpy版本的降级,... 目录一、版本对应二、安装命令(pip)1. 版本2. 安装全过程3. 命令相关解释参考文章一、版本对

从零教你安装pytorch并在pycharm中使用

《从零教你安装pytorch并在pycharm中使用》本文详细介绍了如何使用Anaconda包管理工具创建虚拟环境,并安装CUDA加速平台和PyTorch库,同时在PyCharm中配置和使用PyTor... 目录背景介绍安装Anaconda安装CUDA安装pytorch报错解决——fbgemm.dll连接p

pycharm远程连接服务器运行pytorch的过程详解

《pycharm远程连接服务器运行pytorch的过程详解》:本文主要介绍在Linux环境下使用Anaconda管理不同版本的Python环境,并通过PyCharm远程连接服务器来运行PyTorc... 目录linux部署pytorch背景介绍Anaconda安装Linux安装pytorch虚拟环境安装cu

SpringBoot中整合RabbitMQ(测试+部署上线最新完整)的过程

《SpringBoot中整合RabbitMQ(测试+部署上线最新完整)的过程》本文详细介绍了如何在虚拟机和宝塔面板中安装RabbitMQ,并使用Java代码实现消息的发送和接收,通过异步通讯,可以优化... 目录一、RabbitMQ安装二、启动RabbitMQ三、javascript编写Java代码1、引入

Nginx设置连接超时并进行测试的方法步骤

《Nginx设置连接超时并进行测试的方法步骤》在高并发场景下,如果客户端与服务器的连接长时间未响应,会占用大量的系统资源,影响其他正常请求的处理效率,为了解决这个问题,可以通过设置Nginx的连接... 目录设置连接超时目的操作步骤测试连接超时测试方法:总结:设置连接超时目的设置客户端与服务器之间的连接

Mybatis官方生成器的使用方式

《Mybatis官方生成器的使用方式》本文详细介绍了MyBatisGenerator(MBG)的使用方法,通过实际代码示例展示了如何配置Maven插件来自动化生成MyBatis项目所需的实体类、Map... 目录1. MyBATis Generator 简介2. MyBatis Generator 的功能3

PyTorch使用教程之Tensor包详解

《PyTorch使用教程之Tensor包详解》这篇文章介绍了PyTorch中的张量(Tensor)数据结构,包括张量的数据类型、初始化、常用操作、属性等,张量是PyTorch框架中的核心数据结构,支持... 目录1、张量Tensor2、数据类型3、初始化(构造张量)4、常用操作5、常用属性5.1 存储(st

如何测试计算机的内存是否存在问题? 判断电脑内存故障的多种方法

《如何测试计算机的内存是否存在问题?判断电脑内存故障的多种方法》内存是电脑中非常重要的组件之一,如果内存出现故障,可能会导致电脑出现各种问题,如蓝屏、死机、程序崩溃等,如何判断内存是否出现故障呢?下... 如果你的电脑是崩溃、冻结还是不稳定,那么它的内存可能有问题。要进行检查,你可以使用Windows 11