【CenterFusion】run_epoch()函数-训练一轮epoch-CenterFusion/src/lib/trainer.py

2024-03-18 02:12

本文主要是介绍【CenterFusion】run_epoch()函数-训练一轮epoch-CenterFusion/src/lib/trainer.py,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文件位置:CenterFusion/src/lib/trainer.py
run_epoch作用:CenterFusion 项目训练一轮epoch过程

  • 在 main.py 函数中,生成了训练器,然后再使用训练器训练一个 epoch
  • run_epoch()函数的定义在src\lib\trainer.py150行左右,它的主要过程如下所示:
  def run_epoch(self, phase, epoch, data_loader):model_with_loss = self.model_with_loss'''self.model_with_loss 是 ModelWithLoss 类,这个类又继承 torch.nn.Module 类'''if phase == 'train':model_with_loss.train()'''启用 Batch Normalization 和 Dropout如果模型中有 BN 层 (Batch Normalization) 和 Dropout需要在训练时添加 model.train()model.train()是保证 BN 层能够用到每一批数据的均值和方差对于 Dropout,model.train() 是随机取一部分网络连接来训练更新参数'''else:if len(self.opt.gpus) > 1:model_with_loss = self.model_with_loss.modulemodel_with_loss.eval()'''不启用 Batch Normalization 和 Dropout如果模型中有 BN 层 (Batch Normalization) 和Dropout在测试时添加 model.eval()model.eval() 是保证 BN 层能够用全部训练数据的均值和方差即测试过程中要保证 BN 层的均值和方差不变对于 Dropout,model.eval() 是利用到了所有网络连接,即不进行随机舍弃神经元。'''torch.cuda.empty_cache()'''释放空间'''opt = self.optresults = {}data_time, batch_time = AverageMeter(), AverageMeter()'''新建两个 AverageMeter 对象'''avg_loss_stats = {l: AverageMeter() for l in self.loss_stats \if l == 'tot' or opt.weights[l] > 0}'''为 loss 列表的每个属性赋值一个 AverageMeter 对象'''num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters'''获取数据长度'''bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)end = time.time()'''设置进度条'''for iter_id, batch in enumerate(data_loader):if iter_id >= num_iters:break'''遍历完'''data_time.update(time.time() - end)'''更新 data_time 的值'''for k in batch:if k != 'meta':batch[k] = batch[k].to(device=opt.device, non_blocking=True)'''这里的 batch 是一个 Tensor 对象将其配置到 gpu 上'''output, loss, loss_stats = model_with_loss(batch, phase)'''运行第一阶段(模型训练)'''# backpropagate and step optimizer 反向传播和步进优化器loss = loss.mean()'''求每一层损失值的平均值'''if phase == 'train':self.optimizer.zero_grad()'''将模型的参数梯度初始化为0'''loss.backward()'''反向传播计算梯度'''self.optimizer.step()'''更新所有参数''''''根据 pytorch 中 backward() 函数的计算当网络参量进行反馈时,梯度是累积计算而不是被替换但在处理每一个 batch 时并不需要与其他 batch的梯度混合起来累积计算因此需要对每个 batch 调用一遍 zero_grad() 将参数梯度置 0.'''batch_time.update(time.time() - end)'''更新 batch_time 的值'''end = time.time()Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(epoch, iter_id, num_iters, phase=phase,total=bar.elapsed_td, eta=bar.eta_td)'''bar.elapsed_td : 经过的时间增量eta=bar.eta_td : 时间间隔'''for l in avg_loss_stats:avg_loss_stats[l].update(loss_stats[l].mean().item(), batch['image'].size(0))Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg)'''更新平均损失'''Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \'|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)if opt.print_iter > 0:if iter_id % opt.print_iter == 0:print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) else:bar.next()'''opt.print_iter = 0 执行 else 语句,显示进度条'''if opt.debug > 0:self.debug(batch, output, iter_id, dataset=data_loader.dataset)'''debug 默认为 0,没有执行 if 语句'''if (phase == 'val' and (opt.run_dataset_eval or opt.eval)):meta = batch['meta']dets = fusion_decode(output, K=opt.K, opt=opt)'''解码器和雷达点云融合调用的这个函数位于 CenterFusion\src\lib\model\decode.py 中这个函数具体实现的功能就是将前面模型训练得到的结果,也就是一些特征图,这些特征图为多维矩阵将特征图与毫米波雷达点云进行映射,映射过程就是将特征图进行维度转换、升维等操作,然后再点乘旋转矩阵'''for k in dets:dets[k] = dets[k].detach().cpu().numpy()'''detach() 阻断反向传播,返回值仍为 tensorcpu() 将变量放在 cpu 上,仍为 tensornumpy() 将 tensor 转换为 numpy'''calib = meta['calib'].detach().numpy() if 'calib' in meta else Nonedets = generic_post_process(opt, dets, meta['c'].cpu().numpy(), meta['s'].cpu().numpy(),output['hm'].shape[2], output['hm'].shape[3], self.opt.num_classes,calib)result = []for i in range(len(dets[0])):if dets[0][i]['score'] > self.opt.out_thresh and all(dets[0][i]['dim'] > 0):result.append(dets[0][i])'''筛选结果'''img_id = batch['meta']['img_id'].numpy().astype(np.int32)[0]'''强制类型转换图片 id'''results[img_id] = resultdel output, loss, loss_statsbar.finish()ret = {k: v.avg for k, v in avg_loss_stats.items()}'''平均损失结果'''ret['time'] = bar.elapsed_td.total_seconds() / 60.return ret, results

这篇关于【CenterFusion】run_epoch()函数-训练一轮epoch-CenterFusion/src/lib/trainer.py的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/820868

相关文章

Python的time模块一些常用功能(各种与时间相关的函数)

《Python的time模块一些常用功能(各种与时间相关的函数)》Python的time模块提供了各种与时间相关的函数,包括获取当前时间、处理时间间隔、执行时间测量等,:本文主要介绍Python的... 目录1. 获取当前时间2. 时间格式化3. 延时执行4. 时间戳运算5. 计算代码执行时间6. 转换为指

Python正则表达式语法及re模块中的常用函数详解

《Python正则表达式语法及re模块中的常用函数详解》这篇文章主要给大家介绍了关于Python正则表达式语法及re模块中常用函数的相关资料,正则表达式是一种强大的字符串处理工具,可以用于匹配、切分、... 目录概念、作用和步骤语法re模块中的常用函数总结 概念、作用和步骤概念: 本身也是一个字符串,其中

一文带你搞懂Python中__init__.py到底是什么

《一文带你搞懂Python中__init__.py到底是什么》朋友们,今天我们来聊聊Python里一个低调却至关重要的文件——__init__.py,有些人可能听说过它是“包的标志”,也有人觉得它“没... 目录先搞懂 python 模块(module)Python 包(package)是啥?那么 __in

shell编程之函数与数组的使用详解

《shell编程之函数与数组的使用详解》:本文主要介绍shell编程之函数与数组的使用,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录shell函数函数的用法俩个数求和系统资源监控并报警函数函数变量的作用范围函数的参数递归函数shell数组获取数组的长度读取某下的

MySQL高级查询之JOIN、子查询、窗口函数实际案例

《MySQL高级查询之JOIN、子查询、窗口函数实际案例》:本文主要介绍MySQL高级查询之JOIN、子查询、窗口函数实际案例的相关资料,JOIN用于多表关联查询,子查询用于数据筛选和过滤,窗口函... 目录前言1. JOIN(连接查询)1.1 内连接(INNER JOIN)1.2 左连接(LEFT JOI

MySQL中FIND_IN_SET函数与INSTR函数用法解析

《MySQL中FIND_IN_SET函数与INSTR函数用法解析》:本文主要介绍MySQL中FIND_IN_SET函数与INSTR函数用法解析,本文通过实例代码给大家介绍的非常详细,感兴趣的朋友一... 目录一、功能定义与语法1、FIND_IN_SET函数2、INSTR函数二、本质区别对比三、实际场景案例分

C++ Sort函数使用场景分析

《C++Sort函数使用场景分析》sort函数是algorithm库下的一个函数,sort函数是不稳定的,即大小相同的元素在排序后相对顺序可能发生改变,如果某些场景需要保持相同元素间的相对顺序,可使... 目录C++ Sort函数详解一、sort函数调用的两种方式二、sort函数使用场景三、sort函数排序

C语言函数递归实际应用举例详解

《C语言函数递归实际应用举例详解》程序调用自身的编程技巧称为递归,递归做为一种算法在程序设计语言中广泛应用,:本文主要介绍C语言函数递归实际应用举例的相关资料,文中通过代码介绍的非常详细,需要的朋... 目录前言一、递归的概念与思想二、递归的限制条件 三、递归的实际应用举例(一)求 n 的阶乘(二)顺序打印

C/C++错误信息处理的常见方法及函数

《C/C++错误信息处理的常见方法及函数》C/C++是两种广泛使用的编程语言,特别是在系统编程、嵌入式开发以及高性能计算领域,:本文主要介绍C/C++错误信息处理的常见方法及函数,文中通过代码介绍... 目录前言1. errno 和 perror()示例:2. strerror()示例:3. perror(

Kotlin 作用域函数apply、let、run、with、also使用指南

《Kotlin作用域函数apply、let、run、with、also使用指南》在Kotlin开发中,作用域函数(ScopeFunctions)是一组能让代码更简洁、更函数式的高阶函数,本文将... 目录一、引言:为什么需要作用域函数?二、作用域函China编程数详解1. apply:对象配置的 “流式构建器”最