本文主要是介绍详解矩阵乘法中的Strassen算法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
机器学习中需要训练大量数据,涉及大量复杂运算,例如卷积、矩阵等。这些复杂运算不仅多,而且每次计算的数据量很大,如果能针对这些运算进行优化,可以大幅提高性能。
一、矩阵乘法
如下图所示:
Figure 1 Matrix Multiplication
二、Strassen算法
Figure 2 x^3 vs. x^2.807
三、Strassen原理详解
Strassen算法正是从这个角度出发,实现了降低算法复杂度!
实现步骤可以分为以下4步:
3.1 Strassen实现步骤
四、Strassen算法的代码实现
我们以MNN中关于Strassen算法源码实现来学习:https://github.com/alibaba/MNN/blob/master/source/backend/cpu/compute/StrassenMatmulComputor.cpp。
类StrassenMatrixComputor提供了3个API供调用:
_generateTrivalMatMul(const Tensor* AT, const Tensor* BT, const Tensor* CT);
普通矩阵乘法计算
_generateMatMul(const Tensor* AT, const Tensor* BT, const Tensor* CT, int currentDepth);
Strassen算法的矩阵乘法
_generateMatMulConstB(const Tensor* AT, const Tensor* BT, const Tensor* CT, int currentDepth);
Strassen算法的矩阵乘法(和MatMul的区别在于内存Buffer是否允许复用)
我们以_generateMatMul为例来学习下Strassen算法如何实现,可以分成如下几步:
第一步:使用Strassen算法收益判断
在矩阵操作中,因为需要对矩阵的维数进行扩展,涉及大量读写操作,这些读写操作都需要大量循环,如果读写次数超出使用Strassen乘法的收益的话,就得不偿失了,那么就使用普通的矩阵乘法。
/*Compute the memory read / write cost for expandMatrix Mul need eSub*lSub*hSub*(1+1.0/CONVOLUTION_TILED_NUMBWR), Matrix Add/Sub need x*y*UNIT*3 (2 read 1 write)*/float saveCost =(eSub * lSub * hSub) * (1.0f + 1.0f / CONVOLUTION_TILED_NUMBWR) - 4 * (eSub * lSub) * 3 - 7 * (eSub * hSub * 3);if (currentDepth >= mMaxDepth || e <= CONVOLUTION_TILED_NUMBWR || l % 2 != 0 || h % 2 != 0 || saveCost < 0.0f) {return _generateTrivialMatMul(AT, BT, CT);}
第二步:分块
auto aStride = AT->stride(0);auto a11 = AT->host<float>() + 0 * aUnit * eSub + 0 * aStride * lSub;auto a12 = AT->host<float>() + 0 * aUnit * eSub + 1 * aStride * lSub;auto a21 = AT->host<float>() + 1 * aUnit * eSub + 0 * aStride * lSub;auto a22 = AT->host<float>() + 1 * aUnit * eSub + 1 * aStride * lSub;auto bStride = BT->stride(0);auto b11 = BT->host<float>() + 0 * bUnit * lSub + 0 * bStride * hSub;auto b12 = BT->host<float>() + 0 * bUnit * lSub + 1 * bStride * hSub;auto b21 = BT->host<float>() + 1 * bUnit * lSub + 0 * bStride * hSub;auto b22 = BT->host<float>() + 1 * bUnit * lSub + 1 * bStride * hSub;auto cStride = CT->stride(0);auto c11 = CT->host<float>() + 0 * aUnit * eSub + 0 * cStride * hSub;auto c12 = CT->host<float>() + 0 * aUnit * eSub + 1 * cStride * hSub;auto c21 = CT->host<float>() + 1 * aUnit * eSub + 0 * cStride * hSub;auto c22 = CT->host<float>() + 1 * aUnit * eSub + 1 * cStride * hSub;
第三步:分治和递归
Strassen算法核心就是分治思想。这一步可以写成下列所示伪代码:
1. If n = 1 Output A × B
2. Else
3. Compute A11,B11, . . . ,A22,B22 % by computing m = n/2
4. P1 Strassen(A11,B12 − B22)
5. P2 Strassen(A11 + A12,B22)
6. P3 Strassen(A21 + A22,B11)
7. P4 Strassen(A22,B21 − B11)
8. P5 Strassen(A11 + A22,B11 + B22)
9. P6 Strassen(A12 − A22,B21 + B22)
10. P7 Strassen(A11 − A21,B11 + B12)
11. C11 P5 + P4 − P2 + P6
12. C12 P1 + P2
13. C21 P3 + P4
14. C22 P1 + P5 − P3 − P7
15. Output C
16. End If
例如其中的一步代码如下所示:
{// S1=A21+A22, T1=B12-B11, P5=S1T1auto f = [a22, a21, b11, b12, xAddr, yAddr, eSub, lSub, hSub, aStride, bStride]() {MNNMatrixAdd(xAddr, a21, a22, eSub * aUnit / 4, eSub * aUnit, aStride, aStride, lSub);MNNMatrixSub(yAddr, b12, b11, lSub * bUnit / 4, lSub * bUnit, bStride, bStride, hSub);};mFunctions.emplace_back(f);auto code = _generateMatMul(X.get(), Y.get(), C22.get(), currentDepth);if (code != NO_ERROR) {return code;}}
递归执行,得到最终结果!
这篇关于详解矩阵乘法中的Strassen算法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!