详解矩阵乘法中的Strassen算法

2024-06-02 16:38

本文主要是介绍详解矩阵乘法中的Strassen算法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

 

机器学习中需要训练大量数据,涉及大量复杂运算,例如卷积、矩阵等。这些复杂运算不仅多,而且每次计算的数据量很大,如果能针对这些运算进行优化,可以大幅提高性能。

一、矩阵乘法

如下图所示:

Figure 1 Matrix Multiplication

二、Strassen算法

Figure 2 x^3 vs. x^2.807

三、Strassen原理详解

Strassen算法正是从这个角度出发,实现了降低算法复杂度!

实现步骤可以分为以下4步:

3.1 Strassen实现步骤

 

四、Strassen算法的代码实现

我们以MNN中关于Strassen算法源码实现来学习:https://github.com/alibaba/MNN/blob/master/source/backend/cpu/compute/StrassenMatmulComputor.cpp。

类StrassenMatrixComputor提供了3个API供调用:

_generateTrivalMatMul(const Tensor* AT, const Tensor* BT, const Tensor* CT);

普通矩阵乘法计算

_generateMatMul(const Tensor* AT, const Tensor* BT, const Tensor* CT, int currentDepth);

Strassen算法的矩阵乘法

_generateMatMulConstB(const Tensor* AT, const Tensor* BT, const Tensor* CT, int currentDepth);

Strassen算法的矩阵乘法(和MatMul的区别在于内存Buffer是否允许复用)

我们以_generateMatMul为例来学习下Strassen算法如何实现,可以分成如下几步:

第一步:使用Strassen算法收益判断

在矩阵操作中,因为需要对矩阵的维数进行扩展,涉及大量读写操作,这些读写操作都需要大量循环,如果读写次数超出使用Strassen乘法的收益的话,就得不偿失了,那么就使用普通的矩阵乘法

    /*Compute the memory read / write cost for expandMatrix Mul need eSub*lSub*hSub*(1+1.0/CONVOLUTION_TILED_NUMBWR), Matrix Add/Sub need x*y*UNIT*3 (2 read 1 write)*/float saveCost =(eSub * lSub * hSub) * (1.0f + 1.0f / CONVOLUTION_TILED_NUMBWR) - 4 * (eSub * lSub) * 3 - 7 * (eSub * hSub * 3);if (currentDepth >= mMaxDepth || e <= CONVOLUTION_TILED_NUMBWR || l % 2 != 0 || h % 2 != 0 || saveCost < 0.0f) {return _generateTrivialMatMul(AT, BT, CT);}

第二步:分块

    auto aStride = AT->stride(0);auto a11     = AT->host<float>() + 0 * aUnit * eSub + 0 * aStride * lSub;auto a12     = AT->host<float>() + 0 * aUnit * eSub + 1 * aStride * lSub;auto a21     = AT->host<float>() + 1 * aUnit * eSub + 0 * aStride * lSub;auto a22     = AT->host<float>() + 1 * aUnit * eSub + 1 * aStride * lSub;auto bStride = BT->stride(0);auto b11     = BT->host<float>() + 0 * bUnit * lSub + 0 * bStride * hSub;auto b12     = BT->host<float>() + 0 * bUnit * lSub + 1 * bStride * hSub;auto b21     = BT->host<float>() + 1 * bUnit * lSub + 0 * bStride * hSub;auto b22     = BT->host<float>() + 1 * bUnit * lSub + 1 * bStride * hSub;auto cStride = CT->stride(0);auto c11     = CT->host<float>() + 0 * aUnit * eSub + 0 * cStride * hSub;auto c12     = CT->host<float>() + 0 * aUnit * eSub + 1 * cStride * hSub;auto c21     = CT->host<float>() + 1 * aUnit * eSub + 0 * cStride * hSub;auto c22     = CT->host<float>() + 1 * aUnit * eSub + 1 * cStride * hSub;

第三步:分治和递归

Strassen算法核心就是分治思想。这一步可以写成下列所示伪代码:

1. If n = 1 Output A × B
2. Else
3. Compute A11,B11, . . . ,A22,B22 % by computing m = n/2
4. P1   Strassen(A11,B12 − B22)
5. P2   Strassen(A11 + A12,B22)
6. P3   Strassen(A21 + A22,B11)
7. P4   Strassen(A22,B21 − B11)
8. P5   Strassen(A11 + A22,B11 + B22)
9. P6   Strassen(A12 − A22,B21 + B22)
10. P7   Strassen(A11 − A21,B11 + B12)
11. C11   P5 + P4 − P2 + P6
12. C12   P1 + P2
13. C21   P3 + P4
14. C22   P1 + P5 − P3 − P7
15. Output C
16. End If

例如其中的一步代码如下所示:

   {// S1=A21+A22, T1=B12-B11, P5=S1T1auto f = [a22, a21, b11, b12, xAddr, yAddr, eSub, lSub, hSub, aStride, bStride]() {MNNMatrixAdd(xAddr, a21, a22, eSub * aUnit / 4, eSub * aUnit, aStride, aStride, lSub);MNNMatrixSub(yAddr, b12, b11, lSub * bUnit / 4, lSub * bUnit, bStride, bStride, hSub);};mFunctions.emplace_back(f);auto code = _generateMatMul(X.get(), Y.get(), C22.get(), currentDepth);if (code != NO_ERROR) {return code;}}

递归执行,得到最终结果!

这篇关于详解矩阵乘法中的Strassen算法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1024517

相关文章

详解如何通过Python批量转换图片为PDF

《详解如何通过Python批量转换图片为PDF》:本文主要介绍如何基于Python+Tkinter开发的图片批量转PDF工具,可以支持批量添加图片,拖拽等操作,感兴趣的小伙伴可以参考一下... 目录1. 概述2. 功能亮点2.1 主要功能2.2 界面设计3. 使用指南3.1 运行环境3.2 使用步骤4. 核

一文详解JavaScript中的fetch方法

《一文详解JavaScript中的fetch方法》fetch函数是一个用于在JavaScript中执行HTTP请求的现代API,它提供了一种更简洁、更强大的方式来处理网络请求,:本文主要介绍Jav... 目录前言什么是 fetch 方法基本语法简单的 GET 请求示例代码解释发送 POST 请求示例代码解释

springboot+dubbo实现时间轮算法

《springboot+dubbo实现时间轮算法》时间轮是一种高效利用线程资源进行批量化调度的算法,本文主要介绍了springboot+dubbo实现时间轮算法,文中通过示例代码介绍的非常详细,对大家... 目录前言一、参数说明二、具体实现1、HashedwheelTimer2、createWheel3、n

详解nginx 中location和 proxy_pass的匹配规则

《详解nginx中location和proxy_pass的匹配规则》location是Nginx中用来匹配客户端请求URI的指令,决定如何处理特定路径的请求,它定义了请求的路由规则,后续的配置(如... 目录location 的作用语法示例:location /www.chinasem.cntestproxy

CSS will-change 属性示例详解

《CSSwill-change属性示例详解》will-change是一个CSS属性,用于告诉浏览器某个元素在未来可能会发生哪些变化,本文给大家介绍CSSwill-change属性详解,感... will-change 是一个 css 属性,用于告诉浏览器某个元素在未来可能会发生哪些变化。这可以帮助浏览器优化

Python基础文件操作方法超详细讲解(详解版)

《Python基础文件操作方法超详细讲解(详解版)》文件就是操作系统为用户或应用程序提供的一个读写硬盘的虚拟单位,文件的核心操作就是读和写,:本文主要介绍Python基础文件操作方法超详细讲解的相... 目录一、文件操作1. 文件打开与关闭1.1 打开文件1.2 关闭文件2. 访问模式及说明二、文件读写1.

详解C++中类的大小决定因数

《详解C++中类的大小决定因数》类的大小受多个因素影响,主要包括成员变量、对齐方式、继承关系、虚函数表等,下面就来介绍一下,具有一定的参考价值,感兴趣的可以了解一下... 目录1. 非静态数据成员示例:2. 数据对齐(Padding)示例:3. 虚函数(vtable 指针)示例:4. 继承普通继承虚继承5.

前端高级CSS用法示例详解

《前端高级CSS用法示例详解》在前端开发中,CSS(层叠样式表)不仅是用来控制网页的外观和布局,更是实现复杂交互和动态效果的关键技术之一,随着前端技术的不断发展,CSS的用法也日益丰富和高级,本文将深... 前端高级css用法在前端开发中,CSS(层叠样式表)不仅是用来控制网页的外观和布局,更是实现复杂交

Linux换行符的使用方法详解

《Linux换行符的使用方法详解》本文介绍了Linux中常用的换行符LF及其在文件中的表示,展示了如何使用sed命令替换换行符,并列举了与换行符处理相关的Linux命令,通过代码讲解的非常详细,需要的... 目录简介检测文件中的换行符使用 cat -A 查看换行符使用 od -c 检查字符换行符格式转换将

详解C#如何提取PDF文档中的图片

《详解C#如何提取PDF文档中的图片》提取图片可以将这些图像资源进行单独保存,方便后续在不同的项目中使用,下面我们就来看看如何使用C#通过代码从PDF文档中提取图片吧... 当 PDF 文件中包含有价值的图片,如艺术画作、设计素材、报告图表等,提取图片可以将这些图像资源进行单独保存,方便后续在不同的项目中使