本文主要是介绍MMDetection 教程(一),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
1、安装配置MMDetection
#新建conda环境 !conda env list
# conda environments:
#/home/snnu/anaconda3/home/snnu/anaconda3/envs/jxl/home/snnu/anaconda3/envs/torch18/home/snnu/anaconda3/envs/zt
base /home/snnu/miniconda3
mmcls /home/snnu/miniconda3/envs/mmcls
mmdet /home/snnu/miniconda3/envs/mmdet
mmdet17 * /home/snnu/miniconda3/envs/mmdet17
mmpose /home/snnu/miniconda3/envs/mmpose
paddleVedio /home/snnu/miniconda3/envs/paddleVedio
tensorflow /home/snnu/miniconda3/envs/tensorflow
torch18 /home/snnu/miniconda3/envs/torch18
!pip list
Package Version
------------------- ---------
argon2-cffi 21.1.0
attrs 21.2.0
backcall 0.2.0
bleach 4.1.0
certifi 2021.10.8
cffi 1.15.0
debugpy 1.5.1
decorator 5.1.0
defusedxml 0.7.1
entrypoints 0.3
importlib-resources 5.4.0
ipykernel 6.5.0
ipython 7.29.0
ipython-genutils 0.2.0
ipywidgets 7.6.5
jedi 0.18.0
Jinja2 3.0.3
jsonschema 4.2.1
jupyter 1.0.0
jupyter-client 7.0.6
jupyter-console 6.4.0
jupyter-core 4.9.1
jupyterlab-pygments 0.1.2
jupyterlab-widgets 1.0.2
MarkupSafe 2.0.1
matplotlib-inline 0.1.3
mistune 0.8.4
nbclient 0.5.8
nbconvert 6.3.0
nbformat 5.1.3
nest-asyncio 1.5.1
notebook 6.4.5
packaging 21.2
pandocfilters 1.5.0
parso 0.8.2
pexpect 4.8.0
pickleshare 0.7.5
pip 21.2.4
prometheus-client 0.12.0
prompt-toolkit 3.0.22
ptyprocess 0.7.0
pycparser 2.21
Pygments 2.10.0
pyparsing 2.4.7
pyrsistent 0.18.0
python-dateutil 2.8.2
pyzmq 22.3.0
qtconsole 5.2.0
QtPy 1.11.2
Send2Trash 1.8.0
setuptools 58.0.4
six 1.16.0
terminado 0.12.1
testpath 0.5.0
tornado 6.1
traitlets 5.1.1
wcwidth 0.2.5
webencodings 0.5.1
wheel 0.37.0
widgetsnbextension 3.5.2
zipp 3.6.0
#安装pytorch !pip install torchvision==0.9.1
Collecting torchvision==0.9.1Using cached torchvision-0.9.1-cp38-cp38-manylinux1_x86_64.whl (17.4 MB)
Collecting pillow>=4.1.1Using cached Pillow-8.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
Collecting numpyUsing cached numpy-1.21.4-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (15.7 MB)
Collecting torch==1.8.1Using cached torch-1.8.1-cp38-cp38-manylinux1_x86_64.whl (804.1 MB)
Collecting typing-extensionsDownloading typing_extensions-4.0.0-py3-none-any.whl (22 kB)
Installing collected packages: typing-extensions, numpy, torch, pillow, torchvision
Successfully installed numpy-1.21.4 pillow-8.4.0 torch-1.8.1 torchvision-0.9.1 typing-extensions-4.0.0
import torch print(torch.__version__) torch.cuda.is_available()
1.8.1+cu102
True
#安装mmcv-full 注意对应版本 # !pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.8.0/index.html #加sudo权限
# 下载mmdetection17版本 # !git clone https://github.com/open-mmlab/mmclassification.git
%cd mmdetection/
/home/snnu/chenkequan/notebooke/mmdet/mmdetection
# 安装依赖 # !pip install -e .
#检查环境 #检查torch 和gpu import torch ,torchvision print(torch.__version__,torch.cuda.is_available()) #检查mmdetection import mmdet print(mmdet.__version__) #检查mmcv from mmcv.ops import get_compiling_cuda_version,get_compiler_version print(get_compiler_version()) print(get_compiling_cuda_version())
1.8.1+cu102 True 2.18.0 GCC 7.3 10.2
2、使用faster RCNN模型实现单张推理
from PIL import Image Image.open('demo/demo.jpg')
# 下载训练好的faster-RCNN模型到checkpoints文件夹 !mkdir checkpoints !wget -c https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth -O checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth
--2021-11-15 11:20:56-- https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth
正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.108.5.70
正在连接 download.openmmlab.com (download.openmmlab.com)|47.108.5.70|:443... 已连接。
已发出 HTTP 请求,正在等待回应... 200 OK
长度: 167287506 (160M) [application/octet-stream]
正在保存至: “checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth”
checkpoints/faster_ 100%[===================>] 159.54M 10.9MB/s in 15s
2021-11-15 11:21:12 (10.3 MB/s) - 已保存 “checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth” [167287506/167287506])
from mmdet.apis import inference_detector ,init_detector ,show_result_pyplot
#设置device
device1 = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#选择模型对应的配置文件
condfig = 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
#选择下载好的checkpoint
checkpoint = 'checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
#初始化模型
model = init_detector(condfig,checkpoint,device=device1)
load checkpoint from local path: checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth
#使用模型进行单张推理
img = 'demo/demo.jpg'
result = inference_detector(model,img)
/home/snnu/chenkequan/notebooke/mmdet/mmdetection/mmdet/datasets/utils.py:65: UserWarning: "ImageToTensor" pipeline is replaced by "DefaultFormatBundle" for batch inference. It is recommended to manually replace it in the test data pipeline in your config file.warnings.warn(
#在原图上绘制结果
show_result_pyplot(model,img,result,score_thr=0.8)
result
[array([[3.75348572e+02, 1.19171005e+02, 3.81950867e+02, 1.34460617e+02,1.35454655e-01],[5.32362000e+02, 1.09554726e+02, 5.40526550e+02, 1.25222633e+02,8.88788104e-02],[3.61124298e+02, 1.09049202e+02, 3.68625610e+02, 1.22483063e+02,7.20716342e-02]], dtype=float32),array([], shape=(0, 5), dtype=float32),array([[6.09650024e+02, 1.13805901e+02, 6.34511658e+02, 1.36951904e+02,9.88766015e-01],[4.81773712e+02, 1.10480995e+02, 5.22459717e+02, 1.30407104e+02,9.87157285e-01],[1.01821303e+00, 1.12144722e+02, 6.04374390e+01, 1.44173752e+02,9.83206093e-01],[2.94623749e+02, 1.17035233e+02, 3.78022675e+02, 1.50550873e+02,9.71326888e-01],[3.96328979e+02, 1.11203323e+02, 4.32490540e+02, 1.32729263e+02,9.67802048e-01],[5.90976318e+02, 1.10802658e+02, 6.15401917e+02, 1.26493553e+02,9.59414959e-01],[2.67582001e+02, 1.05686005e+02, 3.28818756e+02, 1.28226547e+02,9.59253430e-01],[1.66856735e+02, 1.08006607e+02, 2.19100693e+02, 1.40194809e+02,9.56841230e-01],[1.89769592e+02, 1.09801109e+02, 3.00310822e+02, 1.53781891e+02,9.51012254e-01],[4.29822510e+02, 1.05655380e+02, 4.82741547e+02, 1.32376724e+02,9.45850194e-01],[5.55000916e+02, 1.09784981e+02, 5.92761780e+02, 1.27808495e+02,9.43992376e-01],[5.96790390e+01, 9.31827927e+01, 8.34545517e+01, 1.06242905e+02,9.33143973e-01],[9.78446579e+01, 8.96542969e+01, 1.18172356e+02, 1.01011108e+02,8.66324604e-01],[1.43899002e+02, 9.61869888e+01, 1.64599808e+02, 1.04979256e+02,8.26784194e-01],[8.55894241e+01, 8.99445801e+01, 9.88920746e+01, 9.85285416e+01,7.53480613e-01],[9.78282623e+01, 9.07443695e+01, 1.10298058e+02, 9.97373276e+01,7.16600537e-01],[2.23579224e+02, 9.85184631e+01, 2.49845108e+02, 1.07509857e+02,6.00782990e-01],[1.68928635e+02, 9.59468994e+01, 1.82843445e+02, 1.05694962e+02,5.91999710e-01],[1.35021347e+02, 9.08739395e+01, 1.50607025e+02, 1.02798874e+02,5.54030061e-01],[0.00000000e+00, 1.11521957e+02, 1.45326672e+01, 1.25850281e+02,5.43520629e-01],[5.53896606e+02, 1.16170540e+02, 5.62602295e+02, 1.26390923e+02,4.76758868e-01],[3.75809753e+02, 1.19579056e+02, 3.82376495e+02, 1.32113892e+02,4.61191684e-01],[1.37924118e+02, 9.37975311e+01, 1.54497177e+02, 1.04659683e+02,4.00998443e-01],[5.55009033e+02, 1.10952698e+02, 5.74925659e+02, 1.26912033e+02,3.43850315e-01],[5.54043152e+02, 1.00959076e+02, 5.61297913e+02, 1.10927711e+02,2.87964016e-01],[6.14741028e+02, 1.01987068e+02, 6.35481628e+02, 1.12593704e+02,2.61201501e-01],[5.70760315e+02, 1.09679382e+02, 5.90286133e+02, 1.27248878e+02,2.58405149e-01],[4.78544116e-01, 1.11568169e+02, 2.25040913e+01, 1.42623535e+02,2.56050646e-01],[3.75093140e+02, 1.11696442e+02, 4.20536804e+02, 1.33691055e+02,2.55963594e-01],[2.62747253e+02, 1.07565620e+02, 3.26765961e+02, 1.43925293e+02,2.09969178e-01],[7.91312561e+01, 9.03788834e+01, 1.00247879e+02, 1.01080872e+02,2.03961387e-01],[6.09313477e+02, 1.13308517e+02, 6.25961975e+02, 1.25342506e+02,1.97424650e-01],[1.35304840e+02, 9.23771439e+01, 1.64080185e+02, 1.04992455e+02,1.49972796e-01],[6.73540573e+01, 8.85008087e+01, 8.29853516e+01, 9.73942108e+01,1.48384452e-01],[5.40852417e+02, 1.13848946e+02, 5.61855530e+02, 1.26198776e+02,1.47629425e-01],[3.51735046e+02, 1.09432648e+02, 4.39310089e+02, 1.34819733e+02,1.41735569e-01],[9.63179016e+01, 8.98780594e+01, 1.53287766e+02, 1.01776367e+02,1.32708758e-01],[4.54495049e+01, 1.17444977e+02, 6.18955803e+01, 1.44275055e+02,1.25890389e-01],[6.06407532e+02, 1.12215973e+02, 6.18935669e+02, 1.24957237e+02,1.10721953e-01],[1.02152626e+02, 9.36143646e+01, 1.41081863e+02, 1.01598961e+02,8.13646615e-02],[3.98364838e+02, 1.12081459e+02, 4.09389862e+02, 1.32897766e+02,7.64547288e-02],[5.39245911e+02, 1.12394836e+02, 5.48756714e+02, 1.21964462e+02,7.32634813e-02],[6.09156555e+02, 1.04017456e+02, 6.35472107e+02, 1.26777176e+02,6.47420660e-02],[3.75895119e+00, 9.85745163e+01, 7.45848312e+01, 1.35155014e+02,6.32170066e-02],[1.68166473e+02, 9.14260483e+01, 2.20303146e+02, 1.07955681e+02,5.16179651e-02],[7.09723892e+01, 9.02684860e+01, 1.05398132e+02, 1.03825508e+02,5.15382327e-02]], dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([[1.8820006e+02, 1.0994707e+02, 3.0047134e+02, 1.5633583e+02,9.7509347e-02],[4.2774911e+02, 1.0511559e+02, 4.8345544e+02, 1.3294328e+02,9.6881956e-02],[2.9450482e+02, 1.1764229e+02, 3.7863281e+02, 1.5046356e+02,7.4364685e-02]], dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([[2.1920888e+02, 1.7456264e+02, 4.6010876e+02, 3.7704666e+02,9.7778010e-01],[3.7206638e+02, 1.3631430e+02, 4.3219525e+02, 1.8717291e+02,4.1699946e-01]], dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([[9.13259964e+01, 1.07155769e+02, 1.06029366e+02, 1.19777306e+02,1.15152769e-01]], dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([[3.7265186e+02, 1.3614308e+02, 4.3205383e+02, 1.8844647e+02,7.7787608e-01],[2.1840410e+02, 1.7513782e+02, 4.6210764e+02, 3.6554123e+02,1.0123572e-01]], dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32),array([], shape=(0, 5), dtype=float32)]
2.1、结果解析
列表中包含的元素为ndarray, 列表长度为80,一共80个类一个ndarray中包含多行 5列行:代表检测到的一个物体列:【物体的坐标(x,y,h,w),置信度】 其中x,y :检测框中心的横纵坐标;hw:检测框的高宽
2.2、Faster RCNN模型解析
#查看Faster-RCNN模型结构
for name ,model in model.named_children():print(name)[print(F' {n}') for n, _ in model.named_children()]
backboneconv1bn1relumaxpoollayer1layer2layer3layer4 necklateral_convsfpn_convs rpn_headloss_clsloss_bboxrpn_convrpn_clsrpn_reg roi_headbbox_roi_extractorbbox_head
说明
backbone:主干网络 neck: lateral_convs 旁路卷积fpn卷积rpn_head (产生提议框) rpn_cls和rpn_reg 是否包含物体以及边界框回归roi_head:分类和边界框回归
2.2、尝试体验rpn模型的效果
查看RPN候选框提取结结果
# 下载rpn预训练模型
!wget -c https://download.openmmlab.com/mmdetection/v2.0/rpn/rpn_r50_fpn_1x_coco/rpn_r50_fpn_1x_coco_20200218-5525fa2e.pth -O checkpoints/rpn_r50_fpn_1x_coco_20200218-5525fa2e.pth
--2021-11-15 12:48:01-- https://download.openmmlab.com/mmdetection/v2.0/rpn/rpn_r50_fpn_1x_coco/rpn_r50_fpn_1x_coco_20200218-5525fa2e.pth 正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.104.39.18 正在连接 download.openmmlab.com (download.openmmlab.com)|47.104.39.18|:443... 已连接。 已发出 HTTP 请求,正在等待回应... 200 OK 长度: 110057751 (105M) [application/octet-stream] 正在保存至: “checkpoints/rpn_r50_fpn_1x_coco_20200218-5525fa2e.pth” checkpoints/rpn_r50 100%[===================>] 104.96M 11.3MB/s in 9.8s 2021-11-15 12:48:16 (10.7 MB/s) - 已保存 “checkpoints/rpn_r50_fpn_1x_coco_20200218-5525fa2e.pth” [110057751/110057751])
#类似上面推理过程
#设置device
device1 = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#选择模型对应的配置文件
condfig = 'configs/rpn/rpn_r50_fpn_1x_coco.py'
#选择下载好的checkpoint
checkpoint = 'checkpoints/rpn_r50_fpn_1x_coco_20200218-5525fa2e.pth'
#初始化模型
model = init_detector(condfig,checkpoint,device=device1)
load checkpoint from local path: checkpoints/rpn_r50_fpn_1x_coco_20200218-5525fa2e.pth
#使用模型进行单张推理
img = 'demo/demo.jpg'
rpn_result = inference_detector(model,img)
/home/snnu/chenkequan/notebooke/mmdet/mmdetection/mmdet/datasets/utils.py:65: UserWarning: "ImageToTensor" pipeline is replaced by "DefaultFormatBundle" for batch inference. It is recommended to manually replace it in the test data pipeline in your config file.warnings.warn(
# 安装opencv模块
! pip install opencv-python
Requirement already satisfied: opencv-python in /home/snnu/miniconda3/envs/mmdet17/lib/python3.8/site-packages (4.5.4.58) Requirement already satisfied: numpy>=1.17.3 in /home/snnu/miniconda3/envs/mmdet17/lib/python3.8/site-packages (from opencv-python) (1.21.4)
!pip install google
Collecting googleDownloading google-3.0.0-py2.py3-none-any.whl (45 kB) •[K |████████████████████████████████| 45 kB 464 kB/s eta 0:00:01 •[?25hCollecting beautifulsoup4Downloading beautifulsoup4-4.10.0-py3-none-any.whl (97 kB) •[K |████████████████████████████████| 97 kB 1.3 MB/s eta 0:00:01 •[?25hCollecting soupsieve>1.2Downloading soupsieve-2.3.1-py3-none-any.whl (37 kB) Installing collected packages: soupsieve, beautifulsoup4, google Successfully installed beautifulsoup4-4.10.0 google-3.0.0 soupsieve-2.3.1
# import cv2
# from google.colab.pathes import cv2_imshow
model.show_result(img,rpn_result,top_k=100)
#在原图上绘制结果
show_result_pyplot(model,img,result,score_thr=0.01) #显示到终端了
### 2.3、模型配置文件说明
model = dict(type='FasterRCNN', ##模型backbone=dict(type='ResNet',depth=50,num_stages=4, #残差模块组out_indices=(0, 1, 2, 3), #把4个残差模块组都输出frozen_stages=1, #模型固定哪些参数,第一个模块组中finetune过程是不进行训练norm_cfg=dict(type='BN', requires_grad=True),norm_eval=True,style='pytorch',init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),neck=dict(type='FPN', #FPN模型in_channels=[256, 512, 1024, 2048], #输入特征out_channels=256, #统一压缩至256num_outs=5), #四个特征层+最高层的特征图做pooling产生更大的提议框 rpn_head=dict(type='RPNHead',in_channels=256, #输入256通道,和上面对应feat_channels=256, #第一个卷积通道是256 anchor_generator=dict( #产生不同尺度的提议框type='AnchorGenerator',scales=[8], #应为FPN模块有多尺度,所有这里产生8x8的提议框ratios=[0.5, 1.0, 2.0], #控制不同的长宽比,产生不同的提议框strides=[4, 8, 16, 32, 64]), #特征图一像素对应原图的步长数,每提高一层,将采样率提升一倍bbox_coder=dict(type='DeltaXYWHBBoxCoder', #对边界框进行编码target_means=[.0, .0, .0, .0],target_stds=[1.0, 1.0, 1.0, 1.0]),loss_cls=dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), #指定分类使用CrossEntropyLoss;因为是二分类问题可以使用, use_sigmoid=Trueloss_bbox=dict(type='L1Loss', loss_weight=1.0)), #回归使用L1Lossroi_head=dict( #roi分两步type='StandardRoIHead',bbox_roi_extractor=dict( #第一步:从特征图裁剪下来type='SingleRoIExtractor',roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), #指定ROLAlign算法,输出7*7out_channels=256, #输出通道维数是256featmap_strides=[4, 8, 16, 32]), #特征图上位移对应原图像上的步长bbox_head=dict( #裁剪下的特征图使用Shared2FCBBoxHead,这是一个共享结构全连接层type='Shared2FCBBoxHead',in_channels=256, #输入一个256通道7*7的特征图fc_out_channels=1024, #通过两层全连接层,转化成1024维度的特征向量roi_feat_size=7, #经过分类头和回归头产生80个类别的分类结果和边界框回归结果num_classes=80,bbox_coder=dict(type='DeltaXYWHBBoxCoder', #对边界框进行编码target_means=[0., 0., 0., 0.],target_stds=[0.1, 0.1, 0.2, 0.2]),reg_class_agnostic=False,loss_cls=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), #由于这里是80类所有没有使用sigmoidloss_bbox=dict(type='L1Loss', loss_weight=1.0))),# model training and testing settingstrain_cfg=dict( #不是配置模型结构的,是配置一些模块在训练中的一些行为rpn=dict( assigner=dict( type='MaxIoUAssigner', #分类方式采用IOU分类方法pos_iou_thr=0.7,neg_iou_thr=0.3,min_pos_iou=0.3,match_low_quality=True,ignore_iof_thr=-1),sampler=dict(type='RandomSampler', #产生上千个提议框只采样256个用于训练num=256,pos_fraction=0.5,neg_pos_ub=-1,add_gt_as_proposals=False),allowed_border=-1,pos_weight=-1,debug=False),rpn_proposal=dict( #先产生2000个proposal,经过nms后只剩下1000个nms_pre=2000,max_per_img=1000,nms=dict(type='nms', iou_threshold=0.7),min_bbox_size=0),rcnn=dict(assigner=dict(type='MaxIoUAssigner',pos_iou_thr=0.5,neg_iou_thr=0.5,min_pos_iou=0.5,match_low_quality=False,ignore_iof_thr=-1),sampler=dict(type='RandomSampler',num=512,pos_fraction=0.25,neg_pos_ub=-1,add_gt_as_proposals=True),pos_weight=-1,debug=False)),test_cfg=dict( #测试不需要给模型分配值了,所以只有nmsrpn=dict(nms_pre=1000,max_per_img=1000,nms=dict(type='nms', iou_threshold=0.7),min_bbox_size=0),rcnn=dict(score_thr=0.05,nms=dict(type='nms', iou_threshold=0.5),max_per_img=100)# soft-nms is also supported for rcnn testing# e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)))
这篇关于MMDetection 教程(一)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!