本文主要是介绍easy-Fpn源码解读(二):train,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
目录
- easy-Fpn源码解读(二):train
- train.py完整代码
- 代码解析
easy-Fpn源码解读(二):train
train.py完整代码
import argparse
import os
import time
import uuid
from collections import deque
from typing import Optionalfrom tensorboardX import SummaryWriter
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoaderfrom backbone.base import Base as BackboneBase
from config.train_config import TrainConfig as Config
from dataset.base import Base as DatasetBase
from logger import Logger as Log
from model import Model
from roi.wrapper import Wrapper as ROIWrapperdef _train(dataset_name: str, backbone_name: str, path_to_data_dir: str, path_to_checkpoints_dir: str, path_to_resuming_checkpoint: Optional[str]):dataset = DatasetBase.from_name(dataset_name)(path_to_data_dir, DatasetBase.Mode.TRAIN, Config.IMAGE_MIN_SIDE, Config.IMAGE_MAX_SIDE)dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=8, pin_memory=True)Log.i('Found {:d} samples'.format(len(dataset)))backbone = BackboneBase.from_name(backbone_name)(pretrained=True)model = Model(backbone, dataset.num_classes(), pooling_mode=Config.POOLING_MODE,anchor_ratios=Config.ANCHOR_RATIOS, anchor_scales=Config.ANCHOR_SCALES,rpn_pre_nms_top_n=Config.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=Config.RPN_POST_NMS_TOP_N).cuda()optimizer = optim.SGD(model.parameters(), lr=Config.LEARNING_RATE,momentum=Config.MOMENTUM, weight_decay=Config.WEIGHT_DECAY)scheduler = MultiStepLR(optimizer, milestones=Config.STEP_LR_SIZES, gamma=Config.STEP_LR_GAMMA)step = 0time_checkpoint = time.time()losses = deque(maxlen=100)summary_writer = SummaryWriter(os.path.join(path_to_checkpoints_dir, 'summaries'))should_stop = Falsenum_steps_to_display = Config.NUM_STEPS_TO_DISPLAYnum_steps_to_snapshot = Config.NUM_STEPS_TO_SNAPSHOTnum_steps_to_finish = Config.NUM_STEPS_TO_FINISHif path_to_resuming_checkpoint is not None:step = model.load(path_to_resuming_checkpoint, optimizer, scheduler)Log.i(f'Model has been restored from file: {path_to_resuming_checkpoint}')Log.i('Start training')while not should_stop:for batch_index, (_, image_batch, _, bboxes_batch, labels_batch) in enumerate(dataloader):assert image_batch.shape[0] == 1, 'only batch size of 1 is supported'image = image_batch[0].cuda()bboxes = bboxes_batch[0].cuda()labels = labels_batch[0].cuda()forward_input = Model.ForwardInput.Train(image, gt_classes=labels, gt_bboxes=bboxes)forward_output: Model.ForwardOutput.Train = model.train().forward(forward_input)anchor_objectness_loss, anchor_transformer_loss, proposal_class_loss, proposal_transformer_loss = forward_outputloss = anchor_objectness_loss + anchor_transformer_loss + proposal_class_loss + proposal_transformer_lossoptimizer.zero_grad()loss.backward()optimizer.step()scheduler.step()losses.append(loss.item())summary_writer.add_scalar('train/anchor_objectness_loss'
这篇关于easy-Fpn源码解读(二):train的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!