DeepSnake源代码阅读笔记----process on Sbd datasets

2023-10-15 04:40

本文主要是介绍DeepSnake源代码阅读笔记----process on Sbd datasets,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

DeepSnake 源代码阅读笔记

  • 本篇主要分析Deep Snake for Real-Time Instance Segmentation 文献中代码的实现,其开源代码地址为:https://github.com/zju3dv/snake。

  • [good18Levin]((6条消息) 论文阅读_fighterflyfly的博客-CSDN博客) 的deepsnake代码阅读记录对我启发很大,该博客主要分析了模型测试过程中的代码。

  • 本篇笔记主要用于记录本人分析deepsnake在sbd数据集上训练的过程,本人水平有限,难免有理解不到位的地方,如有疑问欢迎交流。dongximing@webmail.hzau.edu.cn

〇、 环境配置

源码要求环境为pytorch1.1+cuda9.0,由于我显卡是RTX3090,并不支持老版本的cuda9.0,因此我采用了pytorch1.8.1+cuda11.1的环境,python版本为3.7。

  1. pytorch及cuda的安装

    pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
    

    也可以使用conda安装,但是使用conda安装会在后续编译cuda拓展的时候报错。

  2. 安装apex及编译

    git clone https://github.com/NVIDIA/apex.git
    cd apex 
    python setup.py install --cuda_ext --cpp_ext
    

    这里要注意nvidia cuda 版本要和pytorch cuda版本一致,否则可能会报错。

  3. dcn_v2的安装

    源码中的dcn_v2版本不支持高版本的pytorch及cuda,可以下载安装新版本的dcn_v2

    cd lib/csrc
    rm -rf dcn_v2
    git clone https://github.com/jinfagang/DCNv2_latest.git
    mv DCNv2_latest dcn_v2
    cd dcn_v2
    git checkout fa9b2fd740ced2a22e0e7e913c3bf3934fd08098
    python setup.py build develop
    

    因为最新版本支持pytorch1.11可能不会向下兼容,所以使用git checkout 回退之前支持pytorch1.8的版本。

  4. extreme_utils和roi_align_layer的安装按照源码进行安装即可

    cd lib/csrc/extreme_utils
    python setup.py build_ext --inplace
    cd ../roi_align_layer
    python setup.py build_ext --inplace
    

一、代码结构

  • 整体项目使用了许多种数据集并且包含了许多种网络,需要配置的参数量较多,因此项目使用了cfg读取yaml并结合parser来进行全局参数的储存。根据README.md针对sbd训练传入的参数以及configs/sbd_snake.yaml文件,可以得到cfg的参数:

    ct_score: 0.05
    demo_path:
    det_dir: data/model/snake/
    eval_ep: 5
    gpus: [0, 1, 2, 3]
    head_conv: 256
    heads:
    ct_hm: 20
    wh: 2
    model: sbd_snake
    model_dir: data/model/snake/sbd
    network: ro_34
    record_dir: data/record/snake/sbd
    result_dir: data/result/snake/sbd
    resume: True
    save_ep: 5
    segm_or_bbox: segm
    skip_eval: False
    task: snake
    test:
    batch_size: 1
    dataset: SbdMini
    epoch: -1
    train:
    batch_size: 80
    dataset: SbdTrain
    epoch: 150
    gamma: 0.5
    lr: 0.0001
    milestones: [80, 120, 150, 170]
    num_workers: 32
    optim: adam
    scheduler:
    warmup: False
    weight_decay: 0.0005
    use_gt_det: False

  • 项目训练采用了许多的封装,使得代码调用十分方便,但是对于代码阅读性也增添了些许难度,在接下来的代码阅读中我会利用cfg参数等对底层构建代码进行分析。

    #train_net.py
    from lib.config import cfg, args
    from lib.networks import make_network
    from lib.train import make_trainer, make_optimizer, make_lr_scheduler, make_recorder, set_lr_scheduler
    from lib.datasets import make_data_loader
    from lib.utils.net_utils import load_model, save_model, load_network
    from lib.evaluators import make_evaluator
    import torch.multiprocessingdef train(cfg, network):trainer = make_trainer(cfg, network) #将损失函数与网络模型封装optimizer = make_optimizer(cfg, network) #封装优化器scheduler = make_lr_scheduler(cfg, optimizer) #封装学习率recorder = make_recorder(cfg) #记录evaluator = make_evaluator(cfg) #评估begin_epoch = load_model(network, optimizer, scheduler, recorder, cfg.model_dir, resume=cfg.resume)# set_lr_scheduler(cfg, scheduler)train_loader = make_data_loader(cfg, is_train=True) #封装数据集构建函数val_loader = make_data_loader(cfg, is_train=False)for epoch in range(begin_epoch, cfg.train.epoch):recorder.epoch = epochtrainer.train(epoch, train_loader, optimizer, recorder) #训练过程scheduler.step()#保存数据if (epoch + 1) % cfg.save_ep == 0:save_model(network, optimizer, scheduler, recorder, epoch, cfg.model_dir)if (epoch + 1) % cfg.eval_ep == 0:trainer.val(epoch, val_loader, evaluator, recorder)return networkdef main():network = make_network(cfg) #利用cfg参数构建网络if args.test:test(cfg, network)else:train(cfg, network)if __name__ == "__main__":main()

二、数据集处理

首先找到构建数据集的封装函数lib.datasets.make_data_loader

#lib/datasets/make_data_loader.py
from .transforms import make_transforms
from . import samplers
from .dataset_catalog import DatasetCatalog
import torch
import torch.utils.data
import imp
import os
from .collate_batch import make_collatortorch.multiprocessing.set_sharing_strategy('file_system')def _dataset_factory(data_source, task):module = '.'.join(['lib.datasets', data_source, task])path = os.path.join('lib/datasets', data_source, task+'.py')dataset = imp.load_source(module, path).Dataset #加载lib.datasets.sbd.snake.Dataset函数return datasetdef make_dataset(cfg, dataset_name, transforms, is_train=True):args = DatasetCatalog.get(dataset_name)#SbdTraindata_source = args['id']dataset = _dataset_factory(data_source, cfg.task) #('sbd','snake')del args['id']# args['cfg'] = cfg# args['transforms'] = transforms# args['is_train'] = is_traindataset = dataset(**args)return datasetdef make_data_sampler(dataset, shuffle):if shuffle:sampler = torch.utils.data.sampler.RandomSampler(dataset)else:sampler = torch.utils.data.sampler.SequentialSampler(dataset)return samplerdef make_batch_data_sampler(cfg, sampler, batch_size, drop_last, max_iter):batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, batch_size, drop_last)if max_iter != -1:batch_sampler = samplers.IterationBasedBatchSampler(batch_sampler, max_iter)return batch_samplerdef make_data_loader(cfg, is_train=True, is_distributed=False, max_iter=-1):if is_train:batch_size = cfg.train.batch_size #80shuffle = Truedrop_last = Falseelse:batch_size = cfg.test.batch_sizeshuffle = True if is_distributed else Falsedrop_last = Falsedataset_name = cfg.train.dataset if is_train else cfg.test.dataset #Truetransforms = make_transforms(cfg, is_train) #转变为Tensor并归一化,但是后续并没有用到dataset = make_dataset(cfg, dataset_name, transforms, is_train)sampler = make_data_sampler(dataset, shuffle)batch_sampler = make_batch_data_sampler(cfg, sampler, batch_size, drop_last, max_iter)num_workers = cfg.train.num_workers #32collator = make_collator(cfg)data_loader = torch.utils.data.DataLoader(dataset,batch_sampler=batch_sampler,num_workers=num_workers,collate_fn=collator)return data_loader

使用了DatasetCatalog中封装的数据集的信息,以及调用了lib.datasets.sbd.snake.Dataset函数构建数据集。

##lib/dataset/dataset_catalog.py
from lib.config import cfg
class DatasetCatalog(object):dataset_attrs = {  'SbdTrain': {'id': 'sbd','data_root': 'data/sbd/img','ann_file': 'data/sbd/annotations/sbd_train_instance.json','split': 'train'}}@staticmethoddef get(name):attrs = DatasetCatalog.dataset_attrs[name]return attrs.copy()

lib.datasets.sbd.snake.Dataset又调用了lib.datasets.voc.snake.Dataset函数。

#lib/datasets/voc/snake.py
import os
from lib.utils.snake import snake_voc_utils, snake_config, visualize_utils
import cv2
import numpy as np
import math
from lib.utils import data_utils
import torch.utils.data as data
from pycocotools.coco import COCO
from lib.config import cfgclass Dataset(data.Dataset):def __init__(self, ann_file, data_root, split):super(Dataset, self).__init__()"""data_root='data/sbd/img'ann_file='data/sbd/annotations/sbd_train_instance.json'split='train'"""self.data_root = data_rootself.split = split"""构建COCO对象"""self.coco = COCO(ann_file)self.anns = np.array(sorted(self.coco.getImgIds()))  # shape(5623,)self.anns = self.anns[:500] if split == 'mini' else self.annsself.json_category_id_to_contiguous_id = {v: i for i, v in enumerate(self.coco.getCatIds())}def process_info(self, img_id):"""根据给定的imgid获得anno标注对象以及图片路径"""ann_ids = self.coco.getAnnIds(imgIds=img_id)  # 获取标注idanno = self.coco.loadAnns(ann_ids)  # 读取标注信息path = os.path.join(self.data_root, self.coco.loadImgs(int(img_id))[0]['file_name'])return anno, path, img_iddef read_original_data(self, anno, path):"""读取图片将标注的segmentation由xyxyxyxyxy一维转为二维[[x,y],[x,y]] -> instance_polys读取标注对应的classid -> cls_ids"""img = cv2.imread(path)instance_polys = [[np.array(poly).reshape(-1, 2) for poly in obj['segmentation']] for obj in anno]cls_ids = [self.json_category_id_to_contiguous_id[obj['category_id']] for obj in anno]return img, instance_polys, cls_idsdef transform_original_data(self, instance_polys, flipped, width, trans_output, inp_out_hw):"""因为数据增强的缘故,也需要对原始的标注进行相应的变换使用了trans_output进行仿射变换,使标注与图像的仿射变换保持一致"""output_h, output_w = inp_out_hw[2:]instance_polys_ = []for instance in instance_polys:polys = [poly.reshape(-1, 2) for poly in instance]if flipped:polys_ = []for poly in polys:poly[:, 0] = width - np.array(poly[:, 0]) - 1polys_.append(poly.copy())polys = polys_polys = snake_voc_utils.transform_polys(polys, trans_output, output_h, output_w)instance_polys_.append(polys)return instance_polys_def get_valid_polys(self, instance_polys, inp_out_hw):"""对变换后的多边形标注进行过滤:1.过滤掉点数小于等于四的多边形2.对多边形坐标进行clip,防止多边形超出output size3.filter删除面积小于5的多边形4.cw获取顺时针的poly5.对坐标重复的顶点进行去重,保持顶点的顺序不变"""output_h, output_w = inp_out_hw[2:]instance_polys_ = []for instance in instance_polys:instance = [poly for poly in instance if len(poly) >= 4]for poly in instance:poly[:, 0] = np.clip(poly[:, 0], 0, output_w - 1)poly[:, 1] = np.clip(poly[:, 1], 0, output_h - 1)polys = snake_voc_utils.filter_tiny_polys(instance)polys = snake_voc_utils.get_cw_polys(polys)polys = [poly[np.sort(np.unique(poly, axis=0, return_index=True)[1])] for poly in polys]instance_polys_.append(polys)return instance_polys_def get_extreme_points(self, instance_polys):"""获取多边形的极值点坐标(上左下右)(逆时针)"""extreme_points = []for instance in instance_polys:points = [snake_voc_utils.get_extreme_points(poly) for poly in instance]extreme_points.append(points)return extreme_pointsdef prepare_detection(self, box, poly, ct_hm, cls_id, wh, ct_cls, ct_ind):"""计算center将已计算的class的id保存clsid,将center编码保存在ctind计算高斯模糊半径,并更新cthm"""ct_hm = ct_hm[cls_id]#ct_hm.shape:(class,128,128) ->(128,128)ct_cls.append(cls_id)#将已经prepare的class保存x_min, y_min, x_max, y_max = boxct = np.array([(x_min + x_max) / 2, (y_min + y_max) / 2], dtype=np.float32)ct = np.round(ct).astype(np.int32)h, w = y_max - y_min, x_max - x_minradius = data_utils.gaussian_radius((math.ceil(h), math.ceil(w)))radius = max(0, int(radius))data_utils.draw_umich_gaussian(ct_hm, ct, radius)wh.append([w, h])ct_ind.append(ct[1] * ct_hm.shape[1] + ct[0])x_min, y_min = ct[0] - w / 2, ct[1] - h / 2x_max, y_max = ct[0] + w / 2, ct[1] + h / 2decode_box = [x_min, y_min, x_max, y_max]return decode_boxdef prepare_init(self, box, extreme_point, i_it_4pys, c_it_4pys, i_gt_4pys, c_gt_4pys, h, w):x_min, y_min = np.min(extreme_point[:, 0]), np.min(extreme_point[:, 1])x_max, y_max = np.max(extreme_point[:, 0]), np.max(extreme_point[:, 1])img_init_poly = snake_voc_utils.get_init(box)  # 中点连接的四边形"""uniformsaple 对中点连接的四边形进行均匀采样,获得40个点;如果原始点大于40,则减去最短的路径。"""img_init_poly = snake_voc_utils.uniformsample(img_init_poly, snake_config.init_poly_num)#在图片中的poly坐标can_init_poly = snake_voc_utils.img_poly_to_can_poly(img_init_poly, x_min, y_min, x_max, y_max)#归一化的poly坐标img_gt_poly = extreme_pointcan_gt_poly = snake_voc_utils.img_poly_to_can_poly(img_gt_poly, x_min, y_min, x_max, y_max)i_it_4pys.append(img_init_poly)c_it_4pys.append(can_init_poly)i_gt_4pys.append(img_gt_poly)c_gt_4pys.append(can_gt_poly)def prepare_evolution(self, poly, extreme_point, img_init_polys, can_init_polys, img_gt_polys, can_gt_polys):x_min, y_min = np.min(extreme_point[:, 0]), np.min(extreme_point[:, 1])x_max, y_max = np.max(extreme_point[:, 0]), np.max(extreme_point[:, 1])octagon = snake_voc_utils.get_octagon(extreme_point) #根据极值点建立八边形img_init_poly = snake_voc_utils.uniformsample(octagon, snake_config.poly_num)can_init_poly = snake_voc_utils.img_poly_to_can_poly(img_init_poly, x_min, y_min, x_max, y_max)img_gt_poly = snake_voc_utils.uniformsample(poly, len(poly) * snake_config.gt_poly_num)tt_idx = np.argmin(np.power(img_gt_poly - img_init_poly[0], 2).sum(axis=1))img_gt_poly = np.roll(img_gt_poly, -tt_idx, axis=0)[::len(poly)] #将最大值顶点移动到首位元素can_gt_poly = snake_voc_utils.img_poly_to_can_poly(img_gt_poly, x_min, y_min, x_max, y_max)img_init_polys.append(img_init_poly)can_init_polys.append(can_init_poly)img_gt_polys.append(img_gt_poly)can_gt_polys.append(can_gt_poly)def prepare_merge(self, is_id, cls_id, cp_id, cp_cls):cp_id.append(is_id)cp_cls.append(cls_id)def __getitem__(self, index):ann = self.anns[index]anno, path, img_id = self.process_info(ann)img, instance_polys, cls_ids = self.read_original_data(anno, path)height, width = img.shape[0], img.shape[1]"""imgsize:442*500 -> inp size:512*512进行图片增强:随机裁剪翻转缩放,颜色增强,图片归一化,仿射变换trans为仿射变换矩阵fipped记录是否翻转inp_out_hw 记录输入输出宽高 (512,512,128,128),因为降采样倍数为4scale 随机裁剪后,仿射变换前的图片尺寸(缩放尺度)图片缩小了四倍。"""orig_img, inp, trans_input, trans_output, flipped, center, scale, inp_out_hw = \snake_voc_utils.augment(img, self.split,snake_config.data_rng, snake_config.eig_val, snake_config.eig_vec,snake_config.mean, snake_config.std, instance_polys)instance_polys = self.transform_original_data(instance_polys, flipped, width, trans_output, inp_out_hw)instance_polys = self.get_valid_polys(instance_polys, inp_out_hw)extreme_points = self.get_extreme_points(instance_polys)# detectionoutput_h, output_w = inp_out_hw[2:]ct_hm = np.zeros([cfg.heads.ct_hm, output_h, output_w], dtype=np.float32)wh = []ct_cls = []ct_ind = []# initi_it_4pys = []c_it_4pys = []i_gt_4pys = []c_gt_4pys = []# evolutioni_it_pys = []c_it_pys = []i_gt_pys = []c_gt_pys = []for i in range(len(anno)):cls_id = cls_ids[i]instance_poly = instance_polys[i]instance_points = extreme_points[i]for j in range(len(instance_poly)):poly = instance_poly[j]extreme_point = instance_points[j]x_min, y_min = np.min(poly[:, 0]), np.min(poly[:, 1])x_max, y_max = np.max(poly[:, 0]), np.max(poly[:, 1])bbox = [x_min, y_min, x_max, y_max]h, w = y_max - y_min + 1, x_max - x_min + 1if h <= 1 or w <= 1:continue"""计算center将已计算的class的id保存clsid,将center编码保存在ctind计算高斯模糊半径,并更新cthm"""self.prepare_detection(bbox, poly, ct_hm, cls_id, wh, ct_cls, ct_ind)"""根据bbox中点连接的四边形,均匀采样四十个点生成i_it_4pys,c_it_4pys为其归一化的点坐标i_gt_4pys为极值点,c_gt_4pys为其归一化的点坐标"""self.prepare_init(bbox, extreme_point, i_it_4pys, c_it_4pys, i_gt_4pys, c_gt_4pys, output_h, output_w)"""根据极值点建立八边形,并采样128个点,生成i_it_pysi_gt_pys为根据标注采样的128个点"""self.prepare_evolution(poly, extreme_point, i_it_pys, c_it_pys, i_gt_pys, c_gt_pys)ret = {'inp': inp}detection = {'ct_hm': ct_hm, 'wh': wh, 'ct_cls': ct_cls, 'ct_ind': ct_ind}init = {'i_it_4py': i_it_4pys, 'c_it_4py': c_it_4pys, 'i_gt_4py': i_gt_4pys, 'c_gt_4py': c_gt_4pys}evolution = {'i_it_py': i_it_pys, 'c_it_py': c_it_pys, 'i_gt_py': i_gt_pys, 'c_gt_py': c_gt_pys}ret.update(detection)ret.update(init)ret.update(evolution)# visualize_utils.visualize_snake_detection(orig_img, ret)# visualize_utils.visualize_snake_evolution(orig_img, ret)ct_num = len(ct_ind)meta = {'center': center, 'scale': scale, 'img_id': img_id, 'ann': ann, 'ct_num': ct_num}ret.update({'meta': meta})return retdef __len__(self):return len(self.anns)

最后制作的dataloader保存数据通过for x,y in enumerate(train_loader): break 可以查看:

在这里插入图片描述

三、网络结构

#train_net.py
from lib.networks import make_network
def main():network = make_network(cfg)if args.test:test(cfg, network)else:train(cfg, network)

根据train_net.py找到调用函数lib.networks.make_network

#lib/networks/make_network.py
def make_network(cfg):module = '.'.join(['lib.networks', cfg.task])path = os.path.join('lib/networks', cfg.task, '__init__.py')return imp.load_source(module, path).get_network(cfg)

cfg.task为snake,找到lib.networks.snake.\__init__.get\_network函数

#lib/networks/snkae/__init__.py
from lib.utils.snake import snake_config
from .ct_snake import get_network as get_ro_network_factory = {'ro': get_ro
}
def get_network(cfg):arch = cfg.networkheads = cfg.headshead_conv = cfg.head_convnum_layers = int(arch[arch.find('_') + 1:]) if '_' in arch else 0arch = arch[:arch.find('_')] if '_' in arch else archget_model = _network_factory[arch]network = get_model(num_layers, heads, head_conv, snake_config.down_ratio, cfg.det_dir)return network

看来最后的网络构建函数来自于lib.networks.snake.ct_snake.get_network

#lib/networks/snake/ct_snake.py
import torch.nn as nn
from .dla import DLASeg
from .evolve import Evolution
from lib.utils import net_utils, data_utils
from lib.utils.snake import snake_decode
import torch
from lib.config import cfgclass Network(nn.Module):def __init__(self, num_layers, heads, head_conv=256, down_ratio=4, det_dir=''):super(Network, self).__init__()"""DLASeg('dla34',heads={'ct_hm':20,'wh':2},pretrained=True,down_ratio=4,final_kernel=1,last_level=5,head_conv=256)"""self.dla = DLASeg('dla{}'.format(num_layers), heads,pretrained=True,down_ratio=down_ratio,final_kernel=1,last_level=5,head_conv=head_conv)self.gcn = Evolution()def decode_detection(self, output, h, w):ct_hm = output['ct_hm']wh = output['wh']ct, detection = snake_decode.decode_ct_hm(torch.sigmoid(ct_hm), wh)detection[..., :4] = data_utils.clip_to_image(detection[..., :4], h, w)output.update({'ct': ct, 'detection': detection})return ct, detectiondef forward(self, x, batch=None):output, cnn_feature = self.dla(x)with torch.no_grad():ct, detection = self.decode_detection(output, cnn_feature.size(2), cnn_feature.size(3))if cfg.use_gt_det:self.use_gt_detection(output, batch)output = self.gcn(output, cnn_feature, batch)return outputdef get_network(num_layers, heads, head_conv=256, down_ratio=4, det_dir=''):network = Network(num_layers, heads, head_conv, down_ratio, det_dir)return network

由上述代码可以看到,网络结构主要分为三个部分:

  • DLAseg,使用了dla34的特征提取网络;
  • decode_detection,解码特征图;
  • Evolution,主要涉及轮廓演变。

1. DLASeg

该DLAseg主要分为:

  • base层(dla34)
  • DLAUP和IDAUP
  • 检测头

最后输出:包含两个检测头输出的字典,CNNfeature。

class DLASeg(nn.Module):def __init__(self, base_name, heads, pretrained, down_ratio, final_kernel,last_level, head_conv, out_channel=0):super(DLASeg, self).__init__()"""base_name=dla34heads={'ct_hm':20,'wh':2}pretrained=Truedwon_ratio=4final_kernel=1last_level=5head_conv=256"""assert down_ratio in [2, 4, 8, 16]self.first_level = int(np.log2(down_ratio)) # 2self.last_level = last_levelself.base = globals()[base_name](pretrained=pretrained) #dla34(pretrained=True)channels = self.base.channelsscales = [2 ** i for i in range(len(channels[self.first_level:]))]self.dla_up = DLAUp(self.first_level, channels[self.first_level:], scales)if out_channel == 0:out_channel = channels[self.first_level]self.ida_up = IDAUp(out_channel, channels[self.first_level:self.last_level],[2 ** i for i in range(self.last_level - self.first_level)])self.heads = heads"""根据heads建立检测头,本任务有两个检测头:ct_hm:heatmap检测;wh:bbox的宽高。"""for head in self.heads:classes = self.heads[head]if head_conv > 0:fc = nn.Sequential(nn.Conv2d(channels[self.first_level], head_conv,kernel_size=3, padding=1, bias=True),nn.ReLU(inplace=True),nn.Conv2d(head_conv, classes,kernel_size=final_kernel, stride=1,padding=final_kernel // 2, bias=True))if 'hm' in head:fc[-1].bias.data.fill_(-2.19)else:fill_fc_weights(fc)else:fc = nn.Conv2d(channels[self.first_level], classes,kernel_size=final_kernel, stride=1,padding=final_kernel // 2, bias=True)if 'hm' in head:fc.bias.data.fill_(-2.19)else:fill_fc_weights(fc)self.__setattr__(head, fc)def forward(self, x):x = self.base(x)x = self.dla_up(x)y = []for i in range(self.last_level - self.first_level):y.append(x[i].clone())self.ida_up(y, 0, len(y))z = {}for head in self.heads:z[head] = self.__getattr__(head)(y[-1])return z, y[-1]
1.1 dla34
def dla34(pretrained=True, **kwargs):  # DLA-34model = DLA([1, 1, 1, 2, 2, 1],[16, 32, 64, 128, 256, 512],block=BasicBlock, **kwargs)if pretrained:model.load_pretrained_model(data='imagenet', name='dla34', hash='ba72cf86')return modelclass DLA(nn.Module):def __init__(self, levels, channels, num_classes=1000,block=BasicBlock, residual_root=False, linear_root=False):super(DLA, self).__init__()self.channels = channelsself.num_classes = num_classesself.base_layer = nn.Sequential(nn.Conv2d(3, channels[0], kernel_size=7, stride=1,padding=3, bias=False),nn.BatchNorm2d(channels[0], momentum=BN_MOMENTUM),nn.ReLU(inplace=True))self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2)"""levelroot可以看作一个标记,看以看做是否接受上一个level输出的downsample传入(下图红色方块之间的蓝色箭头);默认的level_root为False,通过递归创建的Tree并未特定level_root,表示与前一个Tree为同一个level。"""self.level2 = Tree(levels[2], block, channels[1], channels[2], 2,level_root=False,root_residual=residual_root)self.level3 = Tree(levels[3], block, channels[2], channels[3], 2,level_root=True, root_residual=residual_root)self.level4 = Tree(levels[4], block, channels[3], channels[4], 2,level_root=True, root_residual=residual_root)self.level5 = Tree(levels[5], block, channels[4], channels[5], 2,level_root=True, root_residual=residual_root)def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1):modules = []for i in range(convs):modules.extend([nn.Conv2d(inplanes, planes, kernel_size=3,stride=stride if i == 0 else 1,padding=dilation, bias=False, dilation=dilation),nn.BatchNorm2d(planes, momentum=BN_MOMENTUM),nn.ReLU(inplace=True)])inplanes = planesreturn nn.Sequential(*modules)def forward(self, x):y = []x = self.base_layer(x)for i in range(6):x = getattr(self, 'level{}'.format(i))(x)y.append(x)return y

dla34是首先通过Conv-BN-RL基础层然后再通过五个level进行特征提取的,level0和level1都是Conv-BN-RL层,level2、3、4、5则是调用Tree对象进行构建。DLA的输入参数levels,则表示了level0-5的循环调用与否。

levels:[1, 1, 1, 2, 2, 1]

表示level3,4会循环调用Tree。

class Tree(nn.Module):def __init__(self, levels, block, in_channels, out_channels, stride=1,level_root=False, root_dim=0, root_kernel_size=1,dilation=1, root_residual=False):super(Tree, self).__init__()if root_dim == 0:root_dim = 2 * out_channelsif level_root:root_dim += in_channels"""当输入level等于1时直接用BasicBlock进行构建,当level大于1时进行Tree循环构建。"""if levels == 1:self.tree1 = block(in_channels, out_channels, stride,dilation=dilation)self.tree2 = block(out_channels, out_channels, 1,dilation=dilation)else:self.tree1 = Tree(levels - 1, block, in_channels, out_channels,stride, root_dim=0,root_kernel_size=root_kernel_size,dilation=dilation, root_residual=root_residual)self.tree2 = Tree(levels - 1, block, out_channels, out_channels,root_dim=root_dim + out_channels,root_kernel_size=root_kernel_size,dilation=dilation, root_residual=root_residual)if levels == 1:self.root = Root(root_dim, out_channels, root_kernel_size,root_residual)self.level_root = level_rootself.root_dim = root_dimself.downsample = Noneself.project = Noneself.levels = levelsif stride > 1:self.downsample = nn.MaxPool2d(stride, stride=stride)if in_channels != out_channels:self.project = nn.Sequential(nn.Conv2d(in_channels, out_channels,kernel_size=1, stride=1, bias=False),nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM))def forward(self, x, residual=None, children=None):children = [] if children is None else childrenbottom = self.downsample(x) if self.downsample else xresidual = self.project(bottom) if self.project else bottomif self.level_root:children.append(bottom)x1 = self.tree1(x, residual)if self.levels == 1:x2 = self.tree2(x1)x = self.root(x2, x1, *children)else:children.append(x1)x = self.tree2(x1, children=children)return xclass Root(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, residual):super(Root, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, 1,stride=1, bias=False, padding=(kernel_size - 1) // 2)self.bn = nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM)self.relu = nn.ReLU(inplace=True)self.residual = residualdef forward(self, *x):children = xx = self.conv(torch.cat(x, 1))x = self.bn(x)if self.residual:x += children[0]x = self.relu(x)return xclass BasicBlock(nn.Module):def __init__(self, inplanes, planes, stride=1, dilation=1):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3,stride=stride, padding=dilation,bias=False, dilation=dilation)self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,stride=1, padding=dilation,bias=False, dilation=dilation)self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)self.stride = stridedef forward(self, x, residual=None):if residual is None:residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += residualout = self.relu(out)return out

img

  • 如上图DLA结构示意图,黑色方块就表示Basicblock,绿色方框就代表root,红色方框则是代表了一个Treelevel,实际上dla34只使用了第一个、第二个和第四个红色方块,第一个红色方块表示了level2,第二个红色方块表示了level3、level4,第四个红色方块则表示了level5。
  • dla34代码比较复杂,需要自己推理加深理解,[*pprp*博客]((9条消息) DLA模型(分类模型+改进版分割模型) + 可变形卷积_BIT_Legend的博客-CSDN博客)以及[算法小乔知乎](DLA34主干网络介绍 - 知乎 (zhihu.com))对此有更深入的解释。
1.2 DLAUP与IDAUP
class DLAUp(nn.Module):def __init__(self, startp, channels, scales, in_channels=None):super(DLAUp, self).__init__()"""startp=2channels=[64,128,256,512]sclaes[1,2,4,8]"""self.startp = startpif in_channels is None:in_channels = channelsself.channels = channelschannels = list(channels)scales = np.array(scales, dtype=int)for i in range(len(channels) - 1): #range(3)j = -i - 2setattr(self, 'ida_{}'.format(i),IDAUp(channels[j], in_channels[j:],scales[j:] // scales[j]))scales[j + 1:] = scales[j]in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]]def forward(self, layers):out = [layers[-1]]  # start with 32for i in range(len(layers) - self.startp - 1):ida = getattr(self, 'ida_{}'.format(i))ida(layers, len(layers) - i - 2, len(layers))out.insert(0, layers[-1])return outclass IDAUp(nn.Module):def __init__(self, o, channels, up_f):super(IDAUp, self).__init__()for i in range(1, len(channels)):c = channels[i]f = int(up_f[i])"""使用可形变卷积"""proj = DeformConv(c, o)node = DeformConv(o, o)"""转置卷积,上采样"""up = nn.ConvTranspose2d(o, o, f * 2, stride=f,padding=f // 2, output_padding=0,groups=o, bias=False)fill_up_weights(up)setattr(self, 'proj_' + str(i), proj)setattr(self, 'up_' + str(i), up)setattr(self, 'node_' + str(i), node)def forward(self, layers, startp, endp):for i in range(startp + 1, endp):upsample = getattr(self, 'up_' + str(i - startp))project = getattr(self, 'proj_' + str(i - startp))layers[i] = upsample(project(layers[i]))node = getattr(self, 'node_' + str(i - startp))layers[i] = node(layers[i] + layers[i - 1])

在这里插入图片描述

DLAUP与IDAUP结构图,图中红框表示了本项目中使用的部分,4s,8s,16s,32s分别表示了leve2、3、4、5的输出,将特征融合最终输出4s特征图。

1.3 检测头
       """DLASeg中检测头的构造"""for head in self.heads:classes = self.heads[head]if head_conv > 0:fc = nn.Sequential(nn.Conv2d(channels[self.first_level], head_conv,kernel_size=3, padding=1, bias=True),nn.ReLU(inplace=True),nn.Conv2d(head_conv, classes,kernel_size=final_kernel, stride=1,padding=final_kernel // 2, bias=True))if 'hm' in head:fc[-1].bias.data.fill_(-2.19)else:fill_fc_weights(fc)else:fc = nn.Conv2d(channels[self.first_level], classes,kernel_size=final_kernel, stride=1,padding=final_kernel // 2, bias=True)if 'hm' in head:fc.bias.data.fill_(-2.19)else:fill_fc_weights(fc)self.__setattr__(head, fc)"""
forward中的传递,y[-1]为IDAUP输出的4s特征图
"""z = {}for head in self.heads:z[head] = self.__getattr__(head)(y[-1])

2.decode_detection

还是先来看一下在ct_snake中decode_detection的使用:

    def forward(self, x, batch=None):output, cnn_feature = self.dla(x)"""output为检测头的输出、cnn_feature为检测头之前的骨干网络提取的特征"""with torch.no_grad():ct, detection = self.decode_detection(output, cnn_feature.size(2), cnn_feature.size(3))if cfg.use_gt_det: #Falseself.use_gt_detection(output, batch)output = self.gcn(output, cnn_feature, batch)return outputdef decode_detection(self, output, h, w):ct_hm = output['ct_hm']wh = output['wh']ct, detection = snake_decode.decode_ct_hm(torch.sigmoid(ct_hm), wh)detection[..., :4] = data_utils.clip_to_image(detection[..., :4], h, w) #防止越界output.update({'ct': ct, 'detection': detection})return ct, detection

主要使用了snake_decode.decode_ct_hm函数

def decode_ct_hm(ct_hm, wh, reg=None, K=100):"""假设传入:ct_hm [80,20,128,128],wh [80,2,128,128]"""batch, cat, height, width = ct_hm.size()ct_hm = nms(ct_hm) #非极大值抑制"""输出topk的得分值,索引,类别,xy值 [80,100]"""scores, inds, clses, ys, xs = topk(ct_hm, K=K)wh = transpose_and_gather_feat(wh, inds) #获取topk对应的宽高 [80,100,2]wh = wh.view(batch, K, 2)if reg is not None:reg = transpose_and_gather_feat(reg, inds)reg = reg.view(batch, K, 2)xs = xs.view(batch, K, 1) + reg[:, :, 0:1]ys = ys.view(batch, K, 1) + reg[:, :, 1:2]else:xs = xs.view(batch, K, 1)ys = ys.view(batch, K, 1)clses = clses.view(batch, K, 1).float() #[80,100,1]scores = scores.view(batch, K, 1) #[80,100,1]ct = torch.cat([xs, ys], dim=2) #构建center [80,100,2]bboxes = torch.cat([xs - wh[..., 0:1] / 2,ys - wh[..., 1:2] / 2,xs + wh[..., 0:1] / 2,ys + wh[..., 1:2] / 2], dim=2) #构建bounding box [80,100,4]detection = torch.cat([bboxes, scores, clses], dim=2) #[80,100,4(bbox)+1(scores)+1(clses)]return ct, detection def nms(heat, kernel=3):"""使用3*3最大值池化进行非极大值抑制"""pad = (kernel - 1) // 2hmax = nn.functional.max_pool2d(heat, (kernel, kernel), stride=1, padding=pad)keep = (hmax == heat).float()return heat * keepdef topk(scores, K=40):#假设现在scores维度为[80,10,128,128]batch, cat, height, width = scores.size()"""选择每类中最大的K个值 ->topk_scores [80,10,100]对应的索引 -> topk_inds [80,10,100]注意上述是在scores.view(batch,cat,-1):[80,10,128*128]中的索引。"""topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)topk_inds = topk_inds % (height * width)    #这一行貌似没有作用?topk_ys = (topk_inds / width).int().float() #获取对应的图片的第几行topk_xs = (topk_inds % width).int().float()	#第几列"""选择topk_scores所有类中最大的k个值 ->topk_score [80,100]对应的索引-> topk_ind [80,100]注意上述是在topk_scores.view(batch,-1):[80,10*100]中的索引。"""topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)topk_clses = (topk_ind / K).int() #最大值对应的类别 [80,100]"""获取在原始输入中最大K值对应的索引以及xy值"""topk_inds = gather_feat(topk_inds.view(batch, -1, 1), topk_ind).view(batch, K) #[80,100]topk_ys = gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)#[80,100]topk_xs = gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)#[80,100]return topk_score, topk_inds, topk_clses, topk_ys, topk_xsdef gather_feat(feat, ind, mask=None):"""以上述gather_feat(topk_inds.view(batch,-1,1),topk_ind)为例传入feat : topk_inds.view(batch,-1,1) : topk_inds[80,10,100]-> [80,10*100,1]传入ind:   topk_ind [80,100]"""dim = feat.size(2) #1ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) #[80,100]->[80,100,1]"""topk_inds 对应的是topk_scores的索引,也即在图片中位置的索引;而topk_ind对应的是topk_score的索引,是在topk_scores中最大的K个值的索引;因此使用topk_inds.gather(1,topk_ind)即可实现将topk_ind映射回在图片中对应位置的索引。"""feat = feat.gather(1, ind)if mask is not None:mask = mask.unsqueeze(2).expand_as(feat)feat = feat[mask]feat = feat.view(-1, dim)return feat #[80,100]def transpose_and_gather_feat(feat, ind):"""传入wh:[80,2,128,128] ->feattopk的索引[80,100] -> ind"""feat = feat.permute(0, 2, 3, 1).contiguous() #[80,128,128,2]feat = feat.view(feat.size(0), -1, feat.size(3))#[80,128*128,2]feat = gather_feat(feat, ind) #返回topk对应的宽高 [80,100,2]return feat
  • 最后的结果是在output内更新了两个键值’ct’: [batch,topk,2] 和’detection’: [batch,topk,6],其中detection[…,:4]为bounding box,[…,4]为topk得分,[…,5]为class。
  • ct和detection貌似没有在训练过程中使用,而是在测试过程中使用。

3. Evolution

Evolution模块主要可以分为:

  • snake层
  • 轮廓演化:
    • 训练
    • 测试
3.1 snake

Snake层主要是对应于论文中的循环卷积。

  • input[batch,66,顶点数]

  • 首先构造head,将输入卷积(feature_dim66)至state_dim128;

  • 然后使用空洞圆卷积构建七层残差层,这里空洞圆卷积是通过将input后方的特征拼接到input前,再将input前方的特征拼接到input后通过fc实现循环卷积;

  • 接着将head和残差层按通道进行拼接后输入fuse层,输出通道256;

  • 取fuse层输出每个特征通道的最大值(shape=[batch,256,1]),并拓展至[batch,256,顶点数];

  • 将head+七层残差层+上一步的输出按照通道进行融合,传入prediction层,最后输出[batch,2,顶点数]。

"""
空洞圆卷积
可以看到在构造过程中与普通fc层无异,卷积核为9,dialtion根据Snake构造函数定义
重点是在forward函数中,先将input特征进行拼接:input后面的-n_adj*dialtion + input + input前面的n_adj*dialtion
形成环形,再进行fc卷积。
"""
class DilatedCircConv(nn.Module):def __init__(self, state_dim, out_state_dim=None, n_adj=4, dilation=1):super(DilatedCircConv, self).__init__()self.n_adj = n_adjself.dilation = dilationout_state_dim = state_dim if out_state_dim is None else out_state_dimself.fc = nn.Conv1d(state_dim, out_state_dim, kernel_size=self.n_adj*2+1, dilation=self.dilation)def forward(self, input, adj):"""此处传入adj邻接矩阵,但是好像并没有用到,推测可能是原来想用于GCN的"""if self.n_adj != 0:input = torch.cat([input[..., -self.n_adj*self.dilation:], input, input[..., :self.n_adj*self.dilation]], dim=2)return self.fc(input)"""
Conv-relu-bn,Conv为空洞圆卷积
"""
class BasicBlock(nn.Module):def __init__(self, state_dim, out_state_dim, conv_type, n_adj=4, dilation=1):super(BasicBlock, self).__init__()self.conv = _conv_factory[conv_type](state_dim, out_state_dim, n_adj, dilation) # DilatedCircConvself.relu = nn.ReLU(inplace=True)self.norm = nn.BatchNorm1d(out_state_dim)def forward(self, x, adj=None):x = self.conv(x, adj)x = self.relu(x)x = self.norm(x)return xclass Snake(nn.Module):def __init__(self, state_dim, feature_dim, conv_type='dgrid'):super(Snake, self).__init__()"""主要构造1.head为普通basicBlock2.七个残差层3.融合层4.预测层"""self.head = BasicBlock(feature_dim, state_dim, conv_type)self.res_layer_num = 7dilation = [1, 1, 1, 2, 2, 4, 4] # 每个残差层的dilation参数#循环构建七个残差层for i in range(self.res_layer_num):conv = BasicBlock(state_dim, state_dim, conv_type, n_adj=4, dilation=dilation[i])self.__setattr__('res'+str(i), conv)fusion_state_dim = 256 #将head及残差层所有的特征进行融合self.fusion = nn.Conv1d(state_dim * (self.res_layer_num + 1), fusion_state_dim, 1)#输出偏移量self.prediction = nn.Sequential(nn.Conv1d(state_dim * (self.res_layer_num + 1) + fusion_state_dim, 256, 1),nn.ReLU(inplace=True),nn.Conv1d(256, 64, 1),nn.ReLU(inplace=True),nn.Conv1d(64, 2, 1))def forward(self, x, adj):"""假设传入x[80,66,40],state_dim=128,feature_dim=66"""states = []x = self.head(x, adj) #[80,128,40]states.append(x)for i in range(self.res_layer_num):x = self.__getattr__('res'+str(i))(x, adj) + xstates.append(x) #[80,128,40]#将head和res进行拼接state = torch.cat(states, dim=1)  #[80,128*8,40]"""torch.max([80,256,40],dim=2,keepdim=True)[0] ->[80,256,1]相当于取每个通道的最大值"""global_state = torch.max(self.fusion(state), dim=2, keepdim=True)[0]global_state = global_state.expand(global_state.size(0), global_state.size(1), state.size(2)) #[80,256,40]state = torch.cat([global_state, state], dim=1) #[80,256+128*8,40]x = self.prediction(state)return x
3.2 Evolution
class Evolution(nn.Module):def __init__(self):super(Evolution, self).__init__()self.fuse = nn.Conv1d(128, 64, 1)self.init_gcn = Snake(state_dim=128, feature_dim=64 + 2, conv_type='dgrid')self.evolve_gcn = Snake(state_dim=128, feature_dim=64 + 2, conv_type='dgrid')self.iter = 2for i in range(self.iter):evolve_gcn = Snake(state_dim=128, feature_dim=64 + 2, conv_type='dgrid')self.__setattr__('evolve_gcn' + str(i), evolve_gcn)for m in self.modules():if isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv2d):m.weight.data.normal_(0.0, 0.01)if m.bias is not None:nn.init.constant_(m.bias, 0)
3.2.1训练过程
  • 训练过程中是先收集了标注的信息进行初始化,获得四十个点的初始坐标;
  • 将初始坐标(利用bbox中点构建的四边形)传入snake,然后返回了极值点的预测(因为gt为极值点),更新进ret中;
  • 根据标注的轮廓(利用极值点+bbox中点构建的八边形)进行初始轮廓演化,得到128个点的轮廓
  • 将得到的轮廓传入下一次演化,迭代两次
  • 将演化的轮廓更新进ret键值
def forward(self, output, cnn_feature, batch=None):ret = outputif batch is not None and 'test' not in batch['meta']:with torch.no_grad():"""init:ct_01记录了batch中每张图片里有几个poly,以及对应的poly位置collect_training将batch中所有的poly拼接在一起 size:(polynum,40,2)ct_num记录每张图片内有几个ct4py_ind记录了拼接后的poly对应的图片 size:(polynum,) 如:【0,1,2,2,3,4,4,4】output:更新了i_init_py..等四个py"""init = self.prepare_training(output, batch)# 返回了预测的四个bbox中点坐标(极值点)ex_pred = self.init_poly(self.init_gcn, cnn_feature, init['i_it_4py'], init['c_it_4py'], init['4py_ind'])ret.update({'ex_pred': ex_pred, 'i_gt_4py': output['i_gt_4py']}) # ret 和output是一起改变的 为什么还要更新gt呢?# with torch.no_grad():#     init = self.prepare_training_evolve(output, batch, init)py_pred = self.evolve_poly(self.evolve_gcn, cnn_feature, init['i_it_py'], init['c_it_py'], init['py_ind'])py_preds = [py_pred]for i in range(self.iter): #range(2)py_pred = py_pred / snake_config.roc_py_pred = snake_gcn_utils.img_poly_to_can_poly(py_pred)evolve_gcn = self.__getattr__('evolve_gcn' + str(i))py_pred = self.evolve_poly(evolve_gcn, cnn_feature, py_pred, c_py_pred, init['py_ind'])py_preds.append(py_pred)ret.update({'py_pred': py_preds, 'i_gt_py': output['i_gt_py'] * snake_config.ro})

这里训练过程其实是用到了标注的信息,收集标注中信息。

def prepare_training(self, output, batch):"""ct_01记录了batch中每张图片里有几个poly,以及对应的poly位置collect_training将batch中所有的poly拼接在一起 size:(polynum,40,2)ct_num记录每张图片内有几个ct4py_ind记录了拼接后的poly对应的图片 size:(polynum,) 如:【0,1,2,2,3,4,4,4】"""init = snake_gcn_utils.prepare_training(output, batch)output.update({'i_it_4py': init['i_it_4py'], 'i_it_py': init['i_it_py']})output.update({'i_gt_4py': init['i_gt_4py'], 'i_gt_py': init['i_gt_py']})return init#snake_gcn_utils.prepare_training
def prepare_training(ret, batch):"""ct_01记录了batch中每张图片里有几个poly,以及对应的poly位置 size(80,一张图片内最多的数量) 比如80张图片内最多有13个poly: (80,13)polynum=ct_01.sumcollect_training将batch中所有的poly拼接在一起 size:(polynum,40,2)ct_num记录每张图片内有几个ct4py_ind记录了拼接后的poly对应的图片 shape:(polynum,) 如:【0,1,2,2,3,4,4,4】表示第一张图片有一个poly,位置为0;第二张图片有一个poly,位置1;第三张图片有两个poly,位置2,3;第四张图片有一个poly,位置4;第五张图片有三个poly,位置为5,6,7。"""ct_01 = batch['ct_01'].byte()init = {}init.update({'i_it_4py': collect_training(batch['i_it_4py'], ct_01)})init.update({'c_it_4py': collect_training(batch['c_it_4py'], ct_01)})init.update({'i_gt_4py': collect_training(batch['i_gt_4py'], ct_01)})init.update({'c_gt_4py': collect_training(batch['c_gt_4py'], ct_01)})init.update({'i_it_py': collect_training(batch['i_it_py'], ct_01)})init.update({'c_it_py': collect_training(batch['c_it_py'], ct_01)})init.update({'i_gt_py': collect_training(batch['i_gt_py'], ct_01)})init.update({'c_gt_py': collect_training(batch['c_gt_py'], ct_01)})ct_num = batch['meta']['ct_num']init.update({'4py_ind': torch.cat([torch.full([ct_num[i]], i) for i in range(ct_01.size(0))], dim=0)})init.update({'py_ind': init['4py_ind']})if snake_config.train_pred_box:#Flaseprepare_training_box(ret, batch, init)init['4py_ind'] = init['4py_ind'].to(ct_01.device)init['py_ind'] = init['py_ind'].to(ct_01.device)return initdef collect_training(poly, ct_01):"""将batch中所有的poly拼接在一起 size:(polynum,40,2)"""batch_size = ct_01.size(0)poly = torch.cat([poly[i][ct_01[i]] for i in range(batch_size)], dim=0)return poly

初始化轮廓

  • 获取i_it_4py(四十个中点连接采样多边形的坐标)的特征,[polynum,64,40]
  • 获取中心坐标特征[polynum,64,1],将中心点坐标特征与i_it_4py特征进行拼接,[polynum,128,40]
  • 传入fuse,获得i_init_feature: [polynum,64,40]
  • 将归一化的坐标合并 [polynum,66,40],传入snake,获得坐标偏移量[polynum,40,2]
  • 将坐标偏移量加上原始i_it_4py,并采样,获得预测的极值点坐标( 因为标注是将极值点作为gt):[polynum,4,2].
def init_poly(self, snake, cnn_feature, i_it_poly, c_it_poly, ind):"""Snake:Snake(state_dim=128,feature_dim=64+2,conv_type='dgrid'cnn_feature: size (80,64,128,128)i_it_poly: init['i_it_4py'] (199,40,2) 中点连接采样的多边形ind:       init['py_ind'] (199,)"""if len(i_it_poly) == 0:return torch.zeros([0, 4, 2]).to(i_it_poly)h, w = cnn_feature.size(2), cnn_feature.size(3)init_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, i_it_poly, ind, h, w)  # size (199,64,40)center = (torch.min(i_it_poly, dim=1)[0] + torch.max(i_it_poly, dim=1)[0]) * 0.5  # 每列中最小的行和最大的行 size (199,2)ct_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, center[:, None], ind, h,w)  # poly中心的feature size(199,64,1)init_feature = torch.cat([init_feature, ct_feature.expand_as(init_feature)], dim=1) # size(199,128,40)init_feature = self.fuse(init_feature) #size(199,64,40)init_input = torch.cat([init_feature, c_it_poly.permute(0, 2, 1)], dim=1) #将归一化的坐标合并到特征 (199,64,40)+(199,2,40)-》(199,64,40)adj = snake_gcn_utils.get_adj_ind(snake_config.adj_num, init_input.size(2), init_input.device)i_poly = i_it_poly + snake(init_input, adj).permute(0, 2, 1) #这里adj实际上并没有用到 snake.permute -> (199,40,2)可以看作是偏移量i_poly = i_poly[:, ::snake_config.init_poly_num // 4] #size(199,4,2) return i_polydef get_gcn_feature(cnn_feature, img_poly, ind, h, w):img_poly = img_poly.clone()"""对poly进行归一化直接除以w,h不就行了吗?为什么还要乘2-1? img_poly[...,0]/w 取值范围为【0,1],此时*2-1 取值范围就变为了【-1,1】,可以应用到grid_sample中。"""img_poly[..., 0] = img_poly[..., 0] / (w / 2.) - 1img_poly[..., 1] = img_poly[..., 1] / (h / 2.) - 1batch_size = cnn_feature.size(0)"""gcn_feature size : (多边形数量,特征维度,多边形顶点数)"""gcn_feature = torch.zeros([img_poly.size(0), cnn_feature.size(1), img_poly.size(1)]).to(img_poly.device)for i in range(batch_size):poly = img_poly[ind == i].unsqueeze(0)feature = torch.nn.functional.grid_sample(cnn_feature[i:i + 1], poly)[0].permute(1, 0, 2)gcn_feature[ind == i] = featurereturn gcn_feature

这里跟init_poly很相似,不过没有将中心点特征进行融合;乘四倍是因为此时进行了4s下采样,后面gt也乘了四倍。

def evolve_poly(self, snake, cnn_feature, i_it_poly, c_it_poly, ind):if len(i_it_poly) == 0:return torch.zeros_like(i_it_poly)h, w = cnn_feature.size(2), cnn_feature.size(3)init_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, i_it_poly, ind, h, w) #199,64,128c_it_poly = c_it_poly * snake_config.ro #将归一化的坐标乘以4倍init_input = torch.cat([init_feature, c_it_poly.permute(0, 2, 1)], dim=1) #将乘以四倍的归一化的坐标合并到特征adj = snake_gcn_utils.get_adj_ind(snake_config.adj_num, init_input.size(2), init_input.device)i_poly = i_it_poly * snake_config.ro + snake(init_input, adj).permute(0, 2, 1) #(199,128,2)return i_poly
3.2.2 损失函数

这里损失函数比较简单易懂:loss=ct_loss+0.1wh_loss+ex_loss+py_loss

class NetworkWrapper(nn.Module):def __init__(self, net):super(NetworkWrapper, self).__init__()self.net = netself.ct_crit = net_utils.FocalLoss()self.wh_crit = net_utils.IndL1Loss1d('smooth_l1')self.reg_crit = net_utils.IndL1Loss1d('smooth_l1')self.ex_crit = torch.nn.functional.smooth_l1_lossself.py_crit = torch.nn.functional.smooth_l1_lossdef forward(self, batch):output = self.net(batch['inp'], batch)scalar_stats = {}loss = 0ct_loss = self.ct_crit(net_utils.sigmoid(output['ct_hm']), batch['ct_hm'])scalar_stats.update({'ct_loss': ct_loss})loss += ct_losswh_loss = self.wh_crit(output['wh'], batch['wh'], batch['ct_ind'], batch['ct_01'])scalar_stats.update({'wh_loss': wh_loss})loss += 0.1 * wh_loss# reg_loss = self.reg_crit(output['reg'], batch['reg'], batch['ct_ind'], batch['ct_01'])# scalar_stats.update({'reg_loss': reg_loss})# loss += reg_lossex_loss = self.ex_crit(output['ex_pred'], output['i_gt_4py'])scalar_stats.update({'ex_loss': ex_loss})loss += ex_losspy_loss = 0output['py_pred'] = [output['py_pred'][-1]]for i in range(len(output['py_pred'])):py_loss += self.py_crit(output['py_pred'][i], output['i_gt_py']) / len(output['py_pred'])scalar_stats.update({'py_loss': py_loss})loss += py_lossscalar_stats.update({'loss': loss})image_stats = {}return output, loss, scalar_stats, image_stats
3.2.3 训练过程

与训练过程中不同的地方主要在于:

  • 利用ouput[‘detection’]进行了bbox中点连接四边形的初始化,在训练中是直接用了标注信息;
  • 利用预测的ex建立八边形并采样,作为轮廓演化的初值,在训练中也是直接用了标注信息。
  • 最后输出的是ct值超过ct_score的多边形poly。
def forward(self, output, cnn_feature, batch=None):ret = outputif not self.training:with torch.no_grad():"""初始化,通过ct_score更新output['detection'] shape(>ct_score,6)将预测i_it_4py更新到it_ex:shape[>ct_score,40,2]"""init = self.prepare_testing_init(output)ex = self.init_poly(self.init_gcn, cnn_feature, init['i_it_4py'], init['c_it_4py'], init['ind']) #[>ct_score,4,2]ret.update({'ex': ex})"""根据ex建立八边形并采样,获取轮廓it_py([>ct_score,128,2])更新进入output['it_py']evole则包含两个键:'i_it_py','c_it_py'shape与it_py相同"""evolve = self.prepare_testing_evolve(output, cnn_feature.size(2), cnn_feature.size(3))py = self.evolve_poly(self.evolve_gcn, cnn_feature, evolve['i_it_py'], evolve['c_it_py'], init['ind'])pys = [py / snake_config.ro]for i in range(self.iter):py = py / snake_config.roc_py = snake_gcn_utils.img_poly_to_can_poly(py)evolve_gcn = self.__getattr__('evolve_gcn' + str(i))py = self.evolve_poly(evolve_gcn, cnn_feature, py, c_py, init['ind'])pys.append(py / snake_config.ro)ret.update({'py': pys})return output

test初始化,利用detection获取i_it_4py(ex)

def prepare_testing_init(self, output):"""output['detection'] shape(batch,100,6)init通过detection获得超过ct_score的i_it_4py,c_it_4py,以及对应的索引indoutput也通过ct_score筛选,并将i_it_4py更新入键值it_ex"""init = snake_gcn_utils.prepare_testing_init(output['detection'][..., :4], output['detection'][..., 4])output['detection'] = output['detection'][output['detection'][..., 4] > snake_config.ct_score]output.update({'it_ex': init['i_it_4py']})return init#snake_gcn_utils.repare_testing_init
def prepare_testing_init(box, score):"""box=output['detection'][..., :4] shape(batch,100,4)score=output['detection'][..., 4] shape(batch,100)"""i_it_4pys = snake_decode.get_init(box) #根据中点建立四边形 [batch,100,4,2]i_it_4pys = uniform_upsample(i_it_4pys, snake_config.init_poly_num) #采样40个点 [batch,100,40,2]c_it_4pys = img_poly_to_can_poly(i_it_4pys) #归一化ind = score > snake_config.ct_score #根据阈值筛选i_it_4pys = i_it_4pys[ind] #[过阈值数量,40,2]c_it_4pys = c_it_4pys[ind] #[过阈值数量,40,2]"""ind shape[超过阈值数量,]for example:[0,0,0,1,2,2]表示第一个batch有3个超过阈值的,第二个batch有1个,第三个batch有2个。"""ind = torch.cat([torch.full([ind[i].sum()], i) for i in range(ind.size(0))], dim=0)init = {'i_it_4py': i_it_4pys, 'c_it_4py': c_it_4pys, 'ind': ind}return init

test轮廓演化初始化,利用ex建立八边形轮廓

def prepare_testing_evolve(self, output, h, w):"""根据极值点建立八边形,并进行采样,获得轮廓i_py ([ct_score,128,2])更新进output['it_py']"""ex = output['ex'] #[>ct_score,4,2]ex[..., 0] = torch.clamp(ex[..., 0], min=0, max=w - 1)#越界修正ex[..., 1] = torch.clamp(ex[..., 1], min=0, max=h - 1)evolve = snake_gcn_utils.prepare_testing_evolve(ex) output.update({'it_py': evolve['i_it_py']})return evolve
#snake_gcn_utils.prepare_testing_evolve
def prepare_testing_evolve(ex):if len(ex) == 0:i_it_pys = torch.zeros([0, snake_config.poly_num, 2]).to(ex)c_it_pys = torch.zeros_like(i_it_pys)else:i_it_pys = snake_decode.get_octagon(ex[None]) #根据极值点建立八边形 #[1,>ct_score,12,2]i_it_pys = uniform_upsample(i_it_pys, snake_config.poly_num)[0] #[>ct_score,128,2]c_it_pys = img_poly_to_can_poly(i_it_pys)evolve = {'i_it_py': i_it_pys, 'c_it_py': c_it_pys}return evolve

四、总结

  • 项目用了大量的高级封装调用起来很简单,阅读起来就比较麻烦,但也是让我学到了很多封装的技巧。
  • 由于项目的环境都比较老,迁移到本地的时候环境配置也需要调整,我也是花了不少时间搜索了不少资料才完成了cuda拓展的编译。
  • 该项目采用了DLA34作为特征提取网络,而没有接触过DLA的我第一次看DLA网络很头大,因为代码写的比较抽象吧,看懂backbone大概花了我一个星期吧,DLA系列看起来还挺有趣的,感兴趣的可以深入学习一下。
  • 项目数据集的处理和轮廓演化是相互配合的,需要结合一起学习。

这篇关于DeepSnake源代码阅读笔记----process on Sbd datasets的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/215394

相关文章

JAVA智听未来一站式有声阅读平台听书系统小程序源码

智听未来,一站式有声阅读平台听书系统 🌟&nbsp;开篇:遇见未来,从“智听”开始 在这个快节奏的时代,你是否渴望在忙碌的间隙,找到一片属于自己的宁静角落?是否梦想着能随时随地,沉浸在知识的海洋,或是故事的奇幻世界里?今天,就让我带你一起探索“智听未来”——这一站式有声阅读平台听书系统,它正悄悄改变着我们的阅读方式,让未来触手可及! 📚&nbsp;第一站:海量资源,应有尽有 走进“智听

【学习笔记】 陈强-机器学习-Python-Ch15 人工神经网络(1)sklearn

系列文章目录 监督学习:参数方法 【学习笔记】 陈强-机器学习-Python-Ch4 线性回归 【学习笔记】 陈强-机器学习-Python-Ch5 逻辑回归 【课后题练习】 陈强-机器学习-Python-Ch5 逻辑回归(SAheart.csv) 【学习笔记】 陈强-机器学习-Python-Ch6 多项逻辑回归 【学习笔记 及 课后题练习】 陈强-机器学习-Python-Ch7 判别分析 【学

系统架构师考试学习笔记第三篇——架构设计高级知识(20)通信系统架构设计理论与实践

本章知识考点:         第20课时主要学习通信系统架构设计的理论和工作中的实践。根据新版考试大纲,本课时知识点会涉及案例分析题(25分),而在历年考试中,案例题对该部分内容的考查并不多,虽在综合知识选择题目中经常考查,但分值也不高。本课时内容侧重于对知识点的记忆和理解,按照以往的出题规律,通信系统架构设计基础知识点多来源于教材内的基础网络设备、网络架构和教材外最新时事热点技术。本课时知识

SWAP作物生长模型安装教程、数据制备、敏感性分析、气候变化影响、R模型敏感性分析与贝叶斯优化、Fortran源代码分析、气候数据降尺度与变化影响分析

查看原文>>>全流程SWAP农业模型数据制备、敏感性分析及气候变化影响实践技术应用 SWAP模型是由荷兰瓦赫宁根大学开发的先进农作物模型,它综合考虑了土壤-水分-大气以及植被间的相互作用;是一种描述作物生长过程的一种机理性作物生长模型。它不但运用Richard方程,使其能够精确的模拟土壤中水分的运动,而且耦合了WOFOST作物模型使作物的生长描述更为科学。 本文让更多的科研人员和农业工作者

论文阅读笔记: Segment Anything

文章目录 Segment Anything摘要引言任务模型数据引擎数据集负责任的人工智能 Segment Anything Model图像编码器提示编码器mask解码器解决歧义损失和训练 Segment Anything 论文地址: https://arxiv.org/abs/2304.02643 代码地址:https://github.com/facebookresear

数学建模笔记—— 非线性规划

数学建模笔记—— 非线性规划 非线性规划1. 模型原理1.1 非线性规划的标准型1.2 非线性规划求解的Matlab函数 2. 典型例题3. matlab代码求解3.1 例1 一个简单示例3.2 例2 选址问题1. 第一问 线性规划2. 第二问 非线性规划 非线性规划 非线性规划是一种求解目标函数或约束条件中有一个或几个非线性函数的最优化问题的方法。运筹学的一个重要分支。2

【C++学习笔记 20】C++中的智能指针

智能指针的功能 在上一篇笔记提到了在栈和堆上创建变量的区别,使用new关键字创建变量时,需要搭配delete关键字销毁变量。而智能指针的作用就是调用new分配内存时,不必自己去调用delete,甚至不用调用new。 智能指针实际上就是对原始指针的包装。 unique_ptr 最简单的智能指针,是一种作用域指针,意思是当指针超出该作用域时,会自动调用delete。它名为unique的原因是这个

查看提交历史 —— Git 学习笔记 11

查看提交历史 查看提交历史 不带任何选项的git log-p选项--stat 选项--pretty=oneline选项--pretty=format选项git log常用选项列表参考资料 在提交了若干更新,又或者克隆了某个项目之后,你也许想回顾下提交历史。 完成这个任务最简单而又有效的 工具是 git log 命令。 接下来的例子会用一个用于演示的 simplegit

记录每次更新到仓库 —— Git 学习笔记 10

记录每次更新到仓库 文章目录 文件的状态三个区域检查当前文件状态跟踪新文件取消跟踪(un-tracking)文件重新跟踪(re-tracking)文件暂存已修改文件忽略某些文件查看已暂存和未暂存的修改提交更新跳过暂存区删除文件移动文件参考资料 咱们接着很多天以前的 取得Git仓库 这篇文章继续说。 文件的状态 不管是通过哪种方法,现在我们已经有了一个仓库,并从这个仓

忽略某些文件 —— Git 学习笔记 05

忽略某些文件 忽略某些文件 通过.gitignore文件其他规则源如何选择规则源参考资料 对于某些文件,我们不希望把它们纳入 Git 的管理,也不希望它们总出现在未跟踪文件列表。通常它们都是些自动生成的文件,比如日志文件、编译过程中创建的临时文件等。 通过.gitignore文件 假设我们要忽略 lib.a 文件,那我们可以在 lib.a 所在目录下创建一个名为 .gi