【CenterFusion】run_epoch()函数-训练一轮epoch-CenterFusion/src/lib/trainer.py

2024-03-18 02:12

本文主要是介绍【CenterFusion】run_epoch()函数-训练一轮epoch-CenterFusion/src/lib/trainer.py,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文件位置:CenterFusion/src/lib/trainer.py
run_epoch作用:CenterFusion 项目训练一轮epoch过程

  • 在 main.py 函数中,生成了训练器,然后再使用训练器训练一个 epoch
  • run_epoch()函数的定义在src\lib\trainer.py150行左右,它的主要过程如下所示:
  def run_epoch(self, phase, epoch, data_loader):model_with_loss = self.model_with_loss'''self.model_with_loss 是 ModelWithLoss 类,这个类又继承 torch.nn.Module 类'''if phase == 'train':model_with_loss.train()'''启用 Batch Normalization 和 Dropout如果模型中有 BN 层 (Batch Normalization) 和 Dropout需要在训练时添加 model.train()model.train()是保证 BN 层能够用到每一批数据的均值和方差对于 Dropout,model.train() 是随机取一部分网络连接来训练更新参数'''else:if len(self.opt.gpus) > 1:model_with_loss = self.model_with_loss.modulemodel_with_loss.eval()'''不启用 Batch Normalization 和 Dropout如果模型中有 BN 层 (Batch Normalization) 和Dropout在测试时添加 model.eval()model.eval() 是保证 BN 层能够用全部训练数据的均值和方差即测试过程中要保证 BN 层的均值和方差不变对于 Dropout,model.eval() 是利用到了所有网络连接,即不进行随机舍弃神经元。'''torch.cuda.empty_cache()'''释放空间'''opt = self.optresults = {}data_time, batch_time = AverageMeter(), AverageMeter()'''新建两个 AverageMeter 对象'''avg_loss_stats = {l: AverageMeter() for l in self.loss_stats \if l == 'tot' or opt.weights[l] > 0}'''为 loss 列表的每个属性赋值一个 AverageMeter 对象'''num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters'''获取数据长度'''bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)end = time.time()'''设置进度条'''for iter_id, batch in enumerate(data_loader):if iter_id >= num_iters:break'''遍历完'''data_time.update(time.time() - end)'''更新 data_time 的值'''for k in batch:if k != 'meta':batch[k] = batch[k].to(device=opt.device, non_blocking=True)'''这里的 batch 是一个 Tensor 对象将其配置到 gpu 上'''output, loss, loss_stats = model_with_loss(batch, phase)'''运行第一阶段(模型训练)'''# backpropagate and step optimizer 反向传播和步进优化器loss = loss.mean()'''求每一层损失值的平均值'''if phase == 'train':self.optimizer.zero_grad()'''将模型的参数梯度初始化为0'''loss.backward()'''反向传播计算梯度'''self.optimizer.step()'''更新所有参数''''''根据 pytorch 中 backward() 函数的计算当网络参量进行反馈时,梯度是累积计算而不是被替换但在处理每一个 batch 时并不需要与其他 batch的梯度混合起来累积计算因此需要对每个 batch 调用一遍 zero_grad() 将参数梯度置 0.'''batch_time.update(time.time() - end)'''更新 batch_time 的值'''end = time.time()Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(epoch, iter_id, num_iters, phase=phase,total=bar.elapsed_td, eta=bar.eta_td)'''bar.elapsed_td : 经过的时间增量eta=bar.eta_td : 时间间隔'''for l in avg_loss_stats:avg_loss_stats[l].update(loss_stats[l].mean().item(), batch['image'].size(0))Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg)'''更新平均损失'''Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \'|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)if opt.print_iter > 0:if iter_id % opt.print_iter == 0:print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) else:bar.next()'''opt.print_iter = 0 执行 else 语句,显示进度条'''if opt.debug > 0:self.debug(batch, output, iter_id, dataset=data_loader.dataset)'''debug 默认为 0,没有执行 if 语句'''if (phase == 'val' and (opt.run_dataset_eval or opt.eval)):meta = batch['meta']dets = fusion_decode(output, K=opt.K, opt=opt)'''解码器和雷达点云融合调用的这个函数位于 CenterFusion\src\lib\model\decode.py 中这个函数具体实现的功能就是将前面模型训练得到的结果,也就是一些特征图,这些特征图为多维矩阵将特征图与毫米波雷达点云进行映射,映射过程就是将特征图进行维度转换、升维等操作,然后再点乘旋转矩阵'''for k in dets:dets[k] = dets[k].detach().cpu().numpy()'''detach() 阻断反向传播,返回值仍为 tensorcpu() 将变量放在 cpu 上,仍为 tensornumpy() 将 tensor 转换为 numpy'''calib = meta['calib'].detach().numpy() if 'calib' in meta else Nonedets = generic_post_process(opt, dets, meta['c'].cpu().numpy(), meta['s'].cpu().numpy(),output['hm'].shape[2], output['hm'].shape[3], self.opt.num_classes,calib)result = []for i in range(len(dets[0])):if dets[0][i]['score'] > self.opt.out_thresh and all(dets[0][i]['dim'] > 0):result.append(dets[0][i])'''筛选结果'''img_id = batch['meta']['img_id'].numpy().astype(np.int32)[0]'''强制类型转换图片 id'''results[img_id] = resultdel output, loss, loss_statsbar.finish()ret = {k: v.avg for k, v in avg_loss_stats.items()}'''平均损失结果'''ret['time'] = bar.elapsed_td.total_seconds() / 60.return ret, results

这篇关于【CenterFusion】run_epoch()函数-训练一轮epoch-CenterFusion/src/lib/trainer.py的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/820868

相关文章

python: 多模块(.py)中全局变量的导入

文章目录 global关键字可变类型和不可变类型数据的内存地址单模块(单个py文件)的全局变量示例总结 多模块(多个py文件)的全局变量from x import x导入全局变量示例 import x导入全局变量示例 总结 global关键字 global 的作用范围是模块(.py)级别: 当你在一个模块(文件)中使用 global 声明变量时,这个变量只在该模块的全局命名空

hdu1171(母函数或多重背包)

题意:把物品分成两份,使得价值最接近 可以用背包,或者是母函数来解,母函数(1 + x^v+x^2v+.....+x^num*v)(1 + x^v+x^2v+.....+x^num*v)(1 + x^v+x^2v+.....+x^num*v) 其中指数为价值,每一项的数目为(该物品数+1)个 代码如下: #include<iostream>#include<algorithm>

C++操作符重载实例(独立函数)

C++操作符重载实例,我们把坐标值CVector的加法进行重载,计算c3=c1+c2时,也就是计算x3=x1+x2,y3=y1+y2,今天我们以独立函数的方式重载操作符+(加号),以下是C++代码: c1802.cpp源代码: D:\YcjWork\CppTour>vim c1802.cpp #include <iostream>using namespace std;/*** 以独立函数

函数式编程思想

我们经常会用到各种各样的编程思想,例如面向过程、面向对象。不过笔者在该博客简单介绍一下函数式编程思想. 如果对函数式编程思想进行概括,就是f(x) = na(x) , y=uf(x)…至于其他的编程思想,可能是y=a(x)+b(x)+c(x)…,也有可能是y=f(x)=f(x)/a + f(x)/b+f(x)/c… 面向过程的指令式编程 面向过程,简单理解就是y=a(x)+b(x)+c(x)

MiniGPT-3D, 首个高效的3D点云大语言模型,仅需一张RTX3090显卡,训练一天时间,已开源

项目主页:https://tangyuan96.github.io/minigpt_3d_project_page/ 代码:https://github.com/TangYuan96/MiniGPT-3D 论文:https://arxiv.org/pdf/2405.01413 MiniGPT-3D在多个任务上取得了SoTA,被ACM MM2024接收,只拥有47.8M的可训练参数,在一张RTX

利用matlab bar函数绘制较为复杂的柱状图,并在图中进行适当标注

示例代码和结果如下:小疑问:如何自动选择合适的坐标位置对柱状图的数值大小进行标注?😂 clear; close all;x = 1:3;aa=[28.6321521955954 26.2453660695847 21.69102348512086.93747104431360 6.25442246899816 3.342835958564245.51365061796319 4.87

OpenCV结构分析与形状描述符(11)椭圆拟合函数fitEllipse()的使用

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C++11 算法描述 围绕一组2D点拟合一个椭圆。 该函数计算出一个椭圆,该椭圆在最小二乘意义上最好地拟合一组2D点。它返回一个内切椭圆的旋转矩形。使用了由[90]描述的第一个算法。开发者应该注意,由于数据点靠近包含的 Mat 元素的边界,返回的椭圆/旋转矩形数据

Spark MLlib模型训练—聚类算法 PIC(Power Iteration Clustering)

Spark MLlib模型训练—聚类算法 PIC(Power Iteration Clustering) Power Iteration Clustering (PIC) 是一种基于图的聚类算法,用于在大规模数据集上进行高效的社区检测。PIC 算法的核心思想是通过迭代图的幂运算来发现数据中的潜在簇。该算法适用于处理大规模图数据,特别是在社交网络分析、推荐系统和生物信息学等领域具有广泛应用。Spa

script中的src

<script src="http://www.somewhere.com/afile.js"></script> 浏览器在解析这个资源时,会向 src 属性指定的路径发送一个 GET 请求,以取得相应资源,假定 是一个 JavaScript 文件。这个初始的请求不受浏览器同源策略限制,但返回并被执行的 JavaScript 则受限制。 当然,这个请求仍然受父页面 HTTP/HTTPS

Unity3D 运动之Move函数和translate

CharacterController.Move 移动 function Move (motion : Vector3) : CollisionFlags Description描述 A more complex move function taking absolute movement deltas. 一个更加复杂的运动函数,每次都绝对运动。 Attempts to