奇妙之旅:SIMD加速矩阵运算

2024-06-16 22:08

本文主要是介绍奇妙之旅:SIMD加速矩阵运算,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

奇妙之旅:SIMD加速矩阵运算

  • 1.前言
  • 2.预备知识
  • 3.计算逻辑
  • 4.代码实战

1.前言

  游戏会涉及到大量4x4的矩阵乘法运算,而乘法最简单直观的实现就是循环4×4×4次乘法,以及若干次加法,得到结果。
  在计算量较少时,cpu并不是很紧张。然而游戏通常每秒伴随着大量运算,此时,计算的效率就显得尤为重要。
  通过查阅文献,我发现了一种SIMD(Single Instruction Multiple Data,单指令多数据流)技术,可以同时操作多个数据,极大的提高了cpu的运算能力。

2.预备知识

原理:

  • 以加法指令为例,单指令单数据(SISD)的CPU对加法指令译码后,执行部件先访问内存,取得第一个操作数;之后再一次访问内存,取得第二个操作数;随后才能进行求和运算。而在SIMD型的CPU中,指令译码后几个执行部件同时访问内存,一次性获得所有操作数进行运算。这个特点使SIMD特别适合于多媒体应用等数据密集型运算。

下列是一些AVX,AVX2指令集上的函数原型:
手册地址:https://software.intel.com/sites/landingpage/IntrinsicsGuide/

  • __m256 _mm256_i32gather_ps (float const* base_addr, __m256i vindex, const int scale)

    • 概要
      __m256 _mm256_i32gather_ps (float const* base_addr, __m256i vindex, const int scale)
      #include <immintrin.h>
      Instruction: vgatherdps ymm, vm32x, ymm
      CPUID Flags: AVX2

    • 描述
      Gather single-precision (32-bit) floating-point elements from memory using 32-bit indices. 32-bit elements are loaded from addresses starting at base_addr and offset by each 32-bit element in vindex (each index is scaled by the factor in scale). Gathered elements are merged into dst. scale should be 1, 2, 4 or 8.

    • 伪代码

      FOR j := 0 to 7i := j*32m := j*32addr := base_addr + SignExtend64(vindex[m+31:m]) * ZeroExtend64(scale) * 8dst[i+31:i] := MEM[addr+31:addr]
      ENDFOR
      dst[MAX:256] := 0
      
    • 口胡 解释:用于将base_addr指向的内存中的数据按vindex中放置的索引序号并乘以scale表示的字节大小收集数据。

  • __m256 _mm256_dp_ps (__m256 a, __m256 b, const int imm8)

    • 概要
      __m256 _mm256_dp_ps (__m256 a, __m256 b, const int imm8)
      #include <immintrin.h>
      Instruction: vdpps ymm, ymm, ymm, imm
      CPUID Flags: AVX

    • 描述
      Conditionally multiply the packed single-precision (32-bit) floating-point elements in a and b using the high 4 bits in imm8, sum the four products, and conditionally store the sum in dst using the low 4 bits of imm8.

    • 伪代码

      DEFINE DP(a[127:0], b[127:0], imm8[7:0]) {FOR j := 0 to 3i := j*32IF imm8[(4+j)%8]temp[i+31:i] := a[i+31:i] * b[i+31:i]ELSEtemp[i+31:i] := 0FIENDFORsum[31:0] := (temp[127:96] + temp[95:64]) + (temp[63:32] + temp[31:0])FOR j := 0 to 3i := j*32IF imm8[j%8]tmpdst[i+31:i] := sum[31:0]ELSEtmpdst[i+31:i] := 0FIENDFORRETURN tmpdst[127:0]
      }
      dst[127:0] := DP(a[127:0], b[127:0], imm8[7:0])
      dst[255:128] := DP(a[255:128], b[255:128], imm8[7:0])
      dst[MAX:256] := 0
      
    • 口胡 解释:将a,b的8个float值,分别相乘(如果imm8第4-7位为1),然后前4个求和,后4个求和,如果imm8的低0-3位为1就分配这个和值到目标区域。

好了,有了以上两个函数,我们就可以将4x4矩阵乘法的乘法次数降到8次了。
什么?你问怎么做?请往下看。

3.计算逻辑

  因为笔者的CPU只支持到AVX2,所以不能用512位的寄存器(哭~),这里就用256位寄存器做示范。
  考虑原来的乘法需要64次的float乘法,如果用256位寄存器(8个float)打包就可以降到 64 ÷ 8 = 8 64\div8 =8 64÷8=8次,太惊人了。

下面就是计算思路:
1. 记 X a b 为 矩 阵 X 的 第 a 行 和 第 b 行 构 成 的 一 个 8 个 f l o a t 值 的 寄 存 器 2. 记 X u v 为 矩 阵 X 的 第 u 列 和 第 v 列 构 成 的 一 个 8 个 f l o a t 值 的 寄 存 器 1.记X_{ab}为矩阵X的第a行和第b行构成的一个8个float值的寄存器\\ 2.记X^{uv}为矩阵X的第u列和第v列构成的一个8个float值的寄存器\\ 1.XabXab8float2.XuvXuv8float
matrix4 c;
__mm256 temp;
int mask = 0b11110001;

一次性计算两次点积
temp =_mm256_dp_ps ( A 12 , B 11 , m a s k ) ( A_{12} ,B^{11},mask) (A12,B11,mask)
c.data[0] = temp[0];
c.data[1] = temp[4];

temp =_mm256_dp_ps ( A 34 , B 11 , m a s k ) ( A_{34} ,B^{11},mask) (A34,B11,mask)
c.data[2] = temp[0];
c.data[3] = temp[4];

剩下的类似…
temp =_mm256_dp_ps ( A 12 , B 22 , m a s k ) ( A_{12} ,B^{22},mask) (A12,B22,mask)
c.data[4] = temp[0];
c.data[5] = temp[4];

temp =_mm256_dp_ps ( A 34 , B 22 , m a s k ) ( A_{34} ,B^{22},mask) (A34,B22,mask)
c.data[6] = temp[0];
c.data[7] = temp[4];

temp =_mm256_dp_ps ( A 12 , B 33 , m a s k ) ( A_{12} ,B^{33},mask) (A12,B33,mask)
c.data[8] = temp[0];
c.data[9] = temp[4];

temp =_mm256_dp_ps ( A 34 , B 33 , m a s k ) ( A_{34} ,B^{33},mask) (A34,B33,mask)
c.data[10] = temp[0];
c.data[11] = temp[4];

temp =_mm256_dp_ps ( A 12 , B 44 , m a s k ) ( A_{12} ,B^{44},mask) (A12,B44,mask)
c.data[12] = temp[0];
c.data[13] = temp[4];

temp =_mm256_dp_ps ( A 34 , B 44 , m a s k ) ( A_{34} ,B^{44},mask) (A34,B44,mask)
c.data[14] = temp[0];
c.data[15] = temp[4];

至此,一次矩阵乘法结束。
关于如何取 A 12 , A 34 , B 11 . . . A_{12},A_{34},B_{11}... A12A34B11...等值:

__declspec(align(16)) __m256i gatherA12 = _mm256_set_epi32(13, 9, 5, 1, 12, 8, 4, 0);
__declspec(align(16)) __m256i gatherA34 = _mm256_set_epi32(15, 11, 7, 3, 14, 10, 6, 2);__declspec(align(16)) __m256i gatherB11 = _mm256_set_epi32(3, 2, 1, 0, 3, 2, 1, 0);
__declspec(align(16)) __m256i gatherB22 = _mm256_set_epi32(7, 6, 5, 4, 7, 6, 5, 4);
__declspec(align(16)) __m256i gatherB33 = _mm256_set_epi32(11, 10, 9, 8, 11, 10, 9, 8);
__declspec(align(16)) __m256i gatherB44 = _mm256_set_epi32(15, 14, 13, 12, 15, 14, 13, 12);a12 = _mm256_i32gather_ps(_Left.data, gatherA12, sizeof(float));a34 = _mm256_i32gather_ps(_Left.data, gatherA34, sizeof(float));b11 = _mm256_i32gather_ps(_Right.data, gatherB11, sizeof(float));b22 = _mm256_i32gather_ps(_Right.data, gatherB22, sizeof(float));b33 = _mm256_i32gather_ps(_Right.data, gatherB33, sizeof(float));b44 = _mm256_i32gather_ps(_Right.data, gatherB44, sizeof(float));

__declspec(align(16)) 意味着要求编译器按16字节对齐,注意!!!SIMD要求字节对齐(当然偏要不对齐也行。。。)

_mm256_set_epi32 语义很简单,就是按照参数来向256位寄存器中放入32位的整数,但是需要注意的是,第一个参数是放到寄存器最末尾的。

4.代码实战

先放出运行结果:
在这里插入图片描述
第一个计时是循环1000万次的SIMD,仅用了不到1ms。
第二个计时是普通的循环1000万次。

以下是全部代码:

#include <iostream>
#include <windows.h>
#include <xmmintrin.h>
#include <immintrin.h>// 计时结构
LARGE_INTEGER t1, t2, tc;void time_begin()
{QueryPerformanceFrequency(&tc);QueryPerformanceCounter(&t1);
}float time_end()
{QueryPerformanceCounter(&t2);return ((t2.QuadPart - t1.QuadPart)*1.0 / tc.QuadPart) * 1000;
}class matrix4
{
public:matrix4():data{ 0 }{}matrix4(float v):data{ v,0,0,0,0,v,0,0,0,0,v,0,0,0,0,v }{}public:public:union{float data[16];float ptr[4][4];};
};// 列主序按行计算
matrix4 mul_1(const matrix4& m1, const matrix4& m2)
{matrix4 ret;for (int row = 0; row < 4; ++row) {for (int col = 0; col < 4; ++col) {for (int a = 0; a < 4; ++a) {ret.ptr[col][row] += m1.ptr[a][row] * m2.ptr[col][a];}}}return ret;
}__declspec(align(16)) matrix4 A(1.f);
__declspec(align(16)) matrix4 B(2.f);
__declspec(align(16)) float v[16]
= {0, 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15};__declspec(align(16)) __m256i gatherA12 = _mm256_set_epi32(13, 9, 5, 1, 12, 8, 4, 0);
__declspec(align(16)) __m256i gatherA34 = _mm256_set_epi32(15, 11, 7, 3, 14, 10, 6, 2);__declspec(align(16)) __m256i gatherB11 = _mm256_set_epi32(3, 2, 1, 0, 3, 2, 1, 0);
__declspec(align(16)) __m256i gatherB22 = _mm256_set_epi32(7, 6, 5, 4, 7, 6, 5, 4);
__declspec(align(16)) __m256i gatherB33 = _mm256_set_epi32(11, 10, 9, 8, 11, 10, 9, 8);
__declspec(align(16)) __m256i gatherB44 = _mm256_set_epi32(15, 14, 13, 12, 15, 14, 13, 12);auto mm_mul_mat(matrix4 const& _Left, matrix4 const& _Right)
{matrix4 ret;__declspec(align(16)) __m256 temp;__declspec(align(16)) __m256 a12, a34;__declspec(align(16)) __m256 b11, b22, b33, b44;a12 = _mm256_i32gather_ps(_Left.data, gatherA12, sizeof(float));a34 = _mm256_i32gather_ps(_Left.data, gatherA34, sizeof(float));b11 = _mm256_i32gather_ps(_Right.data, gatherB11, sizeof(float));b22 = _mm256_i32gather_ps(_Right.data, gatherB22, sizeof(float));b33 = _mm256_i32gather_ps(_Right.data, gatherB33, sizeof(float));b44 = _mm256_i32gather_ps(_Right.data, gatherB44, sizeof(float));temp = _mm256_dp_ps(a12, b11, 0b11110001);ret.data[0] = temp.m256_f32[0];ret.data[1] = temp.m256_f32[4];temp = _mm256_dp_ps(a34, b11, 0b11110001);ret.data[2] = temp.m256_f32[0];ret.data[3] = temp.m256_f32[4];temp = _mm256_dp_ps(a12, b22, 0b11110001);ret.data[4] = temp.m256_f32[0];ret.data[5] = temp.m256_f32[4];temp = _mm256_dp_ps(a34, b22, 0b11110001);ret.data[6] = temp.m256_f32[0];ret.data[7] = temp.m256_f32[4];temp = _mm256_dp_ps(a12, b33, 0b11110001);ret.data[8] = temp.m256_f32[0];ret.data[9] = temp.m256_f32[4];temp = _mm256_dp_ps(a34, b33, 0b11110001);ret.data[10] = temp.m256_f32[0];ret.data[11] = temp.m256_f32[4];temp = _mm256_dp_ps(a12, b44, 0b11110001);ret.data[12] = temp.m256_f32[0];ret.data[13] = temp.m256_f32[4];temp = _mm256_dp_ps(a34, b44, 0b11110001);ret.data[14] = temp.m256_f32[0];ret.data[15] = temp.m256_f32[4];return ret;
}int main()
{time_begin();for (int i = 0; i < 10000000; ++i){volatile auto res = mm_mul_mat(A, B);}::std::cout << "SIMD-AVX \ttime: " << time_end() << ::std::endl;time_begin();for (int i = 0; i < 10000000; ++i){volatile auto m = mul_1(A, B);}::std::cout << "SISD TRIVIAL \ttime: " << time_end() << ::std::endl;return 0;
}

以上。

PS:真是一次奇妙的体验呢~

这篇关于奇妙之旅:SIMD加速矩阵运算的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1067679

相关文章

时间服务器中,适用于国内的 NTP 服务器地址,可用于时间同步或 Android 加速 GPS 定位

NTP 是什么?   NTP 是网络时间协议(Network Time Protocol),它用来同步网络设备【如计算机、手机】的时间的协议。 NTP 实现什么目的?   目的很简单,就是为了提供准确时间。因为我们的手表、设备等,经常会时间跑着跑着就有误差,或快或慢的少几秒,时间长了甚至误差过分钟。 NTP 服务器列表 最常见、熟知的就是 www.pool.ntp.org/zo

C语言入门系列:探秘二级指针与多级指针的奇妙世界

文章目录 一,指针的回忆杀1,指针的概念2,指针的声明和赋值3,指针的使用3.1 直接给指针变量赋值3.2 通过*运算符读写指针指向的内存3.2.1 读3.2.2 写 二,二级指针详解1,定义2,示例说明3,二级指针与一级指针、普通变量的关系3.1,与一级指针的关系3.2,与普通变量的关系,示例说明 4,二级指针的常见用途5,二级指针扩展到多级指针 小结 C语言的学习之旅中,二级

大型网站架构演化(六)——使用反向代理和CDN加速网站响应

随着网站业务不断发展,用户规模越来越大,由于中国复杂的网络环境,不同地区的用户访问网站时,速度差别也极大。有研究表明,网站访问延迟和用户流失率正相关,网站访问越慢,用户越容易失去耐心而离开。为了提供更好的用户体验,留住用户,网站需要加速网站访问速度。      主要手段:使用CDN和反向代理。如图。     使用CDN和反向代理的目的都是尽早返回数据给用户,一方面加快用户访问速

Android热修复学习之旅——Andfix框架完全解析

Android热修复学习之旅开篇——热修复概述 Android热修复学习之旅——HotFix完全解析 Android热修复学习之旅——Tinker接入全攻略 在之前的博客《Android热修复学习之旅——HotFix完全解析》中,我们学习了热修复的实现方式之一,通过dex分包方案的原理还有HotFix框架的源码分析,本次我将讲解热修复的另外一种思路,那就是通过native方法,使用这种思路

Android热修复学习之旅——HotFix完全解析

在上一篇博客 Android热修复学习之旅开篇——热修复概述中,简单介绍了各个热修复框架的原理,本篇博客我将详细分析QQ空间热修复方案。 Android dex分包原理介绍 QQ空间热修复方案基于Android dex分包基础之上,简单概述android dex分包的原理就是:就是把多个dex文件塞入到app的classloader之中,但是android dex拆包方案中的类是没有重复的,如

Android热修复学习之旅开篇——热修复概述

Android热修复技术无疑是Android领域近年来最火热的技术之一,同时也涌现了各种层出不穷的实现方案,如QQ空间补丁方案、阿里AndFix以及微信Tinker等等,从本篇博客开始,计划写一个系列博客专门介绍热修复的相关内容,本系列博客将一一介绍这些框架的原理和源码分析,作为本系列的开篇,本篇博客将对热修复技术进行一个概述,并对以上几种方案进行对比。 为什么会出现热修复? 简单来说,以前出

Java日常探秘-从小疑问到实践智慧的编程之旅(1)

文章目录 前言一、Git中回滚操作的方式二、加密为第三方服务,需要rpc,怎么提高效率三、加解密需求,逻辑能够尽量收敛四、加解密优化五、加解密的rpc失败了处理机制六、优化MySQL查询总结 前言 所有分享的内容源于日常思考和实践,探讨Java编程中的小知识点和实用场景,加深自己对编程技巧和理解Java深层次的原理,期待发现妙招和解决实际问题的新思路。 一、Gi

运算放大器(运放)低通滤波反相放大器电路和积分器电路

低通滤波反相放大器电路 运放积分器电路请访问下行链接 运算放大器(运放)积分器电路 设计目标 输入ViMin输入ViMax输出VoMin输出VoMaxBW:fp电源Vee电源Vcc–0.1V0.1V–2V2V2kHz–2.5V2.5V 设计说明 这款可调式低通反相放大器电路可将信号电平放大 26dB 或 20V/V。R2 和 C1 可设置此电路的截止频率。此电路的频率响应与无源 RC 滤

无法解决 equal to 运算中 Chinese_PRC_90_CI_AS 和 Chinese_PRC_BIN 之间的排序规则冲突

这是因为数据库 oa 和 hh 的编码格式不一样导致的 select  groupname as oper_id,name as oper_name from security_users where name collate Chinese_PRC_CI_AS not in (select oper_name from PDA_UsersAndPWD )

轻松上手MYSQL:MYSQL事务隔离级别的奇幻之旅

​🌈 个人主页:danci_ 🔥 系列专栏:《设计模式》《MYSQL》 💪🏻 制定明确可量化的目标,坚持默默的做事。 ✨欢迎加入探索MYSQL索引数据结构之旅✨     👋 大家好!文本学习研究事务隔离级别。👋 无论您是刚接触MySQL的初学者,还是希望深入优化性能的资深开发者,这篇文章都将为您揭开MySQL事务隔离级别的神秘面纱,让您掌握其中的奥秘,进而