本文主要是介绍如何计算lstm网络的复杂度 乘法次数 flops(未完成),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
计算算法复杂度的一种方法是计算前向算法的乘法次数,因此在这篇文章中计算复杂度的方法是计算lstm网络的乘法运算次数
首先要弄清楚lstm的cell中的乘法次数
cell有4个输入,一个输出
有三个门,input gate控制数据输不输进来,不输入就输入0
forget gate控制保存单元memory cell更不更新,不更新就维持原状
output gate控制计算的值输不输出,不输出就输出0
举例
x是input y是output
lstm的cell有4个input,就是矢量x+bias,乘上weight,求和
以信号处理为例,batch为1,全部省略,此时的x大小[1,Nv]
接下来是用别人写好的库计算,用的是pytorch OpCounter
GitHub - Lyken17/pytorch-OpCounter: Count the MACs / FLOPs of your PyTorch model.
按照github上的提示
pip install thop
然后就可以运行下面的例子了
这篇关于如何计算lstm网络的复杂度 乘法次数 flops(未完成)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!