【CenterFusion】模型的创建、导入、保存CenterFusion/src/lib/model/model.py

2024-03-17 23:52

本文主要是介绍【CenterFusion】模型的创建、导入、保存CenterFusion/src/lib/model/model.py,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文件内容:CenterFusion/src/lib/model/model.py
文件作用:模型的创建、导入、保存

model.py 具体内容如下:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport torchvision.models as models
import torch
import torch.nn as nn
import osfrom .networks.dla import DLASeg
from .networks.resdcn import PoseResDCN
from .networks.resnet import PoseResNet
from .networks.dlav0 import DLASegv0
from .networks.generic_network import GenericNetwork_network_factory = {'resdcn': PoseResDCN,'dla': DLASeg,'res': PoseResNet,'dlav0': DLASegv0,'generic': GenericNetwork
}def create_model(arch, head, head_conv, opt=None):num_layers = int(arch[arch.find('_') + 1:]) if '_' in arch else 0'''处理字符串 arch = dla_34 ,将下划线后半部分取出最后 num_layers = 34'''arch = arch[:arch.find('_')] if '_' in arch else arch'''将 arch = dla_34 中下划线前半部分取出最后 arch = 'dla''''model_class = _network_factory[arch]'''根据 arch = 'dla' 获取 _network_factory 中的值最后 model_class = DLASegDLASeg 类定义在 CenterFusion/src/lib/model/networks/dla.py 第 594 行'''model = model_class(num_layers, heads=head, head_convs=head_conv, opt=opt)'''配置模型'''return modeldef load_model(model, model_path, opt, optimizer=None):start_epoch = 0'''设定初始轮次 = 0'''checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch']))'''torch.load() 函数:用来加载 torch.save() 保存的模型文件'''state_dict_ = checkpoint['state_dict']'''获取 checkpoint 模型文件中的 state_dict 属性这个属性存放训练过程中需要学习的权重和偏执系数state_dict 作为 python 的字典对象将每一层的参数映射成 tensor 张量需要注意的是 torch.nn.Module 模块中的 state_dict 只包含卷积层和全连接层的参数'''state_dict = {}for k in state_dict_:if k.startswith('module') and not k.startswith('module_list'):state_dict[k[7:]] = state_dict_[k]else:state_dict[k] = state_dict_[k]'''startswith(str) 函数:检测字符串 str,检测到返回 True,否则返回 False这里只执行了 else 语句,相当于保存导入模型的网络参数'''model_state_dict = model.state_dict()'''浅拷贝 main.py 中创建的新模型 DLA 的网络参数'''for k in state_dict:'''遍历导入的模型中的每层网络参数'''if k in model_state_dict:'''判断新模型的网络参数中是否有导入的模型的参数是有的,因为导入的模型也是 DLA 模型'''if (state_dict[k].shape != model_state_dict[k].shape) or \(opt.reset_hm and k.startswith('hm') and (state_dict[k].shape[0] in [80, 1])):'''第一个条件为 True其余条件全部为 False'''if opt.reuse_hm:'''不执行'''print('Reusing parameter {}, required shape{}, '\'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape))# todo: bug in next line: both sides of < are the sameif state_dict[k].shape[0] < state_dict[k].shape[0]:model_state_dict[k][:state_dict[k].shape[0]] = state_dict[k]else:model_state_dict[k] = state_dict[k][:model_state_dict[k].shape[0]]state_dict[k] = model_state_dict[k]elif opt.warm_start_weights:'''不执行'''try:print('Partially loading parameter {}, required shape{}, '\'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape))if state_dict[k].shape[1] < model_state_dict[k].shape[1]:model_state_dict[k][:,:state_dict[k].shape[1]] = state_dict[k]else:model_state_dict[k] = state_dict[k][:,:model_state_dict[k].shape[1]]state_dict[k] = model_state_dict[k]except:print('Skip loading parameter {}, required shape{}, '\'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape))state_dict[k] = model_state_dict[k]else:'''执行该 else 中的语句'''print('Skip loading parameter {}, required shape{}, '\'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape))state_dict[k] = model_state_dict[k]'''将新模型的网络参数赋值给导入的模型中'''else:print('Drop parameter {}.'.format(k))for k in model_state_dict:if not (k in state_dict):print('No param {}.'.format(k))state_dict[k] = model_state_dict[k]'''给导入的模型添加没有的参数'''model.load_state_dict(state_dict, strict=False)'''使用 state_dict 反序列化模型参数字字典,用来加载模型参数将 state_dict 中的 parameters 和 buffers 复制到此 module 及其子节点中简述:给模型对象加载训练好的模型参数,即加载模型参数'''#冻结骨干网,没有执行if opt.freeze_backbone:for (name, module) in model.named_children():if name in opt.layers_to_freeze:for (name, layer) in module.named_children():for param in layer.parameters():param.requires_grad = False# 恢复优化器参数,没有执行if optimizer is not None and opt.resume:if 'optimizer' in checkpoint:start_epoch = checkpoint['epoch']start_lr = opt.lrfor step in opt.lr_step:if start_epoch >= step:start_lr *= 0.1for param_group in optimizer.param_groups:param_group['lr'] = start_lrprint('Resumed optimizer with start lr', start_lr)else:print('No optimizer parameters in checkpoint.')if optimizer is not None:'''执行该 if 语句'''return model, optimizer, start_epochelse:return modeldef save_model(path, epoch, model, optimizer=None):if isinstance(model, torch.nn.DataParallel):'''isinstance(object, classinfo) 判断一个函数 object 是否是一个已知的类型 classinfo是则返回 True,反之返回 False'''state_dict = model.module.state_dict()else:state_dict = model.state_dict()'''获取模型的参数矩阵'''data = {'epoch': epoch,'state_dict': state_dict}if not (optimizer is None):data['optimizer'] = optimizer.state_dict()'''获取模型的优化器'''torch.save(data, path)'''保存模型'''

这篇关于【CenterFusion】模型的创建、导入、保存CenterFusion/src/lib/model/model.py的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/820587

相关文章

Java的IO模型、Netty原理解析

《Java的IO模型、Netty原理解析》Java的I/O是以流的方式进行数据输入输出的,Java的类库涉及很多领域的IO内容:标准的输入输出,文件的操作、网络上的数据传输流、字符串流、对象流等,这篇... 目录1.什么是IO2.同步与异步、阻塞与非阻塞3.三种IO模型BIO(blocking I/O)NI

idea中创建新类时自动添加注释的实现

《idea中创建新类时自动添加注释的实现》在每次使用idea创建一个新类时,过了一段时间发现看不懂这个类是用来干嘛的,为了解决这个问题,我们可以设置在创建一个新类时自动添加注释,帮助我们理解这个类的用... 目录前言:详细操作:步骤一:点击上方的 文件(File),点击&nbmyHIgsp;设置(Setti

基于Flask框架添加多个AI模型的API并进行交互

《基于Flask框架添加多个AI模型的API并进行交互》:本文主要介绍如何基于Flask框架开发AI模型API管理系统,允许用户添加、删除不同AI模型的API密钥,感兴趣的可以了解下... 目录1. 概述2. 后端代码说明2.1 依赖库导入2.2 应用初始化2.3 API 存储字典2.4 路由函数2.5 应

GORM中Model和Table的区别及使用

《GORM中Model和Table的区别及使用》Model和Table是两种与数据库表交互的核心方法,但它们的用途和行为存在著差异,本文主要介绍了GORM中Model和Table的区别及使用,具有一... 目录1. Model 的作用与特点1.1 核心用途1.2 行为特点1.3 示例China编程代码2. Tab

一文教你Python引入其他文件夹下的.py文件

《一文教你Python引入其他文件夹下的.py文件》这篇文章主要为大家详细介绍了如何在Python中引入其他文件夹里的.py文件,并探讨几种常见的实现方式,有需要的小伙伴可以根据需求进行选择... 目录1. 使用sys.path动态添加路径2. 使用相对导入(适用于包结构)3. 使用pythonPATH环境

Spring 中使用反射创建 Bean 实例的几种方式

《Spring中使用反射创建Bean实例的几种方式》文章介绍了在Spring框架中如何使用反射来创建Bean实例,包括使用Class.newInstance()、Constructor.newI... 目录1. 使用 Class.newInstance() (仅限无参构造函数):2. 使用 Construc

Java导入、导出excel用法步骤保姆级教程(附封装好的工具类)

《Java导入、导出excel用法步骤保姆级教程(附封装好的工具类)》:本文主要介绍Java导入、导出excel的相关资料,讲解了使用Java和ApachePOI库将数据导出为Excel文件,包括... 目录前言一、引入Apache POI依赖二、用法&步骤2.1 创建Excel的元素2.3 样式和字体2.

C#原型模式之如何通过克隆对象来优化创建过程

《C#原型模式之如何通过克隆对象来优化创建过程》原型模式是一种创建型设计模式,通过克隆现有对象来创建新对象,避免重复的创建成本和复杂的初始化过程,它适用于对象创建过程复杂、需要大量相似对象或避免重复初... 目录什么是原型模式?原型模式的工作原理C#中如何实现原型模式?1. 定义原型接口2. 实现原型接口3

浅析Python中的绝对导入与相对导入

《浅析Python中的绝对导入与相对导入》这篇文章主要为大家详细介绍了Python中的绝对导入与相对导入的相关知识,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1 Imports快速介绍2 import语句的语法2.1 基本使用2.2 导入声明的样式3 绝对import和相对i

C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)

《C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)》本文主要介绍了C#集成DeepSeek模型实现AI私有化的方法,包括搭建基础环境,如安装Ollama和下载DeepS... 目录前言搭建基础环境1、安装 Ollama2、下载 DeepSeek R1 模型客户端 ChatBo