本文主要是介绍【CenterFusion】模型的创建、导入、保存CenterFusion/src/lib/model/model.py,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
文件内容:CenterFusion/src/lib/model/model.py
文件作用:模型的创建、导入、保存
model.py 具体内容如下:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport torchvision.models as models
import torch
import torch.nn as nn
import osfrom .networks.dla import DLASeg
from .networks.resdcn import PoseResDCN
from .networks.resnet import PoseResNet
from .networks.dlav0 import DLASegv0
from .networks.generic_network import GenericNetwork_network_factory = {'resdcn': PoseResDCN,'dla': DLASeg,'res': PoseResNet,'dlav0': DLASegv0,'generic': GenericNetwork
}def create_model(arch, head, head_conv, opt=None):num_layers = int(arch[arch.find('_') + 1:]) if '_' in arch else 0'''处理字符串 arch = dla_34 ,将下划线后半部分取出最后 num_layers = 34'''arch = arch[:arch.find('_')] if '_' in arch else arch'''将 arch = dla_34 中下划线前半部分取出最后 arch = 'dla''''model_class = _network_factory[arch]'''根据 arch = 'dla' 获取 _network_factory 中的值最后 model_class = DLASegDLASeg 类定义在 CenterFusion/src/lib/model/networks/dla.py 第 594 行'''model = model_class(num_layers, heads=head, head_convs=head_conv, opt=opt)'''配置模型'''return modeldef load_model(model, model_path, opt, optimizer=None):start_epoch = 0'''设定初始轮次 = 0'''checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch']))'''torch.load() 函数:用来加载 torch.save() 保存的模型文件'''state_dict_ = checkpoint['state_dict']'''获取 checkpoint 模型文件中的 state_dict 属性这个属性存放训练过程中需要学习的权重和偏执系数state_dict 作为 python 的字典对象将每一层的参数映射成 tensor 张量需要注意的是 torch.nn.Module 模块中的 state_dict 只包含卷积层和全连接层的参数'''state_dict = {}for k in state_dict_:if k.startswith('module') and not k.startswith('module_list'):state_dict[k[7:]] = state_dict_[k]else:state_dict[k] = state_dict_[k]'''startswith(str) 函数:检测字符串 str,检测到返回 True,否则返回 False这里只执行了 else 语句,相当于保存导入模型的网络参数'''model_state_dict = model.state_dict()'''浅拷贝 main.py 中创建的新模型 DLA 的网络参数'''for k in state_dict:'''遍历导入的模型中的每层网络参数'''if k in model_state_dict:'''判断新模型的网络参数中是否有导入的模型的参数是有的,因为导入的模型也是 DLA 模型'''if (state_dict[k].shape != model_state_dict[k].shape) or \(opt.reset_hm and k.startswith('hm') and (state_dict[k].shape[0] in [80, 1])):'''第一个条件为 True其余条件全部为 False'''if opt.reuse_hm:'''不执行'''print('Reusing parameter {}, required shape{}, '\'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape))# todo: bug in next line: both sides of < are the sameif state_dict[k].shape[0] < state_dict[k].shape[0]:model_state_dict[k][:state_dict[k].shape[0]] = state_dict[k]else:model_state_dict[k] = state_dict[k][:model_state_dict[k].shape[0]]state_dict[k] = model_state_dict[k]elif opt.warm_start_weights:'''不执行'''try:print('Partially loading parameter {}, required shape{}, '\'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape))if state_dict[k].shape[1] < model_state_dict[k].shape[1]:model_state_dict[k][:,:state_dict[k].shape[1]] = state_dict[k]else:model_state_dict[k] = state_dict[k][:,:model_state_dict[k].shape[1]]state_dict[k] = model_state_dict[k]except:print('Skip loading parameter {}, required shape{}, '\'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape))state_dict[k] = model_state_dict[k]else:'''执行该 else 中的语句'''print('Skip loading parameter {}, required shape{}, '\'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape))state_dict[k] = model_state_dict[k]'''将新模型的网络参数赋值给导入的模型中'''else:print('Drop parameter {}.'.format(k))for k in model_state_dict:if not (k in state_dict):print('No param {}.'.format(k))state_dict[k] = model_state_dict[k]'''给导入的模型添加没有的参数'''model.load_state_dict(state_dict, strict=False)'''使用 state_dict 反序列化模型参数字字典,用来加载模型参数将 state_dict 中的 parameters 和 buffers 复制到此 module 及其子节点中简述:给模型对象加载训练好的模型参数,即加载模型参数'''#冻结骨干网,没有执行if opt.freeze_backbone:for (name, module) in model.named_children():if name in opt.layers_to_freeze:for (name, layer) in module.named_children():for param in layer.parameters():param.requires_grad = False# 恢复优化器参数,没有执行if optimizer is not None and opt.resume:if 'optimizer' in checkpoint:start_epoch = checkpoint['epoch']start_lr = opt.lrfor step in opt.lr_step:if start_epoch >= step:start_lr *= 0.1for param_group in optimizer.param_groups:param_group['lr'] = start_lrprint('Resumed optimizer with start lr', start_lr)else:print('No optimizer parameters in checkpoint.')if optimizer is not None:'''执行该 if 语句'''return model, optimizer, start_epochelse:return modeldef save_model(path, epoch, model, optimizer=None):if isinstance(model, torch.nn.DataParallel):'''isinstance(object, classinfo) 判断一个函数 object 是否是一个已知的类型 classinfo是则返回 True,反之返回 False'''state_dict = model.module.state_dict()else:state_dict = model.state_dict()'''获取模型的参数矩阵'''data = {'epoch': epoch,'state_dict': state_dict}if not (optimizer is None):data['optimizer'] = optimizer.state_dict()'''获取模型的优化器'''torch.save(data, path)'''保存模型'''
这篇关于【CenterFusion】模型的创建、导入、保存CenterFusion/src/lib/model/model.py的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!