本文主要是介绍torchsummary 简单使用,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
文章目录
- 1. torchsummary
- 2. 代码
- 3. 结果
1. torchsummary
听说是一个大神写的开源插件,实在是太牛逼了,以前只有TensorFlow有,现在pytorch也有了。
torchsummay_github链接
2. 代码
# -*- coding: utf-8 -*-
# @Project: zc
# @Author: zc
# @File name: device_test
# @Create time: 2022/3/9 8:45
import torch
from torch import nn
from torchsummary import summaryclass My_Model(nn.Module):def __init__(self):super(My_Model, self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28 * 28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, X):X = self.flatten(X)logist = self.linear_relu_stack(X)return logist# 实例化神经网络
new_model = My_Model()
# 定义输入矩阵x
x = torch.ones(2, 28, 28)
# 将输入矩阵进入网络后输出
y = new_model(x)
# 得到summary
summary(new_model, input_data=x, device='cpu')
3. 结果
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
├─Flatten: 1-1 [-1, 784] --
├─Sequential: 1-2 [-1, 10] --
| └─Linear: 2-1 [-1, 512] 401,920
| └─ReLU: 2-2 [-1, 512] --
| └─Linear: 2-3 [-1, 512] 262,656
| └─ReLU: 2-4 [-1, 512] --
| └─Linear: 2-5 [-1, 10] 5,130
==========================================================================================
Total params: 669,706
Trainable params: 669,706
Non-trainable params: 0
Total mult-adds (M): 1.34
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.01
Params size (MB): 2.55
Estimated Total Size (MB): 2.57
==========================================================================================
这篇关于torchsummary 简单使用的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!