PyTorch模型参数量计算【使用torchsummary库与自定义 两种方法!附完整代码!!】

本文主要是介绍PyTorch模型参数量计算【使用torchsummary库与自定义 两种方法!附完整代码!!】,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

深度学习PyTorch模型参数量计算示例

在深度学习中,有时候处于分析模型的需要或者写文章的需要,得到模型的参数量对于理解模型复杂度、进行内存管理以及模型优化都至关重要。PyTorch作为当前流行的深度学习框架,为我们提供了计算模型参数量的工具和方法。下面将通过两个示例,详细展示如何在PyTorch中计算模型的参数量。

示例一:基础模型参数量计算

首先,我们创建一个简单的PyTorch模型,该模型包含一个卷积层、一个ReLU激活函数层和一个全连接层。然后,我们将使用PyTorch的torchsummary库来计算模型的参数量。

import torch
import torch.nn as nn
from torchsummary import summary# 定义一个示例模型
class ExampleModel(nn.Module):def __init__(self):super(ExampleModel, self).__init__()self.conv1 = nn.Conv2d(3, 16, 3, padding=1)self.conv2 = nn.Conv2d(16, 32, 3, padding=1)self.fc = nn.Linear(32 * 8 * 8, 10)def forward(self, x1, x2, x3):# 处理第一个图像x = torch.relu(self.conv1(x1))x = torch.relu(self.conv2(x))x = x.view(-1, 32 * 8 * 8)x = self.fc(x)# 处理第二个图像y = torch.relu(self.conv1(x2))y = torch.relu(self.conv2(y))y = y.view(-1, 32 * 8 * 8)y = self.fc(y)# 处理第三个图像z = torch.relu(self.conv1(x3))z = torch.relu(self.conv2(z))z = z.view(-1, 32 * 8 * 8)z = self.fc(z)return x, y, z# 创建一个示例模型实例
model = ExampleModel()# 将模型移动到 CUDA 设备上
device = torch.device('cuda')
model.to(device)# 模拟输入,假设每张图像大小为 3x32x32
image1 = torch.randn(1, 3, 32, 32).to(device)
image2 = torch.randn(1, 3, 32, 32).to(device)
image3 = torch.randn(1, 3, 32, 32).to(device)# 打印模型摘要
summary(model, [(3, 32, 32), (3, 32, 32), (3, 32, 32)])  # 传递每个图像的输入大小

在上面的代码中,我们首先定义了一个简单的示例模型ExampleModel,然后使用torchsummary库的summary函数来计算模型的参数量。summary函数需要两个参数:模型实例和输入数据的形状。执行这段代码后,将会输出模型的每一层的详细信息,包括输出大小、参数量等。
在这里插入图片描述

示例二:自定义函数计算参数量

除了使用torchsummary库,我们还可以自定义一个函数来计算模型的参数量。这样做的好处是更加灵活,可以根据需要定制输出信息。

import torch
import torch.nn as nn# 定义一个示例模型
class ExampleModel(nn.Module):def __init__(self):super(ExampleModel, self).__init__()self.conv1 = nn.Conv2d(3, 16, 3, padding=1)self.conv2 = nn.Conv2d(16, 32, 3, padding=1)self.fc = nn.Linear(32 * 8 * 8, 10)def forward(self, x):x = torch.relu(self.conv1(x))x = torch.relu(self.conv2(x))x = x.view(-1, 32 * 8 * 8)x = self.fc(x)return x# 创建一个示例模型实例
model = ExampleModel()# 计算模型参数大小
total_params = sum(p.numel() for p in model.parameters())
print("Total parameters:", total_params)

在这个示例中,我们定义了一个sum(p.numel() for p in model.parameters())操作,它遍历模型的所有参数,并计算需要梯度的参数的总数。numel()函数返回张量中的元素总数。最后,我们打印出模型的参数量。
在这里插入图片描述

总结

通过以上两个示例,我们展示了如何在PyTorch中计算模型的参数量。第一个示例使用了torchsummary库,它提供了详细的模型层信息以及参数量的统计;第二个示例则通过自定义函数来实现参数量的计算,更加灵活可控。在实际项目中,大家根据具体需求选择合适的方法来计算模型参数量,有助于更好地理解和优化模型。

版权声明

本博客内容仅供学习交流,转载请注明出处。

这篇关于PyTorch模型参数量计算【使用torchsummary库与自定义 两种方法!附完整代码!!】的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/945453

相关文章

Python通用唯一标识符模块uuid使用案例详解

《Python通用唯一标识符模块uuid使用案例详解》Pythonuuid模块用于生成128位全局唯一标识符,支持UUID1-5版本,适用于分布式系统、数据库主键等场景,需注意隐私、碰撞概率及存储优... 目录简介核心功能1. UUID版本2. UUID属性3. 命名空间使用场景1. 生成唯一标识符2. 数

Java中读取YAML文件配置信息常见问题及解决方法

《Java中读取YAML文件配置信息常见问题及解决方法》:本文主要介绍Java中读取YAML文件配置信息常见问题及解决方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要... 目录1 使用Spring Boot的@ConfigurationProperties2. 使用@Valu

创建Java keystore文件的完整指南及详细步骤

《创建Javakeystore文件的完整指南及详细步骤》本文详解Java中keystore的创建与配置,涵盖私钥管理、自签名与CA证书生成、SSL/TLS应用,强调安全存储及验证机制,确保通信加密和... 目录1. 秘密键(私钥)的理解与管理私钥的定义与重要性私钥的管理策略私钥的生成与存储2. 证书的创建与

SpringBoot中如何使用Assert进行断言校验

《SpringBoot中如何使用Assert进行断言校验》Java提供了内置的assert机制,而Spring框架也提供了更强大的Assert工具类来帮助开发者进行参数校验和状态检查,下... 目录前言一、Java 原生assert简介1.1 使用方式1.2 示例代码1.3 优缺点分析二、Spring Fr

Android kotlin中 Channel 和 Flow 的区别和选择使用场景分析

《Androidkotlin中Channel和Flow的区别和选择使用场景分析》Kotlin协程中,Flow是冷数据流,按需触发,适合响应式数据处理;Channel是热数据流,持续发送,支持... 目录一、基本概念界定FlowChannel二、核心特性对比数据生产触发条件生产与消费的关系背压处理机制生命周期

java使用protobuf-maven-plugin的插件编译proto文件详解

《java使用protobuf-maven-plugin的插件编译proto文件详解》:本文主要介绍java使用protobuf-maven-plugin的插件编译proto文件,具有很好的参考价... 目录protobuf文件作为数据传输和存储的协议主要介绍在Java使用maven编译proto文件的插件

Java 方法重载Overload常见误区及注意事项

《Java方法重载Overload常见误区及注意事项》Java方法重载允许同一类中同名方法通过参数类型、数量、顺序差异实现功能扩展,提升代码灵活性,核心条件为参数列表不同,不涉及返回类型、访问修饰符... 目录Java 方法重载(Overload)详解一、方法重载的核心条件二、构成方法重载的具体情况三、不构

SpringBoot线程池配置使用示例详解

《SpringBoot线程池配置使用示例详解》SpringBoot集成@Async注解,支持线程池参数配置(核心数、队列容量、拒绝策略等)及生命周期管理,结合监控与任务装饰器,提升异步处理效率与系统... 目录一、核心特性二、添加依赖三、参数详解四、配置线程池五、应用实践代码说明拒绝策略(Rejected

C++ Log4cpp跨平台日志库的使用小结

《C++Log4cpp跨平台日志库的使用小结》Log4cpp是c++类库,本文详细介绍了C++日志库log4cpp的使用方法,及设置日志输出格式和优先级,具有一定的参考价值,感兴趣的可以了解一下... 目录一、介绍1. log4cpp的日志方式2.设置日志输出的格式3. 设置日志的输出优先级二、Window

SQL中如何添加数据(常见方法及示例)

《SQL中如何添加数据(常见方法及示例)》SQL全称为StructuredQueryLanguage,是一种用于管理关系数据库的标准编程语言,下面给大家介绍SQL中如何添加数据,感兴趣的朋友一起看看吧... 目录在mysql中,有多种方法可以添加数据。以下是一些常见的方法及其示例。1. 使用INSERT I