详解矩阵乘法中的Strassen算法

2024-06-02 16:38

本文主要是介绍详解矩阵乘法中的Strassen算法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

 

机器学习中需要训练大量数据,涉及大量复杂运算,例如卷积、矩阵等。这些复杂运算不仅多,而且每次计算的数据量很大,如果能针对这些运算进行优化,可以大幅提高性能。

一、矩阵乘法

如下图所示:

Figure 1 Matrix Multiplication

二、Strassen算法

Figure 2 x^3 vs. x^2.807

三、Strassen原理详解

Strassen算法正是从这个角度出发,实现了降低算法复杂度!

实现步骤可以分为以下4步:

3.1 Strassen实现步骤

 

四、Strassen算法的代码实现

我们以MNN中关于Strassen算法源码实现来学习:https://github.com/alibaba/MNN/blob/master/source/backend/cpu/compute/StrassenMatmulComputor.cpp。

类StrassenMatrixComputor提供了3个API供调用:

_generateTrivalMatMul(const Tensor* AT, const Tensor* BT, const Tensor* CT);

普通矩阵乘法计算

_generateMatMul(const Tensor* AT, const Tensor* BT, const Tensor* CT, int currentDepth);

Strassen算法的矩阵乘法

_generateMatMulConstB(const Tensor* AT, const Tensor* BT, const Tensor* CT, int currentDepth);

Strassen算法的矩阵乘法(和MatMul的区别在于内存Buffer是否允许复用)

我们以_generateMatMul为例来学习下Strassen算法如何实现,可以分成如下几步:

第一步:使用Strassen算法收益判断

在矩阵操作中,因为需要对矩阵的维数进行扩展,涉及大量读写操作,这些读写操作都需要大量循环,如果读写次数超出使用Strassen乘法的收益的话,就得不偿失了,那么就使用普通的矩阵乘法

    /*Compute the memory read / write cost for expandMatrix Mul need eSub*lSub*hSub*(1+1.0/CONVOLUTION_TILED_NUMBWR), Matrix Add/Sub need x*y*UNIT*3 (2 read 1 write)*/float saveCost =(eSub * lSub * hSub) * (1.0f + 1.0f / CONVOLUTION_TILED_NUMBWR) - 4 * (eSub * lSub) * 3 - 7 * (eSub * hSub * 3);if (currentDepth >= mMaxDepth || e <= CONVOLUTION_TILED_NUMBWR || l % 2 != 0 || h % 2 != 0 || saveCost < 0.0f) {return _generateTrivialMatMul(AT, BT, CT);}

第二步:分块

    auto aStride = AT->stride(0);auto a11     = AT->host<float>() + 0 * aUnit * eSub + 0 * aStride * lSub;auto a12     = AT->host<float>() + 0 * aUnit * eSub + 1 * aStride * lSub;auto a21     = AT->host<float>() + 1 * aUnit * eSub + 0 * aStride * lSub;auto a22     = AT->host<float>() + 1 * aUnit * eSub + 1 * aStride * lSub;auto bStride = BT->stride(0);auto b11     = BT->host<float>() + 0 * bUnit * lSub + 0 * bStride * hSub;auto b12     = BT->host<float>() + 0 * bUnit * lSub + 1 * bStride * hSub;auto b21     = BT->host<float>() + 1 * bUnit * lSub + 0 * bStride * hSub;auto b22     = BT->host<float>() + 1 * bUnit * lSub + 1 * bStride * hSub;auto cStride = CT->stride(0);auto c11     = CT->host<float>() + 0 * aUnit * eSub + 0 * cStride * hSub;auto c12     = CT->host<float>() + 0 * aUnit * eSub + 1 * cStride * hSub;auto c21     = CT->host<float>() + 1 * aUnit * eSub + 0 * cStride * hSub;auto c22     = CT->host<float>() + 1 * aUnit * eSub + 1 * cStride * hSub;

第三步:分治和递归

Strassen算法核心就是分治思想。这一步可以写成下列所示伪代码:

1. If n = 1 Output A × B
2. Else
3. Compute A11,B11, . . . ,A22,B22 % by computing m = n/2
4. P1   Strassen(A11,B12 − B22)
5. P2   Strassen(A11 + A12,B22)
6. P3   Strassen(A21 + A22,B11)
7. P4   Strassen(A22,B21 − B11)
8. P5   Strassen(A11 + A22,B11 + B22)
9. P6   Strassen(A12 − A22,B21 + B22)
10. P7   Strassen(A11 − A21,B11 + B12)
11. C11   P5 + P4 − P2 + P6
12. C12   P1 + P2
13. C21   P3 + P4
14. C22   P1 + P5 − P3 − P7
15. Output C
16. End If

例如其中的一步代码如下所示:

   {// S1=A21+A22, T1=B12-B11, P5=S1T1auto f = [a22, a21, b11, b12, xAddr, yAddr, eSub, lSub, hSub, aStride, bStride]() {MNNMatrixAdd(xAddr, a21, a22, eSub * aUnit / 4, eSub * aUnit, aStride, aStride, lSub);MNNMatrixSub(yAddr, b12, b11, lSub * bUnit / 4, lSub * bUnit, bStride, bStride, hSub);};mFunctions.emplace_back(f);auto code = _generateMatMul(X.get(), Y.get(), C22.get(), currentDepth);if (code != NO_ERROR) {return code;}}

递归执行,得到最终结果!

这篇关于详解矩阵乘法中的Strassen算法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1024517

相关文章

Mysql 中的多表连接和连接类型详解

《Mysql中的多表连接和连接类型详解》这篇文章详细介绍了MySQL中的多表连接及其各种类型,包括内连接、左连接、右连接、全外连接、自连接和交叉连接,通过这些连接方式,可以将分散在不同表中的相关数据... 目录什么是多表连接?1. 内连接(INNER JOIN)2. 左连接(LEFT JOIN 或 LEFT

Java中switch-case结构的使用方法举例详解

《Java中switch-case结构的使用方法举例详解》:本文主要介绍Java中switch-case结构使用的相关资料,switch-case结构是Java中处理多个分支条件的一种有效方式,它... 目录前言一、switch-case结构的基本语法二、使用示例三、注意事项四、总结前言对于Java初学者

Linux内核之内核裁剪详解

《Linux内核之内核裁剪详解》Linux内核裁剪是通过移除不必要的功能和模块,调整配置参数来优化内核,以满足特定需求,裁剪的方法包括使用配置选项、模块化设计和优化配置参数,图形裁剪工具如makeme... 目录简介一、 裁剪的原因二、裁剪的方法三、图形裁剪工具四、操作说明五、make menuconfig

详解Java中的敏感信息处理

《详解Java中的敏感信息处理》平时开发中常常会遇到像用户的手机号、姓名、身份证等敏感信息需要处理,这篇文章主要为大家整理了一些常用的方法,希望对大家有所帮助... 目录前后端传输AES 对称加密RSA 非对称加密混合加密数据库加密MD5 + Salt/SHA + SaltAES 加密平时开发中遇到像用户的

Springboot使用RabbitMQ实现关闭超时订单(示例详解)

《Springboot使用RabbitMQ实现关闭超时订单(示例详解)》介绍了如何在SpringBoot项目中使用RabbitMQ实现订单的延时处理和超时关闭,通过配置RabbitMQ的交换机、队列和... 目录1.maven中引入rabbitmq的依赖:2.application.yml中进行rabbit

C语言线程池的常见实现方式详解

《C语言线程池的常见实现方式详解》本文介绍了如何使用C语言实现一个基本的线程池,线程池的实现包括工作线程、任务队列、任务调度、线程池的初始化、任务添加、销毁等步骤,感兴趣的朋友跟随小编一起看看吧... 目录1. 线程池的基本结构2. 线程池的实现步骤3. 线程池的核心数据结构4. 线程池的详细实现4.1 初

Python绘制土地利用和土地覆盖类型图示例详解

《Python绘制土地利用和土地覆盖类型图示例详解》本文介绍了如何使用Python绘制土地利用和土地覆盖类型图,并提供了详细的代码示例,通过安装所需的库,准备地理数据,使用geopandas和matp... 目录一、所需库的安装二、数据准备三、绘制土地利用和土地覆盖类型图四、代码解释五、其他可视化形式1.

SpringBoot使用Apache POI库读取Excel文件的操作详解

《SpringBoot使用ApachePOI库读取Excel文件的操作详解》在日常开发中,我们经常需要处理Excel文件中的数据,无论是从数据库导入数据、处理数据报表,还是批量生成数据,都可能会遇到... 目录项目背景依赖导入读取Excel模板的实现代码实现代码解析ExcelDemoInfoDTO 数据传输

如何用Java结合经纬度位置计算目标点的日出日落时间详解

《如何用Java结合经纬度位置计算目标点的日出日落时间详解》这篇文章主详细讲解了如何基于目标点的经纬度计算日出日落时间,提供了在线API和Java库两种计算方法,并通过实际案例展示了其应用,需要的朋友... 目录前言一、应用示例1、天安门升旗时间2、湖南省日出日落信息二、Java日出日落计算1、在线API2

使用Spring Cache时设置缓存键的注意事项详解

《使用SpringCache时设置缓存键的注意事项详解》在现代的Web应用中,缓存是提高系统性能和响应速度的重要手段之一,Spring框架提供了强大的缓存支持,通过​​@Cacheable​​、​​... 目录引言1. 缓存键的基本概念2. 默认缓存键生成器3. 自定义缓存键3.1 使用​​@Cacheab