本文主要是介绍AVX介绍,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
1 AVX 介绍
1.1 SIMD
SIMD:Single Instruction Multiple Data,单指令多数据,一个指令可以控制多个数据进行操作。
最简单的例子,在向量加法中,对每一个维度的值,都要进行加法运算:
// a=[a1, a2, a3, a4], b=[b1, b2, b3, b4]
sum[0]=a[0]+b[0];
sum[1]=a[1]+b[1];
sum[2]=a[2]+b[2];
sum[3]=a[3]+b[3];
在这里,使用的是单指令单数据(SISD)的处理方式,要进行四次加法运算,就当真是进行了四次加法运算,使用了四次加法指令。
那么,也可以使用单指令多数据(SIMD)的方式来处理:
sum_vector4 = a_vector4 + b_vector4;
进行4次加法只需要进行一个长度为4的向量加法,只是用了一次向量加法的指令。
在SIMD指令集中,可以控制若干个大寄存器,把这些寄存器中的数据按照某些规则进行统一的操作,相当于一条指令可以完成好几次重复运算,从而达到加快运算速度的效果。这里没有使用汇编指令,而是利用所谓的Intrinsics(内置)函数。
1.2 Intrinsics:直接映射到汇编的函数(指令)
(Intel® Intrinsics Guide)[https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html]
1.3 AVX 简介
在 CPU 中,有一些诸如 xmm,ymm,zmm 的寄存器,它们是用于SIMD指令的基础寄存器。为了充分利用寄存器,它们既可以用于整数计算,也可以用于浮点数计算(32位和64位都可以),为了区分寄存器的实际用途,指令具有这样的格式:
_mm256_add_ps
// 指令前缀,mm代表是SIMD指令集,256代表使用的寄存器宽度是256位
// 指令作用,add代表对每个元素进行加法运算
// 操作数类型,ps解释成packed single(float),即单浮点数向量。pd解释为packed double,还有各种整数类型// 为了防止混用了类型,预定义了三种类型:
__m256
// 最常用的类型,位宽256,解释为8个float__m256d
// 位宽256,解释为4个double__m256i
// 位宽256,解释为若干个整数,具体怎么做取决于指令(因为整数的按位拆分组合更加频繁)
指令与类型是相互匹配的。如果必要也可以用convert
指令进行强制类型转换,它不产生机器码,只是告诉编译器这个寄存器的原值可以直接用于其他类型。
2 AVX 浮点系列指令简介
与普通结构的指令集类似,AVX系列指令也可以分为若干类型。
2.1 内存访问指令
f32x8_p源代码(局部)
static f32x8_p load_8floats(const void *arr) { return f32x8_p(_mm256_loadu_ps((float *)arr)); }
static f32x8_p load_1float_broadcast(float *arr) { return f32x8_p(_mm256_broadcast_ss(arr)); }
static f32x8_p load_4floats_broadcast(const __m128 *arr) { return f32x8_p(_mm256_broadcast_ps(arr)); }
static f32x8_p load_mask(float const *arr, __m256i mask) { return f32x8_p(_mm256_maskload_ps(arr, mask)); }void store(float *a) { _mm256_storeu_ps(a, data); }
void store(f32x8_p *a) { _mm256_store_ps((float *)a, data); }
void load(float *a) { data = _mm256_loadu_ps(a); }
void load(f32x8_p *a) { data = _mm256_load_ps((float *)a); }
可以看到有若干store/load
指令。同时由于SIMD指令的特殊性,给出了一个广播加载指令,可以用一个数据初始化所有值,在某些时候会用到。
2.2 算数运算指令
f32x8_p源代码(局部)
f32x8_p operator+(f32x8_p a) { return f32x8_p(_mm256_add_ps(data, a.data)); }
f32x8_p operator-(f32x8_p a) { return f32x8_p(_mm256_sub_ps(data, a.data)); }
f32x8_p operator*(f32x8_p a) { return f32x8_p(_mm256_mul_ps(data, a.data)); }
f32x8_p operator/(f32x8_p a) { return f32x8_p(_mm256_div_ps(data, a.data)); }
void operator+=(f32x8_p a) { data = _mm256_add_ps(data, a.data); }
void operator-=(f32x8_p a) { data = _mm256_sub_ps(data, a.data); }
void operator*=(f32x8_p a) { data = _mm256_mul_ps(data, a.data); }
void operator/=(f32x8_p a) { data = _mm256_div_ps(data, a.data); }
2.3 逻辑运算指令
各个数据类型下的逻辑运算并没有结果上的区别,虽然指令看上去可能不一样
f32x8_p operator&(f32x8_p a) { return f32x8_p(_mm256_and_ps(data, a.data)); }
void operator&=(f32x8_p a) { data = _mm256_and_ps(data, a.data); }
f32x8_p operator|(f32x8_p a) { return f32x8_p(_mm256_or_ps(data, a.data)); }
void operator|=(f32x8_p a) { data = _mm256_or_ps(data, a.data); }
f32x8_p operator^(f32x8_p a) { return f32x8_p(_mm256_xor_ps(data, a.data)); }
void operator^=(f32x8_p a) { data = _mm256_xor_ps(data, a.data); }
static f32x8_p andnot(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_andnot_ps(a.data, b.data)); }
void do_andnot(f32x8_p a) { data = _mm256_andnot_ps(data, a.data); }
2.4 高级算数运算指令
对数据进行某种不统一但是有规律的算数运算:
// {a[0]-b[0], a[1]+b[1], a[2]-b[2], a[3]+b[3]...}
void addsub(f32x8_p a) { data = _mm256_addsub_ps(data, a.data); }
static f32x8_p addsub(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_addsub_ps(a.data, b.data)); }
// {a[0]+a[1], a[2]+a[3], b[0]+b[1], b[2]+b[3]...}
void hadd(f32x8_p a) { data = _mm256_hadd_ps(data, a.data); }
static f32x8_p hadd(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_hadd_ps(a.data, b.data)); }
// {a[0]-a[1], a[2]-a[3], b[0]-b[1], b[2]-b[3]...}
void hsub(f32x8_p a) { data = _mm256_hsub_ps(data, a.data); }
static f32x8_p hsub(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_hsub_ps(a.data, b.data)); }
// {a[0]*mul[0]+add[0], a[1]*mul[1]+add[1]...}
void mul_then_add(f32x8_p mul, f32x8_p add) { data = _mm256_fmadd_ps(data, mul.data, add.data); }
void fmadd(f32x8_p mul, f32x8_p add) { data = _mm256_fmadd_ps(data, mul.data, add.data); }
static f32x8_p mul_then_add(f32x8_p a, f32x8_p mul, f32x8_p add) { return f32x8_p(_mm256_fmadd_ps(a.data, mul.data, add.data)); }
// {a[0]*mul[0]-add[0], a[1]*mul[1]-add[1]...}
void mul_then_sub(f32x8_p mul, f32x8_p sub) { data = _mm256_fmsub_ps(data, mul.data, sub.data); }
void fmsub(f32x8_p mul, f32x8_p sub) { data = _mm256_fmsub_ps(data, mul.data, sub.data); }
static f32x8_p mul_then_sub(f32x8_p a, f32x8_p mul, f32x8_p sub) { return f32x8_p(_mm256_fmsub_ps(a.data, mul.data, sub.data)); }
// {-a[0]*mul[0]+add[0], -a[1]*mul[1]+add[1]...}
void neg_mul_then_add(f32x8_p mul, f32x8_p add) { data = _mm256_fnmadd_ps(data, mul.data, add.data); }
void fnmadd(f32x8_p mul, f32x8_p add) { data = _mm256_fnmadd_ps(data, mul.data, add.data); }
static f32x8_p neg_mul_then_add(f32x8_p a, f32x8_p mul, f32x8_p add) { return f32x8_p(_mm256_fnmadd_ps(a.data, mul.data, add.data)); }
// {-a[0]*mul[0]-add[0], -a[1]*mul[1]-add[1]...}
void neg_mul_then_sub(f32x8_p mul, f32x8_p sub) { data = _mm256_fnmsub_ps(data, mul.data, sub.data); }
void fnmsub(f32x8_p mul, f32x8_p sub) { data = _mm256_fnmsub_ps(data, mul.data, sub.data); }
static f32x8_p neg_mul_then_sub(f32x8_p a, f32x8_p mul, f32x8_p sub) { return f32x8_p(_mm256_fnmsub_ps(a.data, mul.data, sub.data)); }
// {a[0] * mul[0] + addsub[0], a[1] * mul[1] - addsub[1]...}
void mul_addsub(f32x8_p mul, f32x8_p addsub) { data = _mm256_fmaddsub_ps(data, mul.data, addsub.data); }
void fmaddsub(f32x8_p mul, f32x8_p addsub) { data = _mm256_fmaddsub_ps(data, mul.data, addsub.data); }
static f32x8_p mul_addsub(f32x8_p a, f32x8_p mul, f32x8_p addsub) { return f32x8_p(_mm256_fmaddsub_ps(a.data, mul.data, addsub.data)); }
// {a[0] * mul[0] - subadd[0], a[1] * mul[1] + subadd[1]...}
void mul_subadd(f32x8_p mul, f32x8_p subadd) { data = _mm256_fmsubadd_ps(data, mul.data, subadd.data); }
void fmsubadd(f32x8_p mul, f32x8_p subadd) { data = _mm256_fmsubadd_ps(data, mul.data, subadd.data); }
static f32x8_p mul_subadd(f32x8_p a, f32x8_p mul, f32x8_p subadd) { return f32x8_p(_mm256_fmsubadd_ps(a.data, mul.data, subadd.data)); }
交错加减运算、横向加减法运算(向量内部的相邻元素做加减法)、带乘法的三操作数加减运算。
2.5 数据重排指令
将一个寄存器的数据按照一定的规则重排,或者将两个寄存器的数据按照一定规则重排到一个寄存器;有时候需要运算的数据并没有正对着,需要重新排列一下。
需要注意的是,作为整数的重排规则需要是编译期常量,因为这是机器码的一部分。因此需要使用模板参数而非普通函数参数。
这里的注释中,对imm8的运算符[]是按位访问的,imm[1:0]代表由imm的最低两位组成的0~3的整数值。
// {a[2], b[2], a[3], b[3], a[6], b[6], a[7], b[7]}
static f32x8_p unpack_high(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_unpackhi_ps(a.data, b.data)); }
// {a[0], b[0], a[1], b[1], a[4], b[4], a[5], b[5]}
static f32x8_p unpack_low(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_unpacklo_ps(a.data, b.data)); }// for j = 0 to 7, ret[j] = imm8[j]?a[j]:b[j]
template <uint8_t _imm8>
void blend(f32x8_p a) { data = _mm256_blend_ps(data, a.data, _imm8); }
template <uint8_t _imm8>
static f32x8_p blend(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_blend_ps(a.data, b.data, _imm8)); }// for j = 0 to 7, ret[j] = mask[j].signbit?a[j]:b[j]
void blend(f32x8_p a, f32x8_p mask) { data = _mm256_blendv_ps(data, a.data, mask.data); }
static f32x8_p blend(f32x8_p a, f32x8_p b, f32x8_p mask) { return f32x8_p(_mm256_blendv_ps(a.data, b.data, mask.data)); }// see f32x4_p::shuffle. do the same shuffle for both high and low 128bit
// {a[imm[1:0]], a[imm[3:2]], b[imm[5:4]], b[imm[7:6]]}
template <uint8_t _imm8>
void shuffle(f32x8_p a) { data = _mm256_shuffle_ps(data, a.data, _imm8); }// {a[imm[1:0]], a[imm[3:2]], a[imm[5:4]], a[imm[7:6]]}
template <uint8_t _imm8>
void permute() { data = _mm256_permute_ps(data, _imm8); }// ret.low = imm8[2]? 0 : switch imm8[1:0] {case 0: a.low; case 1: a.high; case 2: b.low; case 3: b.high}
// ret.high = imm8[6]? 0 : switch imm8[5:4] {case 0: a.low; case 1: a.high; case 2: b.low; case 3: b.high}
template <uint8_t _imm8>
void permute2f128(f32x8_p a) { data = _mm256_permute2f128_ps(data, a.data, _imm8); }void move_high_dup() { data = _mm256_movehdup_ps(data); }
void move_odd2even() { data = _mm256_movehdup_ps(data); }
f32x8_p copy_odd2even() { return f32x8_p(_mm256_movehdup_ps(data)); }
void move_low_dup() { data = _mm256_moveldup_ps(data); }
void move_even2odd() { data = _mm256_moveldup_ps(data); }
f32x8_p copy_even2odd() { return f32x8_p(_mm256_moveldup_ps(data)); }
2.6 常用函数指令
取倒数、开平方根、取平方根的倒数:
void rcp() { data = _mm256_rcp_ps(data); }
void sqrt() { data = _mm256_sqrt_ps(data); }
void rsqrt() { data = _mm256_rsqrt_ps(data); }
2.7 掩码指令
通过掩码指定真正参与运算的操作数是来自哪个向量。
3 复数乘法
这里讨论的是最经典的复数存储格式,实部与虚部交错存储:
{c[0].re, c[0].im, c[1].re, c[1].im, c[2].re, c[2].im, c[3].re, c[3].im}
为简便起见,我们暂时只看第一个复数 a + b i a+bi a+bi:
var a_bi={a, b, ...}
把它乘以另一个复数 c + d i c+di c+di
var c_di={c, d, ...}
希望得到的结果是 ( a c − b d ) + ( a d + b c ) i (ac-bd) + (ad+bc)i (ac−bd)+(ad+bc)i
{ac-bd, ad+bc, ...}
3.1 方法1
static fc32x4_p multiply_complex_v0(f32x8_p a_bI, f32x8_p c_dI)
{ // real: a*c - b*d, imag: a*d + b*cf32x8_p ac_bdI = a_bI * c_dI;f32x8_p ad_bcI = a_bI * c_dI.reordered<0b10'11'00'01>();ac_bdI.hsub(ac_bdI);ad_bcI.hadd(ad_bcI);return fc32x4_p(ac_bdI.remixed<0b11'10'01'00>(ad_bcI) // {r0,r1,i0,i1}.reordered<0b11'01'10'00>()); // {r0,i0,r1,i1}
}
- 说明
-
ac_bdI = a_bI * c_dI
:它的值为a*c,b*d, ...
,实际上就是实部的两个值 -
重排
c_dI
, 重排规则0b10'11'00'01
的含义如下:0b10'11'00'01dst[0]=src[0b01]dst[1]=src[0b00]dst[2]=src[0b11]dst[3]=src[0b10]
实际上就是 交换了它的实部与虚部,得到了: d + c i d+ci d+ci
-
a_bI * c_dI.reordered<0b10'11'00'01>() = ad_bcI = a_bI * d_cI
, 它的值为ad, bc, ...
, 实际上就是虚部的两个值 -
横向减法得到实部,第二操作数为自身
ac_bdI.hsub(ac_bdI);
它的效果是,对每个128bit:
dst[0]=a[1]-a[0] dst[1]=a[3]-a[2] dst[2]=b[1]-b[0] dst[3]=b[3]-b[2]
得到的结果是这样的:
c[0].re, c[1].re, c[0].re, c[1].re, ...
-
进行横向加法得到虚部
ad_bcI.hadd(ad_bcI);
虚部:
c[0].im, c[1].im, c[0].im, c[1].im, ...
-
对实部和虚部进行适当的重排即可得到答案
合并,去掉重复的部分:结果的前两路来自第一操作数,后两路来自第二操作数return fc32x4_p(ac_bdI.remixed<0b11'10'01'00>(ad_bcI) // {r0,r1,i0,i1}
实部挨在一起了:
c[0].re, c[1].re, c[0].im, c[1].im, ...
交换一下[1]号元素和[2]号元素得到最终结果:
.reordered<0b11'01'10'00>()); // {r0,i0,r1,i1}
-
3.2 方法二
// complex mul complex, use addsub, cost about 90% time of _v0
static fc32x4_p multiply_complex_v1(f32x8_p a_bI, f32x8_p c_dI)
{ // real: a*c - b*d, imag: a*d + b*cf32x8_p a_aI = a_bI.copy_even2odd();f32x8_p b_bI = a_bI.copy_odd2even();f32x8_p ac_adI = a_aI * c_dI;f32x8_p bd_bcI = b_bI * c_dI.reordered<0b10'11'00'01>();return fc32x4_p(fc32x4_p::addsub(ac_adI, bd_bcI));
}
这次可以逆着来构造,考虑到前面有一个交错加减的函数,用来计算复数乘法比较合适。为了使用交错加减得到复数相乘的结果
{ac-bd, ad+bc, ...}
可以通过这样的运算得到:
{ac, bc}(+/-){ad, bd}
// 或者
{ac, ad}(+/-){bd, bc}
观察第二个式子不难发现:
{ac, ad} = {a, a} * {c, d}
{bd, bc} = {b, b} * {d, c}// -->
{a, a}, {c, d}, {b, b}, {d, c} // 这个可以通过重排函数得到
3.3 方法三
使用乘法后交错加减指令优化一下下方案二
// complex mul complex, use mul_addsub, cost about 98% time of _v1
static fc32x4_p multiply_complex_v2(f32x8_p a_bI, f32x8_p c_dI)
{ // real: a*c - b*d, imag: a*d + b*cf32x8_p a_aI = a_bI.copy_even2odd();f32x8_p b_bI = a_bI.copy_odd2even();//f32x8_p ac_adI = a_aI * c_dI;f32x8_p bd_bcI = b_bI * c_dI.reordered<0b10'11'00'01>();return fc32x4_p::mul_addsub(a_aI, c_dI, bd_bcI);
}
3.4 继续优化
这个封装类因为有源码实现,所以可以直接优化并内联。如果需要封装SIMD指令的函数,可以考虑__vectorcall。__vectorcall使用xmm,ymm等寄存器传递参数和返回值。如下:
_declspec(dllexport) auto __vectorcall multiply_complex_v0(__m256 a, __m256 b)
{return fc32x4_p::multiply_complex_v0(a, b).data;
}
//_declspec(dllexport) auto __vectorcall multiply_complex_v2(__m256 a, __m256 b)//{
C5 FC 10 E0 vmovups ymm4,ymm0 //return fc32x4_p::multiply_complex_v2(a, b).data;
C5 FE 16 D0 vmovshdup ymm2,ymm0
C4 E3 7D 04 D9 B1 vpermilps ymm3,ymm1,0B1h //看这里,立即数0B1h写死在代码里的
C5 E4 59 C2 vmulps ymm0,ymm3,ymm2
C5 FE 12 E4 vmovsldup ymm4,ymm4
C4 E2 5D B6 C1 vfmaddsub231ps ymm0,ymm4,ymm1 //}
C3 ret
AVX256 源码
#pragma once
#ifndef bionukg_SIMD_h
#define bionukg_SIMD_h// sse / avx
#include <stdint.h>
#include <xmmintrin.h> //__m128, f32x4
#include <emmintrin.h> //__m128i,__m128d
#include <immintrin.h> //__m256 series#ifdef namespace_bionukg
namespace bionukg
{
#endifstruct f32x4_b; // basic float32x4struct f32x4_p; // packed float32x4struct f32x4_s; // single float32x4struct f32x4_b{public:union{__m128 data;float f32x4[4];};f32x4_b() : data(_mm_setzero_ps()) {}f32x4_b(__m128 data) : data(data) {}f32x4_b(float a, float b, float c, float d) : data(_mm_setr_ps(a, b, c, d)) {}f32x4_b(float a) : data(_mm_set1_ps(a)) {}f32x4_b(float *a) : data(_mm_loadu_ps(a)) {}f32x4_b(f32x4_b *a) : data(_mm_load_ps((float *)a)) {}void store(float *a) { _mm_storeu_ps(a, data); }void store(f32x4_b *a) { _mm_store_ps((float *)a, data); }void load(float *a) { data = _mm_loadu_ps(a); }void load(f32x4_b *a) { data = _mm_load_ps((float *)a); }float operator[](uint8_t idx) const { return data.m128_f32[idx]; }float &operator[](uint8_t idx) { return data.m128_f32[idx]; }};struct f32x4_s : public f32x4_b{public:f32x4_s() : f32x4_b() {}f32x4_s(__m128 data) : f32x4_b(data) {}f32x4_s(f32x4_b a) : f32x4_b(a) {}f32x4_s(float a, float b, float c, float d) : f32x4_b(a, b, c, d) {}f32x4_s(float a) : f32x4_b(a) {}f32x4_s(float *a) : f32x4_b(a) {}f32x4_s(f32x4_b *a) : f32x4_b(a) {}int32_t get_int32() const { return _mm_cvtss_si32(data); }int32_t get_int32_trunc() const { return _mm_cvttss_si32(data); }int64_t get_int64() const { return _mm_cvtss_si64(data); }int64_t get_int64_trunc() const { return _mm_cvttss_si64(data); }void put_int32(int32_t a) { data = _mm_cvtsi32_ss(data, a); }void put_int64(int64_t a) { data = _mm_cvtsi64_ss(data, a); }float get_float() const { return _mm_cvtss_f32(data); }f32x4_s sqrt() { return f32x4_s(_mm_sqrt_ss(data)); }void do_sqrt() { data = _mm_sqrt_ss(data); }};struct f32x4_p : public f32x4_b{public:f32x4_p() : f32x4_b() {}f32x4_p(__m128 data) : f32x4_b(data) {}f32x4_p(f32x4_b a) : f32x4_b(a) {}f32x4_p(float a, float b, float c, float d) : f32x4_b(a, b, c, d) {}f32x4_p(float a) : f32x4_b(a) {}f32x4_p(float *a) : f32x4_b(a) {}f32x4_p(f32x4_b *a) : f32x4_b(a) {}f32x4_p copy() const { return f32x4_p(data); }f32x4_s as_single() const { return f32x4_s(data); }f32x4_s &as_single_do() { return reinterpret_cast<f32x4_s &>(data); }f32x4_p operator+(f32x4_p a) { return f32x4_p(_mm_add_ps(data, a.data)); }f32x4_p operator+(f32x4_s a) { return f32x4_p(_mm_add_ss(data, a.data)); }f32x4_p operator-(f32x4_p a) { return f32x4_p(_mm_sub_ps(data, a.data)); }f32x4_p operator-(f32x4_s a) { return f32x4_p(_mm_sub_ss(data, a.data)); }f32x4_p operator*(f32x4_p a) { return f32x4_p(_mm_mul_ps(data, a.data)); }f32x4_p operator*(f32x4_s a) { return f32x4_p(_mm_mul_ss(data, a.data)); }f32x4_p operator/(f32x4_p a) { return f32x4_p(_mm_div_ps(data, a.data)); }f32x4_p operator/(f32x4_s a) { return f32x4_p(_mm_div_ss(data, a.data)); }void operator+=(f32x4_p a) { data = _mm_add_ps(data, a.data); }void operator+=(f32x4_s a) { data = _mm_add_ss(data, a.data); }void operator-=(f32x4_p a) { data = _mm_sub_ps(data, a.data); }void operator-=(f32x4_s a) { data = _mm_sub_ss(data, a.data); }void operator*=(f32x4_p a) { data = _mm_mul_ps(data, a.data); }void operator*=(f32x4_s a) { data = _mm_mul_ss(data, a.data); }void operator/=(f32x4_p a) { data = _mm_div_ps(data, a.data); }void operator/=(f32x4_s a) { data = _mm_div_ss(data, a.data); }f32x4_p sqrt() const { return f32x4_p(_mm_sqrt_ps(data)); }void do_sqrt() { data = _mm_sqrt_ps(data); }void do_sqrt_single() { data = _mm_sqrt_ss(data); }f32x4_p rcp() const { return f32x4_p(_mm_rcp_ps(data)); }void do_rcp() { data = _mm_rcp_ps(data); }void do_rcp_single() { data = _mm_rcp_ss(data); }f32x4_p rsqrt() const { return f32x4_p(_mm_rsqrt_ps(data)); }void do_rsqrt() { data = _mm_rsqrt_ps(data); }void do_rsqrt_single() { data = _mm_rsqrt_ss(data); }static f32x4_p minimum(f32x4_p a, f32x4_p b) { return f32x4_p(_mm_min_ps(a.data, b.data)); }static f32x4_p minimum(f32x4_p a, f32x4_s b) { return f32x4_p(_mm_min_ss(a.data, b.data)); }static f32x4_p maximum(f32x4_p a, f32x4_p b) { return f32x4_p(_mm_max_ps(a.data, b.data)); }static f32x4_p maximum(f32x4_p a, f32x4_s b) { return f32x4_p(_mm_max_ss(a.data, b.data)); }f32x4_p operator&(f32x4_p a) { return f32x4_p(_mm_and_ps(data, a.data)); }void operator&=(f32x4_p a) { data = _mm_and_ps(data, a.data); }f32x4_p operator|(f32x4_p a) { return f32x4_p(_mm_or_ps(data, a.data)); }void operator|=(f32x4_p a) { data = _mm_or_ps(data, a.data); }f32x4_p operator^(f32x4_p a) { return f32x4_p(_mm_xor_ps(data, a.data)); }void operator^=(f32x4_p a) { data = _mm_xor_ps(data, a.data); }static f32x4_p andnot(f32x4_p a, f32x4_p b) { return f32x4_p(_mm_andnot_ps(a.data, b.data)); }void do_andnot(f32x4_p a) { data = _mm_andnot_ps(data, a.data); }f32x4_p operator==(f32x4_p a) { return f32x4_p(_mm_cmpeq_ps(data, a.data)); }int operator==(f32x4_s a) { return _mm_comieq_ss(data, a.data); }f32x4_p operator!=(f32x4_p a) { return f32x4_p(_mm_cmpneq_ps(data, a.data)); }int operator!=(f32x4_s a) { return _mm_comineq_ss(data, a.data); }f32x4_p operator<(f32x4_p a) { return f32x4_p(_mm_cmplt_ps(data, a.data)); }int operator<(f32x4_s a) { return _mm_comilt_ss(data, a.data); }f32x4_p operator<=(f32x4_p a) { return f32x4_p(_mm_cmple_ps(data, a.data)); }int operator<=(f32x4_s a) { return _mm_comile_ss(data, a.data); }f32x4_p operator>(f32x4_p a) { return f32x4_p(_mm_cmpgt_ps(data, a.data)); }int operator>(f32x4_s a) { return _mm_comigt_ss(data, a.data); }f32x4_p operator>=(f32x4_p a) { return f32x4_p(_mm_cmpge_ps(data, a.data)); }int operator>=(f32x4_s a) { return _mm_comige_ss(data, a.data); }f32x4_p has_NAN() { return f32x4_p(_mm_cmpunord_ps(data, data)); }f32x4_p has_NAN(f32x4_p a) { return f32x4_p(_mm_cmpunord_ps(data, a.data)); }f32x4_p not_NAN() { return f32x4_p(_mm_cmpord_ps(data, data)); }f32x4_p not_NAN(f32x4_p a) { return f32x4_p(_mm_cmpord_ps(data, a.data)); }// {a[imm[1:0]], a[imm[3:2]], b[imm[5:4]], b[imm[7:6]]}template <uint8_t _imm8>void shuffle(f32x4_p a) { data = _mm_shuffle_ps(data, a.data, _imm8); }/*DEFINE SELECT4(src, control) {CASE(control[1:0]) OF0: tmp[31:0] := src[31:0]1: tmp[31:0] := src[63:32]2: tmp[31:0] := src[95:64]3: tmp[31:0] := src[127:96]ESACRETURN tmp[31:0]}dst[31:0] := SELECT4(a[127:0], imm8[1:0])dst[63:32] := SELECT4(a[127:0], imm8[3:2])dst[95:64] := SELECT4(b[127:0], imm8[5:4])dst[127:96] := SELECT4(b[127:0], imm8[7:6])*/// {a[imm[1:0]], a[imm[3:2]], a[imm[5:4]], a[imm[7:6]]}template <uint8_t _imm8>void permute() { data = _mm_permute_ps(data, _imm8); }/*DEFINE SELECT4(src, control) {CASE(control[1:0]) OF0: tmp[31:0] := src[31:0]1: tmp[31:0] := src[63:32]2: tmp[31:0] := src[95:64]3: tmp[31:0] := src[127:96]ESACRETURN tmp[31:0]}dst[31:0] := SELECT4(a[127:0], imm8[1:0])dst[63:32] := SELECT4(a[127:0], imm8[3:2])dst[95:64] := SELECT4(a[127:0], imm8[5:4])dst[127:96] := SELECT4(a[127:0], imm8[7:6])dst[MAX:128] := 0*/// {a[3], b[3], a[4], b[4]}static f32x4_p unpack_high(f32x4_p a, f32x4_p b) { return f32x4_p(_mm_unpackhi_ps(a.data, b.data)); }void unpack_high(f32x4_p a) { data = _mm_unpackhi_ps(data, a.data); }// {a[0], b[0], a[1], b[1]}static f32x4_p unpack_low(f32x4_p a, f32x4_p b) { return f32x4_p(_mm_unpacklo_ps(a.data, b.data)); }void unpack_low(f32x4_p a) { data = _mm_unpacklo_ps(data, a.data); }// {a[0], a[1], mem[0], mem[1]}static f32x4_p load_high(f32x4_p a, float *mem_addr) { return f32x4_p(_mm_loadh_pi(a.data, (__m64 *)mem_addr)); }void load_high(float *mem_addr) { data = _mm_loadh_pi(data, (__m64 *)mem_addr); }void store_high(float *mem_addr) { _mm_storeh_pi((__m64 *)mem_addr, data); }// {mem[0], mem[1], a[0], a[1]}static f32x4_p load_low(f32x4_p a, float *mem_addr) { return f32x4_p(_mm_loadl_pi(a.data, (__m64 *)mem_addr)); }void load_low(float *mem_addr) { data = _mm_loadl_pi(data, (__m64 *)mem_addr); }void store_low(float *mem_addr) { _mm_storel_pi((__m64 *)mem_addr, data); }// {b[2], b[3], a[2], a[3]}static f32x4_p move_high2low(f32x4_p a, f32x4_p b) { return f32x4_p(_mm_movehl_ps(a.data, b.data)); }void move_high2low(f32x4_p a) { data = _mm_movehl_ps(data, a.data); }void move_h2l(f32x4_p a) { data = _mm_movehl_ps(data, a.data); }// {a[0], a[1], b[0], b[1]}static f32x4_p move_low2high(f32x4_p a, f32x4_p b) { return f32x4_p(_mm_movelh_ps(a.data, b.data)); }void move_low2high(f32x4_p a) { data = _mm_movelh_ps(data, a.data); }void move_l2h(f32x4_p a) { data = _mm_movelh_ps(data, a.data); }int get_mask() const { return _mm_movemask_ps(data); }int get_signs() const { return _mm_movemask_ps(data); }};struct f32x8_p{public:union{__m256 data;float f32x8[8];};f32x8_p() : data(_mm256_setzero_ps()) {}f32x8_p(__m256 data) : data(data) {}f32x8_p(f32x4_p high, f32x4_p low) : data(_mm256_set_m128(high.data, low.data)) {}f32x8_p(float a, float b, float c, float d, float e = 0.0f, float f = 0.0f, float g = 0.0f, float h = 0.0f) : data(_mm256_setr_ps(a, b, c, d, e, f, g, h)) {}f32x8_p(float a) : data(_mm256_set1_ps(a)) {}f32x8_p(float *a) : data(_mm256_load_ps((float *)a)) {}__m256d cast_packed_double() const { return _mm256_castps_pd(data); }__m256d cast_packed_double() { return _mm256_castps_pd(data); }__m256d cast_packed_f64() const { return _mm256_castps_pd(data); }__m256d cast_packed_f64() { return _mm256_castps_pd(data); }__m256d cast_pd() const { return _mm256_castps_pd(data); }__m256d cast_pd() { return _mm256_castps_pd(data); }__m256i cast_packed_int() const { return _mm256_castps_si256(data); }__m256i cast_packed_int() { return _mm256_castps_si256(data); }__m256i cast_si() const { return _mm256_castps_si256(data); }__m256i cast_si() { return _mm256_castps_si256(data); }__m128 cast_down() const { return _mm256_castps256_ps128(data); }__m128 cast_down() { return _mm256_castps256_ps128(data); }template <uint8_t _imm3 = 0b000>__m128i convert_packed_half_float() const { return _mm256_cvtps_ph(data, _imm3); }template <uint8_t _imm3 = 0b000>__m128i convert_packed_f16() const { return _mm256_cvtps_ph(data, _imm3); }template <uint8_t _imm3 = 0b000>__m128i convert_packed_half_float() { return _mm256_cvtps_ph(data, _imm3); }template <uint8_t _imm3 = 0b000>__m128i convert_packed_f16() { return _mm256_cvtps_ph(data, _imm3); }static f32x8_p convert_from_packed_f16(__m128i a) { return f32x8_p(_mm256_cvtph_ps(a)); }static f32x8_p load_8floats(const void *arr) { return f32x8_p(_mm256_loadu_ps((float *)arr)); }static f32x8_p load_1float_broadcast(float *arr) { return f32x8_p(_mm256_broadcast_ss(arr)); }static f32x8_p load_4floats_broadcast(const __m128 *arr) { return f32x8_p(_mm256_broadcast_ps(arr)); }static f32x8_p load_mask(float const *arr, __m256i mask) { return f32x8_p(_mm256_maskload_ps(arr, mask)); }// {a[2], b[2], a[3], b[3], a[6], b[6], a[7], b[7]}static f32x8_p unpack_high(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_unpackhi_ps(a.data, b.data)); }// {a[0], b[0], a[1], b[1], a[4], b[4], a[5], b[5]}static f32x8_p unpack_low(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_unpacklo_ps(a.data, b.data)); }void store(float *a) { _mm256_storeu_ps(a, data); }void store(f32x8_p *a) { _mm256_store_ps((float *)a, data); }void load(float *a) { data = _mm256_loadu_ps(a); }void load(f32x8_p *a) { data = _mm256_load_ps((float *)a); }float operator[](uint8_t idx) const { return data.m256_f32[idx]; }float &operator[](uint8_t idx) { return data.m256_f32[idx]; }f32x8_p copy() const { return f32x8_p(data); }f32x8_p operator+(f32x8_p a) { return f32x8_p(_mm256_add_ps(data, a.data)); }f32x8_p operator-(f32x8_p a) { return f32x8_p(_mm256_sub_ps(data, a.data)); }f32x8_p operator*(f32x8_p a) { return f32x8_p(_mm256_mul_ps(data, a.data)); }f32x8_p operator/(f32x8_p a) { return f32x8_p(_mm256_div_ps(data, a.data)); }void operator+=(f32x8_p a) { data = _mm256_add_ps(data, a.data); }void operator-=(f32x8_p a) { data = _mm256_sub_ps(data, a.data); }void operator*=(f32x8_p a) { data = _mm256_mul_ps(data, a.data); }void operator/=(f32x8_p a) { data = _mm256_div_ps(data, a.data); }f32x8_p operator&(f32x8_p a) { return f32x8_p(_mm256_and_ps(data, a.data)); }void operator&=(f32x8_p a) { data = _mm256_and_ps(data, a.data); }f32x8_p operator|(f32x8_p a) { return f32x8_p(_mm256_or_ps(data, a.data)); }void operator|=(f32x8_p a) { data = _mm256_or_ps(data, a.data); }f32x8_p operator^(f32x8_p a) { return f32x8_p(_mm256_xor_ps(data, a.data)); }void operator^=(f32x8_p a) { data = _mm256_xor_ps(data, a.data); }static f32x8_p andnot(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_andnot_ps(a.data, b.data)); }void do_andnot(f32x8_p a) { data = _mm256_andnot_ps(data, a.data); }// xor for equal makes all bits 0bool operator==(f32x8_p a){auto int_result = _mm256_castps_si256(_mm256_xor_ps(data, a.data));// check if all bits are 0return _mm256_testz_si256(int_result, int_result) != 0;}// for j = 0 to 7, ret[j] = imm8[j]?a[j]:b[j]template <uint8_t _imm8>void blend(f32x8_p a) { data = _mm256_blend_ps(data, a.data, _imm8); }template <uint8_t _imm8>static f32x8_p blend(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_blend_ps(a.data, b.data, _imm8)); }// for j = 0 to 7, ret[j] = mask[j].signbit?a[j]:b[j]void blend(f32x8_p a, f32x8_p mask) { data = _mm256_blendv_ps(data, a.data, mask.data); }static f32x8_p blend(f32x8_p a, f32x8_p b, f32x8_p mask) { return f32x8_p(_mm256_blendv_ps(a.data, b.data, mask.data)); }/// advanced calculation// {a[0]-b[0], a[1]+b[1], a[2]-b[2], a[3]+b[3]...}void addsub(f32x8_p a) { data = _mm256_addsub_ps(data, a.data); }static f32x8_p addsub(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_addsub_ps(a.data, b.data)); }// {a[0]+a[1], a[2]+a[3], b[0]+b[1], b[2]+b[3]...}void hadd(f32x8_p a) { data = _mm256_hadd_ps(data, a.data); }static f32x8_p hadd(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_hadd_ps(a.data, b.data)); }// {a[0]-a[1], a[2]-a[3], b[0]-b[1], b[2]-b[3]...}void hsub(f32x8_p a) { data = _mm256_hsub_ps(data, a.data); }static f32x8_p hsub(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_hsub_ps(a.data, b.data)); }// {a[0]*mul[0]+add[0], a[1]*mul[1]+add[1]...}void mul_then_add(f32x8_p mul, f32x8_p add) { data = _mm256_fmadd_ps(data, mul.data, add.data); }void fmadd(f32x8_p mul, f32x8_p add) { data = _mm256_fmadd_ps(data, mul.data, add.data); }static f32x8_p mul_then_add(f32x8_p a, f32x8_p mul, f32x8_p add) { return f32x8_p(_mm256_fmadd_ps(a.data, mul.data, add.data)); }// {a[0]*mul[0]-add[0], a[1]*mul[1]-add[1]...}void mul_then_sub(f32x8_p mul, f32x8_p sub) { data = _mm256_fmsub_ps(data, mul.data, sub.data); }void fmsub(f32x8_p mul, f32x8_p sub) { data = _mm256_fmsub_ps(data, mul.data, sub.data); }static f32x8_p mul_then_sub(f32x8_p a, f32x8_p mul, f32x8_p sub) { return f32x8_p(_mm256_fmsub_ps(a.data, mul.data, sub.data)); }// {-a[0]*mul[0]+add[0], -a[1]*mul[1]+add[1]...}void neg_mul_then_add(f32x8_p mul, f32x8_p add) { data = _mm256_fnmadd_ps(data, mul.data, add.data); }void fnmadd(f32x8_p mul, f32x8_p add) { data = _mm256_fnmadd_ps(data, mul.data, add.data); }static f32x8_p neg_mul_then_add(f32x8_p a, f32x8_p mul, f32x8_p add) { return f32x8_p(_mm256_fnmadd_ps(a.data, mul.data, add.data)); }// {-a[0]*mul[0]-add[0], -a[1]*mul[1]-add[1]...}void neg_mul_then_sub(f32x8_p mul, f32x8_p sub) { data = _mm256_fnmsub_ps(data, mul.data, sub.data); }void fnmsub(f32x8_p mul, f32x8_p sub) { data = _mm256_fnmsub_ps(data, mul.data, sub.data); }static f32x8_p neg_mul_then_sub(f32x8_p a, f32x8_p mul, f32x8_p sub) { return f32x8_p(_mm256_fnmsub_ps(a.data, mul.data, sub.data)); }// {a[0] * mul[0] + addsub[0], a[1] * mul[1] - addsub[1]...}void mul_addsub(f32x8_p mul, f32x8_p addsub) { data = _mm256_fmaddsub_ps(data, mul.data, addsub.data); }void fmaddsub(f32x8_p mul, f32x8_p addsub) { data = _mm256_fmaddsub_ps(data, mul.data, addsub.data); }static f32x8_p mul_addsub(f32x8_p a, f32x8_p mul, f32x8_p addsub) { return f32x8_p(_mm256_fmaddsub_ps(a.data, mul.data, addsub.data)); }// {a[0] * mul[0] - subadd[0], a[1] * mul[1] + subadd[1]...}void mul_subadd(f32x8_p mul, f32x8_p subadd) { data = _mm256_fmsubadd_ps(data, mul.data, subadd.data); }void fmsubadd(f32x8_p mul, f32x8_p subadd) { data = _mm256_fmsubadd_ps(data, mul.data, subadd.data); }static f32x8_p mul_subadd(f32x8_p a, f32x8_p mul, f32x8_p subadd) { return f32x8_p(_mm256_fmsubadd_ps(a.data, mul.data, subadd.data)); }// {a[0]>b[0]?a[0]:b[0], a[1]>b[1]?a[1]:b[1]...}void maximum(f32x8_p a) { data = _mm256_max_ps(data, a.data); }static f32x8_p maximum(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_max_ps(a.data, b.data)); }// {a[0]<b[0]?a[0]:b[0], a[1]<b[1]?a[1]:b[1]...}void minimum(f32x8_p a) { data = _mm256_min_ps(data, a.data); }static f32x8_p minimum(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_min_ps(a.data, b.data)); }// see f32x4_p::shuffle. do the same shuffle for both high and low 128bit// {a[imm[1:0]], a[imm[3:2]], b[imm[5:4]], b[imm[7:6]]}template <uint8_t _imm8>void shuffle(f32x8_p a) { data = _mm256_shuffle_ps(data, a.data, _imm8); }template <uint8_t _imm8>void remix(f32x8_p a) { data = _mm256_shuffle_ps(data, a.data, _imm8); }// {a[imm[1:0]], a[imm[3:2]], b[imm[5:4]], b[imm[7:6]]}template <uint8_t _imm8>f32x8_p remixed(f32x8_p a) const { return f32x8_p(_mm256_shuffle_ps(data, a.data, _imm8)); }// {a[imm[1:0]], a[imm[3:2]], b[imm[5:4]], b[imm[7:6]]}template <uint8_t _imm8>f32x8_p remixed(f32x8_p a) { return f32x8_p(_mm256_shuffle_ps(data, a.data, _imm8)); }template <uint8_t _imm8>static f32x8_p shuffle(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_shuffle_ps(a.data, b.data, _imm8)); }// see f32x4_p::permute. do the same permute for both high and low 128bitvoid permute(f32x8_p a, __m256i b) { data = _mm256_permutevar_ps(a.data, b); }static f32x8_p permute(f32x8_p a, f32x8_p b, __m256i c) { return f32x8_p(_mm256_permutevar_ps(a.data, c)); }// {a[imm[1:0]], a[imm[3:2]], a[imm[5:4]], a[imm[7:6]]}template <uint8_t _imm8>void permute() { data = _mm256_permute_ps(data, _imm8); }template <uint8_t _imm8>void reorder() { data = _mm256_permute_ps(data, _imm8); }// {a[imm[1:0]], a[imm[3:2]], a[imm[5:4]], a[imm[7:6]]}template <uint8_t _imm8>f32x8_p reordered() const { return f32x8_p(_mm256_permute_ps(data, _imm8)); }// {a[imm[1:0]], a[imm[3:2]], a[imm[5:4]], a[imm[7:6]]}template <uint8_t _imm8>f32x8_p reordered() { return f32x8_p(_mm256_permute_ps(data, _imm8)); }template <uint8_t _imm8>f32x8_p permuted() { return f32x8_p(_mm256_permute_ps(data, _imm8)); }template <uint8_t _imm8>static f32x8_p permute(f32x8_p a) { return f32x8_p(_mm256_permute_ps(a.data, _imm8)); }// ret.low = imm8[2]? 0 : switch imm8[1:0] {case 0: a.low; case 1: a.high; case 2: b.low; case 3: b.high}// ret.high = imm8[6]? 0 : switch imm8[5:4] {case 0: a.low; case 1: a.high; case 2: b.low; case 3: b.high}template <uint8_t _imm8>void permute2f128(f32x8_p a) { data = _mm256_permute2f128_ps(data, a.data, _imm8); }template <uint8_t _imm8>void shuffle128(f32x8_p a) { data = permute2f128<_imm8>(data, a.data); }template <uint8_t _imm8>static f32x8_p permute2f128(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_permute2f128_ps(a.data, b.data, _imm8)); }template <uint8_t _imm8>static f32x8_p shuffle128(f32x8_p a, f32x8_p b) { return permute2f128<_imm8>(a.data, b.data); }enum compare_operation_imm8 : uint8_t{OP_CMP_EQ_OQ = 0,OP_CMP_LT_OS = 1,OP_CMP_LE_OS = 2,OP_CMP_UNORD_Q = 3,OP_CMP_NEQ_UQ = 4,OP_CMP_NLT_US = 5,OP_CMP_NLE_US = 6,OP_CMP_ORD_Q = 7,OP_CMP_EQ_UQ = 8,OP_CMP_NGE_US = 9,OP_CMP_NGT_US = 10,OP_CMP_FALSE_OQ = 11,OP_CMP_NEQ_OQ = 12,OP_CMP_GE_OS = 13,OP_CMP_GT_OS = 14,OP_CMP_TRUE_UQ = 15,OP_CMP_EQ_OS = 16,OP_CMP_LT_OQ = 17,OP_CMP_LE_OQ = 18,OP_CMP_UNORD_S = 19,OP_CMP_NEQ_US = 20,OP_CMP_NLT_UQ = 21,OP_CMP_NLE_UQ = 22,OP_CMP_ORD_S = 23,OP_CMP_EQ_US = 24,OP_CMP_NGE_UQ = 25,OP_CMP_NGT_UQ = 26,OP_CMP_FALSE_OS = 27,OP_CMP_NEQ_OS = 28,OP_CMP_GE_OQ = 29,OP_CMP_GT_OQ = 30,OP_CMP_TRUE_US = 31};template <compare_operation_imm8 _imm8>static f32x8_p compare(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_cmp_ps(a.data, b.data, _imm8)); }__m256i convert_int32p(f32x8_p a) { return _mm256_cvtps_epi32(a.data); }__m256i convert_int32p_trunc(f32x8_p a) { return _mm256_cvttps_epi32(a.data); }float get_float() const { return _mm256_cvtss_f32(data); }f32x4_p get_low() const { return f32x4_p(_mm256_extractf128_ps(data, 0)); }f32x4_p get_high() const { return f32x4_p(_mm256_extractf128_ps(data, 1)); }static void zeroall() { _mm256_zeroall(); }static void zeroupper() { _mm256_zeroupper(); }void load_broadcast(float *mem_addr) { data = _mm256_broadcast_ss(mem_addr); }void move_high_dup() { data = _mm256_movehdup_ps(data); }void move_odd2even() { data = _mm256_movehdup_ps(data); }f32x8_p copy_odd2even() { return f32x8_p(_mm256_movehdup_ps(data)); }void move_low_dup() { data = _mm256_moveldup_ps(data); }void move_even2odd() { data = _mm256_moveldup_ps(data); }f32x8_p copy_even2odd() { return f32x8_p(_mm256_moveldup_ps(data)); }void rcp() { data = _mm256_rcp_ps(data); }void rsqrt() { data = _mm256_rsqrt_ps(data); }void sqrt() { data = _mm256_sqrt_ps(data); }template <uint8_t _rounding_imm>void round() { data = _mm256_round_ps(data, _rounding_imm); }// zf: if the (a & b)sign are all zero, return 1, else return 0// means no sign bit pair is both 1int test_zf() { return _mm256_testz_ps(data, data); } // means all 0int test_zf(f32x8_p a) { return _mm256_testz_ps(data, a.data); }int test_if_not_both_bit1(f32x8_p a) { return _mm256_testz_ps(data, a.data); }// cf: if the (~a & b)sign are all zero, return 1, else return 0// means no sign bit pair is a 0 and b 1 (each b1 in after a1)int test_cf() { return _mm256_testc_ps(data, data); } // always 1int test_cf(f32x8_p a) { return _mm256_testc_ps(data, a.data); }int test_if_bit_contained(f32x8_p a) { return _mm256_testc_ps(data, a.data); }// (!cf) && (!zf) , only when both cf and zf are 0, return 1// means some sign bit pair is both 1 and some sign bit pair is a 0 and b 1// also means there are: >=1x a1b1 and >=1x a0b1int test_ncz() { return _mm256_testnzc_ps(data, data); } // always 0int test_ncz(f32x8_p a) { return _mm256_testnzc_ps(data, a.data); }int get_mask() const { return _mm256_movemask_ps(data); }int get_signs() const { return _mm256_movemask_ps(data); }};// 4 complex float32x2 :{c[0].re,c[0].im,c[1].re,c[1].im,c[2].re,c[2].im,c[3].re,c[3].im}struct fc32x4_p : public f32x8_p{public:fc32x4_p() : f32x8_p() {}fc32x4_p(__m256 data) : f32x8_p(data) {}fc32x4_p(f32x8_p a) : f32x8_p(a) {}fc32x4_p(float a, float b, float c, float d, float e = 0.0f, float f = 0.0f, float g = 0.0f, float h = 0.0f) : f32x8_p(a, b, c, d, e, f, g, h) {}fc32x4_p(float a) : f32x8_p(a) {}fc32x4_p(float *a) : f32x8_p(a) {}// operator +, - is the same as f32x8_p// complex mul complex, use hadd and hsubstatic fc32x4_p multiply_complex_v0(f32x8_p a_bI, f32x8_p c_dI){ // real: a*c - b*d, imag: a*d + b*cf32x8_p ac_bdI = a_bI * c_dI;f32x8_p ad_bcI = a_bI * c_dI.reordered<0b10'11'00'01>();ac_bdI.hsub(ac_bdI);ad_bcI.hadd(ad_bcI);return fc32x4_p(ac_bdI.remixed<0b11'10'01'00>(ad_bcI) // {r0,r1,i0,i1}.reordered<0b11'01'10'00>()); // {r0,i0,r1,i1}}fc32x4_p cmul_v0(fc32x4_p a) { return multiply_complex_v0(data, a.data); }// complex mul complex, use addsub, cost about 90% time of _v0static fc32x4_p multiply_complex_v1(f32x8_p a_bI, f32x8_p c_dI){ // real: a*c - b*d, imag: a*d + b*cf32x8_p a_aI = a_bI.copy_even2odd();f32x8_p b_bI = a_bI.copy_odd2even();f32x8_p ac_adI = a_aI * c_dI;f32x8_p bd_bcI = b_bI * c_dI.reordered<0b10'11'00'01>();return fc32x4_p(fc32x4_p::addsub(ac_adI, bd_bcI));}fc32x4_p cmul_v1(fc32x4_p a) { return multiply_complex_v1(data, a.data); }// complex mul complex, use mul_addsub, cost about 98% time of _v1static fc32x4_p multiply_complex_v2(f32x8_p a_bI, f32x8_p c_dI){ // real: a*c - b*d, imag: a*d + b*cf32x8_p a_aI = a_bI.copy_even2odd();f32x8_p b_bI = a_bI.copy_odd2even();//f32x8_p ac_adI = a_aI * c_dI;f32x8_p bd_bcI = b_bI * c_dI.reordered<0b10'11'00'01>();return fc32x4_p::mul_addsub(a_aI, c_dI, bd_bcI);}fc32x4_p cmul_v2(fc32x4_p a) { return multiply_complex_v2(data, a.data); }// complex mul complex, use mul_addsub, static fc32x4_p multiply_complex_v3(f32x8_p a_bI_, f32x8_p c_dI){ // real: a*c - b*d, imag: a*d + b*cf32x8_p a_bI = a_bI_.data;//f32x8_p a_aI = a_bI.copy_even2odd();//f32x8_p b_bI = a_bI.copy_odd2even();//f32x8_p ac_adI = a_aI * c_dI;//f32x8_p bd_bcI = a_bI.copy_odd2even() * c_dI.reordered<0b10'11'00'01>();return fc32x4_p::mul_addsub(a_bI.copy_even2odd(), c_dI, a_bI.copy_odd2even() * c_dI.reordered<0b10'11'00'01>());}fc32x4_p cmul_v3(fc32x4_p a) { return multiply_complex_v2(data, a.data); }};EXTERN_C{_declspec(dllexport) auto __vectorcall multiply_complex_v0(__m256 a, __m256 b){return fc32x4_p::multiply_complex_v0(a, b).data;}_declspec(dllexport) auto __vectorcall multiply_complex_v1(__m256 a, __m256 b){return fc32x4_p::multiply_complex_v1(a, b).data;}_declspec(dllexport) auto __vectorcall multiply_complex_v2(__m256 a, __m256 b){return fc32x4_p::multiply_complex_v2(a, b).data;}_declspec(dllexport) auto __vectorcall multiply_complex_v3(__m256 a, __m256 b){return fc32x4_p::multiply_complex_v3(a, b).data;};};#ifdef namespace_bionukg
}
#endif#endif
这篇关于AVX介绍的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!