From self-attention 2 flash-attention 数学原理与 cuda 实现优化

2024-06-09 07:44

本文主要是介绍From self-attention 2 flash-attention 数学原理与 cuda 实现优化,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

self attension 是transformer 编码器和解码器中共同的一个计算环节,在整个transformer 网络体系中耗费的算力比例占主导。所以节省self attention 的正向和反向的计算时间,就可以加速 transormer 的训练和推理过程。

1,self attention 的数学提炼

两个矩阵乘法,加入一个列向的softmax

input   矩阵: \mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbf{R}^{N \times d}

output 矩阵:\mathbf{O} \in \mathbf{R}^{N \times d}

 

\mathbf{self\ attention\ algorithm:}

        step1:        \mathbf{S} = \mathbf{Q}*\mathbf{K}^t

        step2:        \mathbf{P} = \mathbf{softmax_{column}(S)}

        step3:        \mathbf{O} = \mathbf{P}*\mathbf{V}

2,cpu 实现self attention

这里的数据类型使用了 float,实际网络中一般采用 fp16,数学过程是相同的;

cpu_self_attention.cpp

#include <stdio.h>
#include <string.h>#include "cpu_gemm.h"
#include "utils.h"
#include "soft_max.h"
//all matrices are row major.void cpu_self_attention(float* Q, int ldq,float* K, int ldk,float* V, int ldv,float* S, int lds,float* P, int ldp,float* O, int ldo,int N, int d)
{gemm_nt(Q, ldq, K, ldk, S, lds, N, N, d);// S = Q*K^t     (NxN) = (Nxd) * (dxN)printf("\nS =\n");	print_matrix(S, N, N, lds);soft_max_column(P, ldp, S, lds, N, N);// P(NxN) = softmax(S(NxN))printf("\nP =\n");	print_matrix(S, N, N, lds);gemm_nn(P, ldp, V, ldv, O, ldo, N, d, N);// O = P*V     (Nxd) = (NxN) * (Nxd)
}

cpu_gemm.cpp

#include "cpu_gemm.h"void gemm_nn(float *A, int lda,		//A(M x K) rowMjfloat *B, int ldb,		//B(K x N) rowMjfloat *C, int ldc,		//C(M x N) rowMjint M,int N,int K)
{for(int i=0; i<M; i++){for(int j=0; j<N; j++){float sigma = 0.0;for(int k=0; k<K; k++){sigma += A[i*lda + k] * B[k*ldb + j];}C[i*ldc + j] = sigma;}}
}void gemm_nt(float *A, int lda,		//A(M x K) rowMjfloat *B, int ldb,		//B(N x K) rowMjfloat *C, int ldc,		//C(M x N) rowMjint M,int N,int K)
{for(int i=0; i<M; i++){for(int j=0; j<N; j++){float sigma = 0.0;for(int k=0; k<K; k++){sigma += A[i*lda + k] * B[k + j*ldb];}C[i*ldc + j] = sigma;}}
}

cpu_softmax_column.cpp

这里使用的是未数值优化的方式,直接按照原始公式计算:

#include "soft_max.h"
void soft_max_column(float *P, int ldp, float* S, int lds, int M, int N)//P = softmax(S)  P(i,j) = exp(S(i,j))/sigma(exp(S(r,j)));  r=0,1,..,n-1 ;
{for(int j=0; j<N; j++){float sigma = 0.0f;for(int i=0; i<M; i++){sigma += exp(S[i*lds + j])}for(int i=0; i<M; i++){P[i*ldp + j] = S[i*lds + j]/sigma;}}
}

3, gpu 实现 self attention 正向

cuda 实现上述过程:

gpu_self_attention.cu

gpu_gemm.cu

gpu_softmax_column.cu

4,为什么不需要gpu 实现self attention 反向

融合上述过程

5, gpu 实现 flash attention 反向

融合算子

数学原理

cuda 实现

挖坑,未完待续 。。。

这篇关于From self-attention 2 flash-attention 数学原理与 cuda 实现优化的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1044578

相关文章

python生成随机唯一id的几种实现方法

《python生成随机唯一id的几种实现方法》在Python中生成随机唯一ID有多种方法,根据不同的需求场景可以选择最适合的方案,文中通过示例代码介绍的非常详细,需要的朋友们下面随着小编来一起学习学习... 目录方法 1:使用 UUID 模块(推荐)方法 2:使用 Secrets 模块(安全敏感场景)方法

MySQL深分页进行性能优化的常见方法

《MySQL深分页进行性能优化的常见方法》在Web应用中,分页查询是数据库操作中的常见需求,然而,在面对大型数据集时,深分页(deeppagination)却成为了性能优化的一个挑战,在本文中,我们将... 目录引言:深分页,真的只是“翻页慢”那么简单吗?一、背景介绍二、深分页的性能问题三、业务场景分析四、

Spring StateMachine实现状态机使用示例详解

《SpringStateMachine实现状态机使用示例详解》本文介绍SpringStateMachine实现状态机的步骤,包括依赖导入、枚举定义、状态转移规则配置、上下文管理及服务调用示例,重点解... 目录什么是状态机使用示例什么是状态机状态机是计算机科学中的​​核心建模工具​​,用于描述对象在其生命

Spring Boot 结合 WxJava 实现文章上传微信公众号草稿箱与群发

《SpringBoot结合WxJava实现文章上传微信公众号草稿箱与群发》本文将详细介绍如何使用SpringBoot框架结合WxJava开发工具包,实现文章上传到微信公众号草稿箱以及群发功能,... 目录一、项目环境准备1.1 开发环境1.2 微信公众号准备二、Spring Boot 项目搭建2.1 创建

Linux进程CPU绑定优化与实践过程

《Linux进程CPU绑定优化与实践过程》Linux支持进程绑定至特定CPU核心,通过sched_setaffinity系统调用和taskset工具实现,优化缓存效率与上下文切换,提升多核计算性能,适... 目录1. 多核处理器及并行计算概念1.1 多核处理器架构概述1.2 并行计算的含义及重要性1.3 并

IntelliJ IDEA2025创建SpringBoot项目的实现步骤

《IntelliJIDEA2025创建SpringBoot项目的实现步骤》本文主要介绍了IntelliJIDEA2025创建SpringBoot项目的实现步骤,文中通过示例代码介绍的非常详细,对大家... 目录一、创建 Spring Boot 项目1. 新建项目2. 基础配置3. 选择依赖4. 生成项目5.

Linux下删除乱码文件和目录的实现方式

《Linux下删除乱码文件和目录的实现方式》:本文主要介绍Linux下删除乱码文件和目录的实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录linux下删除乱码文件和目录方法1方法2总结Linux下删除乱码文件和目录方法1使用ls -i命令找到文件或目录

SpringBoot+EasyExcel实现自定义复杂样式导入导出

《SpringBoot+EasyExcel实现自定义复杂样式导入导出》这篇文章主要为大家详细介绍了SpringBoot如何结果EasyExcel实现自定义复杂样式导入导出功能,文中的示例代码讲解详细,... 目录安装处理自定义导出复杂场景1、列不固定,动态列2、动态下拉3、自定义锁定行/列,添加密码4、合并

mybatis执行insert返回id实现详解

《mybatis执行insert返回id实现详解》MyBatis插入操作默认返回受影响行数,需通过useGeneratedKeys+keyProperty或selectKey获取主键ID,确保主键为自... 目录 两种方式获取自增 ID:1. ​​useGeneratedKeys+keyProperty(推

Spring Boot集成Druid实现数据源管理与监控的详细步骤

《SpringBoot集成Druid实现数据源管理与监控的详细步骤》本文介绍如何在SpringBoot项目中集成Druid数据库连接池,包括环境搭建、Maven依赖配置、SpringBoot配置文件... 目录1. 引言1.1 环境准备1.2 Druid介绍2. 配置Druid连接池3. 查看Druid监控