本文主要是介绍pytorch计算网络参数量和Flops,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
from torchsummary import summary
summary(net, input_size=(3, 256, 256), batch_size=-1)
输出的参数是除以一百万(/1000000)M,
from fvcore.nn import FlopCountAnalysis
inputs = torch.randn(1, 3, 256, 256).cuda()
flop_counter = FlopCountAnalysis(net, inputs)
print(f"FLOPs: {flop_counter.total()}")
输出的参数是B,(/1024/1024/1024)G,(/1024/1024/1024/1024)T
这篇关于pytorch计算网络参数量和Flops的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!