easy-Fpn源码解读(二):train

2024-04-02 05:08
文章标签 源码 解读 train easy fpn

本文主要是介绍easy-Fpn源码解读(二):train,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

  • easy-Fpn源码解读(二):train
    • train.py完整代码
    • 代码解析

easy-Fpn源码解读(二):train

train.py完整代码

import argparse
import os
import time
import uuid
from collections import deque
from typing import Optionalfrom tensorboardX import SummaryWriter
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoaderfrom backbone.base import Base as BackboneBase
from config.train_config import TrainConfig as Config
from dataset.base import Base as DatasetBase
from logger import Logger as Log
from model import Model
from roi.wrapper import Wrapper as ROIWrapperdef _train(dataset_name: str, backbone_name: str, path_to_data_dir: str, path_to_checkpoints_dir: str, path_to_resuming_checkpoint: Optional[str]):dataset = DatasetBase.from_name(dataset_name)(path_to_data_dir, DatasetBase.Mode.TRAIN, Config.IMAGE_MIN_SIDE, Config.IMAGE_MAX_SIDE)dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=8, pin_memory=True)Log.i('Found {:d} samples'.format(len(dataset)))backbone = BackboneBase.from_name(backbone_name)(pretrained=True)model = Model(backbone, dataset.num_classes(), pooling_mode=Config.POOLING_MODE,anchor_ratios=Config.ANCHOR_RATIOS, anchor_scales=Config.ANCHOR_SCALES,rpn_pre_nms_top_n=Config.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=Config.RPN_POST_NMS_TOP_N).cuda()optimizer = optim.SGD(model.parameters(), lr=Config.LEARNING_RATE,momentum=Config.MOMENTUM, weight_decay=Config.WEIGHT_DECAY)scheduler = MultiStepLR(optimizer, milestones=Config.STEP_LR_SIZES, gamma=Config.STEP_LR_GAMMA)step = 0time_checkpoint = time.time()losses = deque(maxlen=100)summary_writer = SummaryWriter(os.path.join(path_to_checkpoints_dir, 'summaries'))should_stop = Falsenum_steps_to_display = Config.NUM_STEPS_TO_DISPLAYnum_steps_to_snapshot = Config.NUM_STEPS_TO_SNAPSHOTnum_steps_to_finish = Config.NUM_STEPS_TO_FINISHif path_to_resuming_checkpoint is not None:step = model.load(path_to_resuming_checkpoint, optimizer, scheduler)Log.i(f'Model has been restored from file: {path_to_resuming_checkpoint}')Log.i('Start training')while not should_stop:for batch_index, (_, image_batch, _, bboxes_batch, labels_batch) in enumerate(dataloader):assert image_batch.shape[0] == 1, 'only batch size of 1 is supported'image = image_batch[0].cuda()bboxes = bboxes_batch[0].cuda()labels = labels_batch[0].cuda()forward_input = Model.ForwardInput.Train(image, gt_classes=labels, gt_bboxes=bboxes)forward_output: Model.ForwardOutput.Train = model.train().forward(forward_input)anchor_objectness_loss, anchor_transformer_loss, proposal_class_loss, proposal_transformer_loss = forward_outputloss = anchor_objectness_loss + anchor_transformer_loss + proposal_class_loss + proposal_transformer_lossoptimizer.zero_grad()loss.backward()optimizer.step()scheduler.step()losses.append(loss.item())summary_writer.add_scalar('train/anchor_objectness_loss'

这篇关于easy-Fpn源码解读(二):train的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/869042

相关文章

java之Objects.nonNull用法代码解读

《java之Objects.nonNull用法代码解读》:本文主要介绍java之Objects.nonNull用法代码,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐... 目录Java之Objects.nonwww.chinasem.cnNull用法代码Objects.nonN

Python实现无痛修改第三方库源码的方法详解

《Python实现无痛修改第三方库源码的方法详解》很多时候,我们下载的第三方库是不会有需求不满足的情况,但也有极少的情况,第三方库没有兼顾到需求,本文将介绍几个修改源码的操作,大家可以根据需求进行选择... 目录需求不符合模拟示例 1. 修改源文件2. 继承修改3. 猴子补丁4. 追踪局部变量需求不符合很

SpringCloud负载均衡spring-cloud-starter-loadbalancer解读

《SpringCloud负载均衡spring-cloud-starter-loadbalancer解读》:本文主要介绍SpringCloud负载均衡spring-cloud-starter-loa... 目录简述主要特点使用负载均衡算法1. 轮询负载均衡策略(Round Robin)2. 随机负载均衡策略(

解读spring.factories文件配置详情

《解读spring.factories文件配置详情》:本文主要介绍解读spring.factories文件配置详情,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录使用场景作用内部原理机制SPI机制Spring Factories 实现原理用法及配置spring.f

Spring MVC使用视图解析的问题解读

《SpringMVC使用视图解析的问题解读》:本文主要介绍SpringMVC使用视图解析的问题解读,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Spring MVC使用视图解析1. 会使用视图解析的情况2. 不会使用视图解析的情况总结Spring MVC使用视图

Linux中的进程间通信之匿名管道解读

《Linux中的进程间通信之匿名管道解读》:本文主要介绍Linux中的进程间通信之匿名管道解读,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、基本概念二、管道1、温故知新2、实现方式3、匿名管道(一)管道中的四种情况(二)管道的特性总结一、基本概念我们知道多

Spring 中 BeanFactoryPostProcessor 的作用和示例源码分析

《Spring中BeanFactoryPostProcessor的作用和示例源码分析》Spring的BeanFactoryPostProcessor是容器初始化的扩展接口,允许在Bean实例化前... 目录一、概览1. 核心定位2. 核心功能详解3. 关键特性二、Spring 内置的 BeanFactory

Linux系统之authconfig命令的使用解读

《Linux系统之authconfig命令的使用解读》authconfig是一个用于配置Linux系统身份验证和账户管理设置的命令行工具,主要用于RedHat系列的Linux发行版,它提供了一系列选项... 目录linux authconfig命令的使用基本语法常用选项示例总结Linux authconfi

SpringBoot集成图片验证码框架easy-captcha的详细过程

《SpringBoot集成图片验证码框架easy-captcha的详细过程》本文介绍了如何将Easy-Captcha框架集成到SpringBoot项目中,实现图片验证码功能,Easy-Captcha是... 目录SpringBoot集成图片验证码框架easy-captcha一、引言二、依赖三、代码1. Ea

解读docker运行时-itd参数是什么意思

《解读docker运行时-itd参数是什么意思》在Docker中,-itd参数组合用于在后台运行一个交互式容器,同时保持标准输入和分配伪终端,这种方式适合需要在后台运行容器并保持交互能力的场景... 目录docker运行时-itd参数是什么意思1. -i(或 --interactive)2. -t(或 --