【CenterFusion】run_epoch()函数-训练一轮epoch-CenterFusion/src/lib/trainer.py

2024-03-18 02:12

本文主要是介绍【CenterFusion】run_epoch()函数-训练一轮epoch-CenterFusion/src/lib/trainer.py,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文件位置:CenterFusion/src/lib/trainer.py
run_epoch作用:CenterFusion 项目训练一轮epoch过程

  • 在 main.py 函数中,生成了训练器,然后再使用训练器训练一个 epoch
  • run_epoch()函数的定义在src\lib\trainer.py150行左右,它的主要过程如下所示:
  def run_epoch(self, phase, epoch, data_loader):model_with_loss = self.model_with_loss'''self.model_with_loss 是 ModelWithLoss 类,这个类又继承 torch.nn.Module 类'''if phase == 'train':model_with_loss.train()'''启用 Batch Normalization 和 Dropout如果模型中有 BN 层 (Batch Normalization) 和 Dropout需要在训练时添加 model.train()model.train()是保证 BN 层能够用到每一批数据的均值和方差对于 Dropout,model.train() 是随机取一部分网络连接来训练更新参数'''else:if len(self.opt.gpus) > 1:model_with_loss = self.model_with_loss.modulemodel_with_loss.eval()'''不启用 Batch Normalization 和 Dropout如果模型中有 BN 层 (Batch Normalization) 和Dropout在测试时添加 model.eval()model.eval() 是保证 BN 层能够用全部训练数据的均值和方差即测试过程中要保证 BN 层的均值和方差不变对于 Dropout,model.eval() 是利用到了所有网络连接,即不进行随机舍弃神经元。'''torch.cuda.empty_cache()'''释放空间'''opt = self.optresults = {}data_time, batch_time = AverageMeter(), AverageMeter()'''新建两个 AverageMeter 对象'''avg_loss_stats = {l: AverageMeter() for l in self.loss_stats \if l == 'tot' or opt.weights[l] > 0}'''为 loss 列表的每个属性赋值一个 AverageMeter 对象'''num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters'''获取数据长度'''bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)end = time.time()'''设置进度条'''for iter_id, batch in enumerate(data_loader):if iter_id >= num_iters:break'''遍历完'''data_time.update(time.time() - end)'''更新 data_time 的值'''for k in batch:if k != 'meta':batch[k] = batch[k].to(device=opt.device, non_blocking=True)'''这里的 batch 是一个 Tensor 对象将其配置到 gpu 上'''output, loss, loss_stats = model_with_loss(batch, phase)'''运行第一阶段(模型训练)'''# backpropagate and step optimizer 反向传播和步进优化器loss = loss.mean()'''求每一层损失值的平均值'''if phase == 'train':self.optimizer.zero_grad()'''将模型的参数梯度初始化为0'''loss.backward()'''反向传播计算梯度'''self.optimizer.step()'''更新所有参数''''''根据 pytorch 中 backward() 函数的计算当网络参量进行反馈时,梯度是累积计算而不是被替换但在处理每一个 batch 时并不需要与其他 batch的梯度混合起来累积计算因此需要对每个 batch 调用一遍 zero_grad() 将参数梯度置 0.'''batch_time.update(time.time() - end)'''更新 batch_time 的值'''end = time.time()Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(epoch, iter_id, num_iters, phase=phase,total=bar.elapsed_td, eta=bar.eta_td)'''bar.elapsed_td : 经过的时间增量eta=bar.eta_td : 时间间隔'''for l in avg_loss_stats:avg_loss_stats[l].update(loss_stats[l].mean().item(), batch['image'].size(0))Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg)'''更新平均损失'''Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \'|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)if opt.print_iter > 0:if iter_id % opt.print_iter == 0:print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) else:bar.next()'''opt.print_iter = 0 执行 else 语句,显示进度条'''if opt.debug > 0:self.debug(batch, output, iter_id, dataset=data_loader.dataset)'''debug 默认为 0,没有执行 if 语句'''if (phase == 'val' and (opt.run_dataset_eval or opt.eval)):meta = batch['meta']dets = fusion_decode(output, K=opt.K, opt=opt)'''解码器和雷达点云融合调用的这个函数位于 CenterFusion\src\lib\model\decode.py 中这个函数具体实现的功能就是将前面模型训练得到的结果,也就是一些特征图,这些特征图为多维矩阵将特征图与毫米波雷达点云进行映射,映射过程就是将特征图进行维度转换、升维等操作,然后再点乘旋转矩阵'''for k in dets:dets[k] = dets[k].detach().cpu().numpy()'''detach() 阻断反向传播,返回值仍为 tensorcpu() 将变量放在 cpu 上,仍为 tensornumpy() 将 tensor 转换为 numpy'''calib = meta['calib'].detach().numpy() if 'calib' in meta else Nonedets = generic_post_process(opt, dets, meta['c'].cpu().numpy(), meta['s'].cpu().numpy(),output['hm'].shape[2], output['hm'].shape[3], self.opt.num_classes,calib)result = []for i in range(len(dets[0])):if dets[0][i]['score'] > self.opt.out_thresh and all(dets[0][i]['dim'] > 0):result.append(dets[0][i])'''筛选结果'''img_id = batch['meta']['img_id'].numpy().astype(np.int32)[0]'''强制类型转换图片 id'''results[img_id] = resultdel output, loss, loss_statsbar.finish()ret = {k: v.avg for k, v in avg_loss_stats.items()}'''平均损失结果'''ret['time'] = bar.elapsed_td.total_seconds() / 60.return ret, results

这篇关于【CenterFusion】run_epoch()函数-训练一轮epoch-CenterFusion/src/lib/trainer.py的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/820868

相关文章

springboot将lib和jar分离的操作方法

《springboot将lib和jar分离的操作方法》本文介绍了如何通过优化pom.xml配置来减小SpringBoot项目的jar包大小,主要通过使用spring-boot-maven-plugin... 遇到一个问题,就是每次maven package或者maven install后target中的ja

Python itertools中accumulate函数用法及使用运用详细讲解

《Pythonitertools中accumulate函数用法及使用运用详细讲解》:本文主要介绍Python的itertools库中的accumulate函数,该函数可以计算累积和或通过指定函数... 目录1.1前言:1.2定义:1.3衍生用法:1.3Leetcode的实际运用:总结 1.1前言:本文将详

轻松上手MYSQL之JSON函数实现高效数据查询与操作

《轻松上手MYSQL之JSON函数实现高效数据查询与操作》:本文主要介绍轻松上手MYSQL之JSON函数实现高效数据查询与操作的相关资料,MySQL提供了多个JSON函数,用于处理和查询JSON数... 目录一、jsON_EXTRACT 提取指定数据二、JSON_UNQUOTE 取消双引号三、JSON_KE

MySQL数据库函数之JSON_EXTRACT示例代码

《MySQL数据库函数之JSON_EXTRACT示例代码》:本文主要介绍MySQL数据库函数之JSON_EXTRACT的相关资料,JSON_EXTRACT()函数用于从JSON文档中提取值,支持对... 目录前言基本语法路径表达式示例示例 1: 提取简单值示例 2: 提取嵌套值示例 3: 提取数组中的值注意

配置springboot项目动静分离打包分离lib方式

《配置springboot项目动静分离打包分离lib方式》本文介绍了如何将SpringBoot工程中的静态资源和配置文件分离出来,以减少jar包大小,方便修改配置文件,通过在jar包同级目录创建co... 目录前言1、分离配置文件原理2、pom文件配置3、使用package命令打包4、总结前言默认情况下,

Java function函数式接口的使用方法与实例

《Javafunction函数式接口的使用方法与实例》:本文主要介绍Javafunction函数式接口的使用方法与实例,函数式接口如一支未完成的诗篇,用Lambda表达式作韵脚,将代码的机械美感... 目录引言-当代码遇见诗性一、函数式接口的生物学解构1.1 函数式接口的基因密码1.2 六大核心接口的形态学

Python调用另一个py文件并传递参数常见的方法及其应用场景

《Python调用另一个py文件并传递参数常见的方法及其应用场景》:本文主要介绍在Python中调用另一个py文件并传递参数的几种常见方法,包括使用import语句、exec函数、subproce... 目录前言1. 使用import语句1.1 基本用法1.2 导入特定函数1.3 处理文件路径2. 使用ex

Oracle的to_date()函数详解

《Oracle的to_date()函数详解》Oracle的to_date()函数用于日期格式转换,需要注意Oracle中不区分大小写的MM和mm格式代码,应使用mi代替分钟,此外,Oracle还支持毫... 目录oracle的to_date()函数一.在使用Oracle的to_date函数来做日期转换二.日

python subprocess.run中的具体使用

《pythonsubprocess.run中的具体使用》subprocess.run是Python3.5及以上版本中用于运行子进程的函数,它提供了更简单和更强大的方式来创建和管理子进程,本文就来详细... 目录一、详解1.1、基本用法1.2、参数详解1.3、返回值1.4、示例1.5、总结二、subproce

C++11的函数包装器std::function使用示例

《C++11的函数包装器std::function使用示例》C++11引入的std::function是最常用的函数包装器,它可以存储任何可调用对象并提供统一的调用接口,以下是关于函数包装器的详细讲解... 目录一、std::function 的基本用法1. 基本语法二、如何使用 std::function