【CenterFusion】模型的创建、导入、保存CenterFusion/src/lib/model/model.py

2024-03-17 23:52

本文主要是介绍【CenterFusion】模型的创建、导入、保存CenterFusion/src/lib/model/model.py,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文件内容:CenterFusion/src/lib/model/model.py
文件作用:模型的创建、导入、保存

model.py 具体内容如下:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport torchvision.models as models
import torch
import torch.nn as nn
import osfrom .networks.dla import DLASeg
from .networks.resdcn import PoseResDCN
from .networks.resnet import PoseResNet
from .networks.dlav0 import DLASegv0
from .networks.generic_network import GenericNetwork_network_factory = {'resdcn': PoseResDCN,'dla': DLASeg,'res': PoseResNet,'dlav0': DLASegv0,'generic': GenericNetwork
}def create_model(arch, head, head_conv, opt=None):num_layers = int(arch[arch.find('_') + 1:]) if '_' in arch else 0'''处理字符串 arch = dla_34 ,将下划线后半部分取出最后 num_layers = 34'''arch = arch[:arch.find('_')] if '_' in arch else arch'''将 arch = dla_34 中下划线前半部分取出最后 arch = 'dla''''model_class = _network_factory[arch]'''根据 arch = 'dla' 获取 _network_factory 中的值最后 model_class = DLASegDLASeg 类定义在 CenterFusion/src/lib/model/networks/dla.py 第 594 行'''model = model_class(num_layers, heads=head, head_convs=head_conv, opt=opt)'''配置模型'''return modeldef load_model(model, model_path, opt, optimizer=None):start_epoch = 0'''设定初始轮次 = 0'''checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch']))'''torch.load() 函数:用来加载 torch.save() 保存的模型文件'''state_dict_ = checkpoint['state_dict']'''获取 checkpoint 模型文件中的 state_dict 属性这个属性存放训练过程中需要学习的权重和偏执系数state_dict 作为 python 的字典对象将每一层的参数映射成 tensor 张量需要注意的是 torch.nn.Module 模块中的 state_dict 只包含卷积层和全连接层的参数'''state_dict = {}for k in state_dict_:if k.startswith('module') and not k.startswith('module_list'):state_dict[k[7:]] = state_dict_[k]else:state_dict[k] = state_dict_[k]'''startswith(str) 函数:检测字符串 str,检测到返回 True,否则返回 False这里只执行了 else 语句,相当于保存导入模型的网络参数'''model_state_dict = model.state_dict()'''浅拷贝 main.py 中创建的新模型 DLA 的网络参数'''for k in state_dict:'''遍历导入的模型中的每层网络参数'''if k in model_state_dict:'''判断新模型的网络参数中是否有导入的模型的参数是有的,因为导入的模型也是 DLA 模型'''if (state_dict[k].shape != model_state_dict[k].shape) or \(opt.reset_hm and k.startswith('hm') and (state_dict[k].shape[0] in [80, 1])):'''第一个条件为 True其余条件全部为 False'''if opt.reuse_hm:'''不执行'''print('Reusing parameter {}, required shape{}, '\'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape))# todo: bug in next line: both sides of < are the sameif state_dict[k].shape[0] < state_dict[k].shape[0]:model_state_dict[k][:state_dict[k].shape[0]] = state_dict[k]else:model_state_dict[k] = state_dict[k][:model_state_dict[k].shape[0]]state_dict[k] = model_state_dict[k]elif opt.warm_start_weights:'''不执行'''try:print('Partially loading parameter {}, required shape{}, '\'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape))if state_dict[k].shape[1] < model_state_dict[k].shape[1]:model_state_dict[k][:,:state_dict[k].shape[1]] = state_dict[k]else:model_state_dict[k] = state_dict[k][:,:model_state_dict[k].shape[1]]state_dict[k] = model_state_dict[k]except:print('Skip loading parameter {}, required shape{}, '\'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape))state_dict[k] = model_state_dict[k]else:'''执行该 else 中的语句'''print('Skip loading parameter {}, required shape{}, '\'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape))state_dict[k] = model_state_dict[k]'''将新模型的网络参数赋值给导入的模型中'''else:print('Drop parameter {}.'.format(k))for k in model_state_dict:if not (k in state_dict):print('No param {}.'.format(k))state_dict[k] = model_state_dict[k]'''给导入的模型添加没有的参数'''model.load_state_dict(state_dict, strict=False)'''使用 state_dict 反序列化模型参数字字典,用来加载模型参数将 state_dict 中的 parameters 和 buffers 复制到此 module 及其子节点中简述:给模型对象加载训练好的模型参数,即加载模型参数'''#冻结骨干网,没有执行if opt.freeze_backbone:for (name, module) in model.named_children():if name in opt.layers_to_freeze:for (name, layer) in module.named_children():for param in layer.parameters():param.requires_grad = False# 恢复优化器参数,没有执行if optimizer is not None and opt.resume:if 'optimizer' in checkpoint:start_epoch = checkpoint['epoch']start_lr = opt.lrfor step in opt.lr_step:if start_epoch >= step:start_lr *= 0.1for param_group in optimizer.param_groups:param_group['lr'] = start_lrprint('Resumed optimizer with start lr', start_lr)else:print('No optimizer parameters in checkpoint.')if optimizer is not None:'''执行该 if 语句'''return model, optimizer, start_epochelse:return modeldef save_model(path, epoch, model, optimizer=None):if isinstance(model, torch.nn.DataParallel):'''isinstance(object, classinfo) 判断一个函数 object 是否是一个已知的类型 classinfo是则返回 True,反之返回 False'''state_dict = model.module.state_dict()else:state_dict = model.state_dict()'''获取模型的参数矩阵'''data = {'epoch': epoch,'state_dict': state_dict}if not (optimizer is None):data['optimizer'] = optimizer.state_dict()'''获取模型的优化器'''torch.save(data, path)'''保存模型'''

这篇关于【CenterFusion】模型的创建、导入、保存CenterFusion/src/lib/model/model.py的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/820587

相关文章

DeepSeek模型本地部署的详细教程

《DeepSeek模型本地部署的详细教程》DeepSeek作为一款开源且性能强大的大语言模型,提供了灵活的本地部署方案,让用户能够在本地环境中高效运行模型,同时保护数据隐私,在本地成功部署DeepSe... 目录一、环境准备(一)硬件需求(二)软件依赖二、安装Ollama三、下载并部署DeepSeek模型选

解决IDEA使用springBoot创建项目,lombok标注实体类后编译无报错,但是运行时报错问题

《解决IDEA使用springBoot创建项目,lombok标注实体类后编译无报错,但是运行时报错问题》文章详细描述了在使用lombok的@Data注解标注实体类时遇到编译无误但运行时报错的问题,分析... 目录问题分析问题解决方案步骤一步骤二步骤三总结问题使用lombok注解@Data标注实体类,编译时

vscode保存代码时自动eslint格式化图文教程

《vscode保存代码时自动eslint格式化图文教程》:本文主要介绍vscode保存代码时自动eslint格式化的相关资料,包括打开设置文件并复制特定内容,文中通过代码介绍的非常详细,需要的朋友... 目录1、点击设置2、选择远程--->点击右上角打开设置3、会弹出settings.json文件,将以下内

MySQL分表自动化创建的实现方案

《MySQL分表自动化创建的实现方案》在数据库应用场景中,随着数据量的不断增长,单表存储数据可能会面临性能瓶颈,例如查询、插入、更新等操作的效率会逐渐降低,分表是一种有效的优化策略,它将数据分散存储在... 目录一、项目目的二、实现过程(一)mysql 事件调度器结合存储过程方式1. 开启事件调度器2. 创

mysql外键创建不成功/失效如何处理

《mysql外键创建不成功/失效如何处理》文章介绍了在MySQL5.5.40版本中,创建带有外键约束的`stu`和`grade`表时遇到的问题,发现`grade`表的`id`字段没有随着`studen... 当前mysql版本:SELECT VERSION();结果为:5.5.40。在复习mysql外键约

Python调用另一个py文件并传递参数常见的方法及其应用场景

《Python调用另一个py文件并传递参数常见的方法及其应用场景》:本文主要介绍在Python中调用另一个py文件并传递参数的几种常见方法,包括使用import语句、exec函数、subproce... 目录前言1. 使用import语句1.1 基本用法1.2 导入特定函数1.3 处理文件路径2. 使用ex

Window Server创建2台服务器的故障转移群集的图文教程

《WindowServer创建2台服务器的故障转移群集的图文教程》本文主要介绍了在WindowsServer系统上创建一个包含两台成员服务器的故障转移群集,文中通过图文示例介绍的非常详细,对大家的... 目录一、 准备条件二、在ServerB安装故障转移群集三、在ServerC安装故障转移群集,操作与Ser

Window Server2016 AD域的创建的方法步骤

《WindowServer2016AD域的创建的方法步骤》本文主要介绍了WindowServer2016AD域的创建的方法步骤,文中通过图文介绍的非常详细,对大家的学习或者工作具有一定的参考学习价... 目录一、准备条件二、在ServerA服务器中常见AD域管理器:三、创建AD域,域地址为“test.ly”

Golang的CSP模型简介(最新推荐)

《Golang的CSP模型简介(最新推荐)》Golang采用了CSP(CommunicatingSequentialProcesses,通信顺序进程)并发模型,通过goroutine和channe... 目录前言一、介绍1. 什么是 CSP 模型2. Goroutine3. Channel4. Channe

Python数据处理之导入导出Excel数据方式

《Python数据处理之导入导出Excel数据方式》Python是Excel数据处理的绝佳工具,通过Pandas和Openpyxl等库可以实现数据的导入、导出和自动化处理,从基础的数据读取和清洗到复杂... 目录python导入导出Excel数据开启数据之旅:为什么Python是Excel数据处理的最佳拍档