b站小土堆pytorch学习记录—— P25-P26 网络模型的使用和修改、保存和读取

本文主要是介绍b站小土堆pytorch学习记录—— P25-P26 网络模型的使用和修改、保存和读取,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 一、修改
    • 1.方法
    • 2.代码
  • 二、保存和读取
    • 1.方法
    • 2.代码
      • (1)保存
      • (2)加载
    • 3.陷阱

一、修改

1.方法

add_module(name: str, module: Module) -> None

name 是要添加的子模块的名称。
module 是要添加的子模块。
调用 add_module 方法会向当前模块中添加一个子模块,并使用指定的名称进行标识。

2.代码

import torchvision
from torch import nn# 实例化一个未经过预训练的 VGG16 模型
vgg16_false = torchvision.models.vgg16(pretrained=False)# 实例化一个经过预训练的 VGG16 模型
vgg16_true = torchvision.models.vgg16(pretrained=True)print("ok")# 输出经过预训练的 VGG16 模型及修改后的模型
print(vgg16_true)
vgg16_true.classifier.add_module("add_linear", nn.Linear(1000, 10))
print(vgg16_true)# 输出未经过预训练的 VGG16 模型及修改后的模型
print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)

修改前的vgg16_true:

在这里插入图片描述
修改后的vgg16_true:

在这里插入图片描述

修改前的vgg16_true:

在这里插入图片描述

修改后的vgg16_true:

在这里插入图片描述

二、保存和读取

1.方法

保存: torch.save(要保存的模型,“文件路径”)

加载: torch.load(“文件路径”)

2.代码

(1)保存

import torch
import torchvisionvgg16 = torchvision.models.vgg16(pretrained=False)# 保存方式1:模型结构+模型参数
torch.save(vgg16, "vgg16_module1.pth")# 保存方式2:模型参数(官方推荐)
torch.save(vgg16.state_dict(), "vgg16_module2.pth")

(2)加载

import torch
import torchvision# 方式1 加载模型
module1 = torch.load("vgg16_module1.pth")
print(module1)#
module2 = torch.load("vgg16_module2.pth")
print(module2)# 方式2 加载模型
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_module2.pth"))
print(vgg16)

运行加载的代码后,打印结果如下

module1:

在这里插入图片描述
module2:

在这里插入图片描述

vgg16:

在这里插入图片描述

可以看到,第二种方式保存的数据,加载后是向量形式,需要通过别的方法加载为模型

3.陷阱

第一种方式加载,在某些条件下可能会报错

例如:

假设自定义一个神经网络,保存:

import torch
import torchvision
from torch import nn# 陷阱
class Guodong(nn.Module):def __init__(self):super(Guodong,self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3)def forward(self,x):x = self.conv1(x)return xguodong = Guodong()
torch.save(guodong,"guodong_method1.pth")

在另一个文件中加载:

import torch# 陷阱
module = torch.load("guodong_method1.pth")
print(module)

就会报错:

AttributeError: Can’t get attribute ‘Guodong’ on <module ‘main’ from ‘E:\deepLearning\Pycharm\pytroch_project\theFirstFile\module_load.py’>

解决办法:

(1)把Guodong类放在这个文件里

import torch
from torch import nn
import torchvisionclass Guodong(nn.Module):def __init__(self):super(Guodong,self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3)def forward(self,x):x = self.conv1(x)return x# 陷阱
module = torch.load("guodong_method1.pth")
print(module)

(2)from module_save import *

(module_save)是保存自定义模型的文件

from module_save import *# 陷阱
module = torch.load("guodong_method1.pth")
print(module)

这篇关于b站小土堆pytorch学习记录—— P25-P26 网络模型的使用和修改、保存和读取的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/785211

相关文章

Python使用PIL库将PNG图片转换为ICO图标的示例代码

《Python使用PIL库将PNG图片转换为ICO图标的示例代码》在软件开发和网站设计中,ICO图标是一种常用的图像格式,特别适用于应用程序图标、网页收藏夹图标等场景,本文将介绍如何使用Python的... 目录引言准备工作代码解析实践操作结果展示结语引言在软件开发和网站设计中,ICO图标是一种常用的图像

使用Java发送邮件到QQ邮箱的完整指南

《使用Java发送邮件到QQ邮箱的完整指南》在现代软件开发中,邮件发送功能是一个常见的需求,无论是用户注册验证、密码重置,还是系统通知,邮件都是一种重要的通信方式,本文将详细介绍如何使用Java编写程... 目录引言1. 准备工作1.1 获取QQ邮箱的SMTP授权码1.2 添加JavaMail依赖2. 实现

MyBatis与其使用方法示例详解

《MyBatis与其使用方法示例详解》MyBatis是一个支持自定义SQL的持久层框架,通过XML文件实现SQL配置和数据映射,简化了JDBC代码的编写,本文给大家介绍MyBatis与其使用方法讲解,... 目录ORM缺优分析MyBATisMyBatis的工作流程MyBatis的基本使用环境准备MyBati

使用Python开发一个图像标注与OCR识别工具

《使用Python开发一个图像标注与OCR识别工具》:本文主要介绍一个使用Python开发的工具,允许用户在图像上进行矩形标注,使用OCR对标注区域进行文本识别,并将结果保存为Excel文件,感兴... 目录项目简介1. 图像加载与显示2. 矩形标注3. OCR识别4. 标注的保存与加载5. 裁剪与重置图像

使用Python实现表格字段智能去重

《使用Python实现表格字段智能去重》在数据分析和处理过程中,数据清洗是一个至关重要的步骤,其中字段去重是一个常见且关键的任务,下面我们看看如何使用Python进行表格字段智能去重吧... 目录一、引言二、数据重复问题的常见场景与影响三、python在数据清洗中的优势四、基于Python的表格字段智能去重

使用Apache POI在Java中实现Excel单元格的合并

《使用ApachePOI在Java中实现Excel单元格的合并》在日常工作中,Excel是一个不可或缺的工具,尤其是在处理大量数据时,本文将介绍如何使用ApachePOI库在Java中实现Excel... 目录工具类介绍工具类代码调用示例依赖配置总结在日常工作中,Excel 是一个不可或缺的工http://

Java之并行流(Parallel Stream)使用详解

《Java之并行流(ParallelStream)使用详解》Java并行流(ParallelStream)通过多线程并行处理集合数据,利用Fork/Join框架加速计算,适用于大规模数据集和计算密集... 目录Java并行流(Parallel Stream)1. 核心概念与原理2. 创建并行流的方式3. 适

Python如何实现读取csv文件时忽略文件的编码格式

《Python如何实现读取csv文件时忽略文件的编码格式》我们再日常读取csv文件的时候经常会发现csv文件的格式有多种,所以这篇文章为大家介绍了Python如何实现读取csv文件时忽略文件的编码格式... 目录1、背景介绍2、库的安装3、核心代码4、完整代码1、背景介绍我们再日常读取csv文件的时候经常

如何使用Docker部署FTP和Nginx并通过HTTP访问FTP里的文件

《如何使用Docker部署FTP和Nginx并通过HTTP访问FTP里的文件》本文介绍了如何使用Docker部署FTP服务器和Nginx,并通过HTTP访问FTP中的文件,通过将FTP数据目录挂载到N... 目录docker部署FTP和Nginx并通过HTTP访问FTP里的文件1. 部署 FTP 服务器 (

MySQL 日期时间格式化函数 DATE_FORMAT() 的使用示例详解

《MySQL日期时间格式化函数DATE_FORMAT()的使用示例详解》`DATE_FORMAT()`是MySQL中用于格式化日期时间的函数,本文详细介绍了其语法、格式化字符串的含义以及常见日期... 目录一、DATE_FORMAT()语法二、格式化字符串详解三、常见日期时间格式组合四、业务场景五、总结一、