本文主要是介绍FCOS 计算loss源码解读,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
FCOS loss计算源码解读
最近在看FCOS论文总觉得不够具体,特此调试了源代码解读源代码以供自己以后查看。其中有很多技巧如果不是读作者源码是很难想到的。
包含一下内容:
- 如何根据原始数据的box坐标生成loss函数需要的box样式
- 如何根据大小不同box的分配不同level的特征图
"""
This file contains specific functions for computing losses of FCOS
file
"""import torch
from torch.nn import functional as F
from torch import nn
import os
from ..utils import concat_box_prediction_layers
from fcos_core.layers import IOULoss
from fcos_core.layers import SigmoidFocalLoss
from fcos_core.modeling.matcher import Matcher
from fcos_core.modeling.utils import cat
from fcos_core.structures.boxlist_ops import boxlist_iou
from fcos_core.structures.boxlist_ops import cat_boxlistINF = 100000000def get_num_gpus():return int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1def reduce_sum(tensor):if get_num_gpus() <= 1:return tensorimport torch.distributed as disttensor = tensor.clone()dist.all_reduce(tensor, op=dist.reduce_op.SUM)return tensorclass FCOSLossComputation(object):"""This class computes the FCOS losses."""def __init__(self, cfg):self.cls_loss_func = SigmoidFocalLoss(cfg.MODEL.FCOS.LOSS_GAMMA,cfg.MODEL.FCOS.LOSS_ALPHA)self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDESself.center_sampling_radius = cfg.MODEL.FCOS.CENTER_SAMPLING_RADIUSself.iou_loss_type = cfg.MODEL.FCOS.IOU_LOSS_TYPEself.norm_reg_targets = cfg.MODEL.FCOS.NORM_REG_TARGETS# we make use of IOU Loss for bounding boxes regression,# but we found that L1 in log scale can yield a similar performanceself.box_reg_loss_func = IOULoss(self.iou_loss_type)self.centerness_loss_func = nn.BCEWithLogitsLoss(reduction="sum")def get_sample_region(self, gt, strides, num_points_per, gt_xs, gt_ys, radius=1.0):'''This code is fromhttps://github.com/yqyao/FCOS_PLUS/blob/0d20ba34ccc316650d8c30febb2eb40cb6eaae37/maskrcnn_benchmark/modeling/rpn/fcos/loss.py#L42'''num_gts = gt.shape[0]K = len(gt_xs)gt = gt[None].expand(K, num_gts, 4)center_x = (gt[..., 0] + gt[..., 2]) / 2center_y = (gt[..., 1] + gt[..., 3]) / 2center_gt = gt.new_zeros(gt.shape)# no gtif center_x[..., 0].sum() == 0:return gt_xs.new_zeros(gt_xs.shape, dtype=torch.uint8)beg = 0for level, n_p in enumerate(num_points_per):end = beg + n_pstride = strides[level] * radiusxmin = center_x[beg:end] - strideymin = center_y[beg:end] - stride
这篇关于FCOS 计算loss源码解读的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!