利用MMSegmentation微调Mask2Former模型

2023-10-20 16:59

本文主要是介绍利用MMSegmentation微调Mask2Former模型,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前言

  • 本文介绍了专用于语义分隔模型的pythonmmsegmentationgithub项目地址,运行环境为Kaggle notebookGPUP100
  • 针对环境配置、预训练模型推理、在西瓜数据集上微调新sota模型mask2former模型,数据说明
  • 由于西瓜数据集较小,我们最后在组织病理切片肾小球数据集上微调了mask2former模型,数据说明
  • 该教程有部分参考github项目MMSegmentation_Tutorials,项目地址

环境配置

  • 跑通代码需要openmimmmsegmentationmmenginemmdetectionmmcv环境,mmcv环境在kaggle配置比较麻烦,需要预配置包,这里我将所有预配置包都打包好了,放到了数据集frozen-packages-mmdetection中,详情页
import IPython.display as display
!pip install -U openmim!rm -rf mmsegmentation
!git clone https://github.com/open-mmlab/mmsegmentation.git
%cd mmsegmentation
!pip install -v -e .!pip install "mmdet>=3.0.0rc4"!pip install -q /kaggle/input/frozen-packages-mmdetection/mmcv-2.0.1-cp310-cp310-linux_x86_64.whl!pip install wandb
display.clear_output()
  • 实测运行上述代码,在kaggle中可以达到运行项目需求,无报错(2023年7月13日)。
  • 导入常用基础包
import io
import os
import cv2
import glob
import time
import torch
import shutil
import mmcv
import wandb
import random
import mmengine
import numpy as np
from PIL import Image
from tqdm import tqdm
from mmengine import Configimport matplotlib.pyplot as plt
%matplotlib inlinefrom mmseg.datasets import cityscapes
from mmseg.utils import register_all_modules
register_all_modules()from mmseg.datasets import CityscapesDataset
from mmengine.model.utils import revert_sync_batchnorm
from mmseg.apis import init_model, inference_model, show_result_pyplot# 忽略警告
import warnings
warnings.filterwarnings('ignore')display.clear_output()
  • 创建文件夹,用于放置数据集、模型预训练权重和模型推理输出
# 创建 checkpoint 文件夹,用于存放预训练模型权重文件
os.mkdir('checkpoint')# 创建 outputs 文件夹,用于存放预测结果
os.mkdir('outputs')# 创建 data 文件夹,用于存放图片和视频素材
os.mkdir('data')
  • 分别下载pspnet、segformer、mask2former在cityscapes上的预训练权重,并保存在checkpoint文件夹中
# 从Model Zoo预训练模型,下载并保存在 checkpoint 文件夹中
!wget https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth -P checkpoint
!wget https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_8x1_1024x1024_160k_cityscapes/segformer_mit-b5_8x1_1024x1024_160k_cityscapes_20211206_072934-87a052ec.pth -P checkpoint
!wget https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth -P checkpoint
display.clear_output()
  • 下载一些测试模型用的图片以及视频,并存放到data文件夹中。
# 伦敦街景图片
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220713-mmdetection/images/street_uk.jpeg -P data# 上海驾车街景视频,视频来源:https://www.youtube.com/watch?v=ll8TgCZ0plk
!wget https://zihao-download.obs.cn-east-3.myhuaweicloud.com/detectron2/traffic.mp4 -P data# 街拍视频,2022年3月30日
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220713-mmdetection/images/street_20220330_174028.mp4 -P data
display.clear_output()

图片推理

命令行推理

  • 使用命令行对图片进行推理,并使用PIL对结果进行可视化
  • 分别使用了pspnet模型和segformer模型进行推理
# pspnet模型
!python demo/image_demo.py \data/street_uk.jpeg \configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py \checkpoint/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth \--out-file outputs/B1_uk_pspnet.jpg \--device cuda:0 \--opacity 0.5display.clear_output()
Image.open('outputs/B1_uk_pspnet.jpg')

请添加图片描述

# segformer模型
!python demo/image_demo.py \data/street_uk.jpeg \configs/segformer/segformer_mit-b5_8xb1-160k_cityscapes-1024x1024.py \checkpoint/segformer_mit-b5_8x1_1024x1024_160k_cityscapes_20211206_072934-87a052ec.pth \--out-file outputs/B1_uk_segformer.jpg \--device cuda:0 \--opacity 0.5
display.clear_output()
Image.open('outputs/B1_uk_segformer.jpg')

请添加图片描述

  • 可以看到其实segformer的效果比pspnet模型效果要好,基本上能将不同物体分割开。

API推理

  • 使用mmsegmentation的Python API进行图片推理
  • 使用mask2former模型推理,并利用matplotlib对结果进行可视化
img_path = 'data/street_uk.jpeg'
img_pil = Image.open(img_path)
# 模型 config 配置文件
config_file = 'configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py'# 模型 checkpoint 权重文件
checkpoint_file = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'model = init_model(config_file, checkpoint_file, device='cuda:0')if not torch.cuda.is_available():model = revert_sync_batchnorm(model)result = inference_model(model, img_path)
pred_mask = result.pred_sem_seg.data[0].detach().cpu().numpy()display.clear_output()
img_bgr = cv2.imread(img_path)
plt.figure(figsize=(14, 8))
plt.imshow(img_bgr[:,:,::-1])
plt.imshow(pred_mask, alpha=0.55) # alpha 高亮区域透明度,越小越接近原图
plt.axis('off')
plt.savefig('outputs/B2-1.jpg')
plt.show()

请添加图片描述

  • mask2former作为sota模型,效果确实非常棒!

视频推理

命令行推理

  • 不推荐,速度很慢
!python demo/video_demo.py \data/street_20220330_174028.mp4 \configs/segformer/segformer_mit-b5_8xb1-160k_cityscapes-1024x1024.py \checkpoint/segformer_mit-b5_8x1_1024x1024_160k_cityscapes_20211206_072934-87a052ec.pth \--device cuda:0 \--output-file outputs/B3_video.mp4 \--opacity 0.5

API推理

  • mask2former模型使用API对视频进行推理
# 模型 config 配置文件
config_file = 'configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py'# 模型 checkpoint 权重文件
checkpoint_file = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'model = init_model(config_file, checkpoint_file, device='cuda:0')if not torch.cuda.is_available():model = revert_sync_batchnorm(model)display.clear_output()input_video = 'data/street_20220330_174028.mp4'temp_out_dir = time.strftime('%Y%m%d%H%M%S')
os.mkdir(temp_out_dir)
print('创建临时文件夹 {} 用于存放每帧预测结果'.format(temp_out_dir))# 获取 Cityscapes 街景数据集 类别名和调色板
classes = cityscapes.CityscapesDataset.METAINFO['classes']
palette = cityscapes.CityscapesDataset.METAINFO['palette']def pridict_single_frame(img, opacity=0.2):result = inference_model(model, img)# 将分割图按调色板染色seg_map = np.array(result.pred_sem_seg.data[0].detach().cpu().numpy()).astype('uint8')seg_img = Image.fromarray(seg_map).convert('P')seg_img.putpalette(np.array(palette, dtype=np.uint8))show_img = (np.array(seg_img.convert('RGB')))*(1-opacity) + img*opacityreturn show_img# 读入待预测视频
imgs = mmcv.VideoReader(input_video)prog_bar = mmengine.ProgressBar(len(imgs))# 对视频逐帧处理
for frame_id, img in enumerate(imgs):## 处理单帧画面show_img = pridict_single_frame(img, opacity=0.15)temp_path = f'{temp_out_dir}/{frame_id:06d}.jpg' # 保存语义分割预测结果图像至临时文件夹cv2.imwrite(temp_path, show_img)prog_bar.update() # 更新进度条# 把每一帧串成视频文件
mmcv.frames2video(temp_out_dir, 'outputs/B3_video.mp4', fps=imgs.fps, fourcc='mp4v')shutil.rmtree(temp_out_dir) # 删除存放每帧画面的临时文件夹
print('删除临时文件夹', temp_out_dir)

小样本数据集微调mask2former

  • 在西瓜语义分隔数据集上对模型进行微调

下载数据集

!rm -rf Watermelon87_Semantic_Seg_Mask.zip Watermelon87_Semantic_Seg_Mask!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/dataset/watermelon/Watermelon87_Semantic_Seg_Mask.zip!unzip Watermelon87_Semantic_Seg_Mask.zip >> /dev/null # 解压!rm -rf Watermelon87_Semantic_Seg_Mask.zip # 删除压缩包!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/watermelon/data/watermelon_test1.jpg -P data!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/watermelon/data/video_watermelon_2.mp4 -P data!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/watermelon/data/video_watermelon_3.mov -P data# 删除系统自动生成的多余文件
!find . -iname '__MACOSX'
!find . -iname '.DS_Store'
!find . -iname '.ipynb_checkpoints'# 删除多余文件
!for i in `find . -iname '__MACOSX'`; do rm -rf $i;done
!for i in `find . -iname '.DS_Store'`; do rm -rf $i;done
!for i in `find . -iname '.ipynb_checkpoints'`; do rm -rf $i;done# 验证多余文件已删除
!find . -iname '__MACOSX'
!find . -iname '.DS_Store'
!find . -iname '.ipynb_checkpoints'display.clear_output()

可视化探索语义分割数据集

  • 可视化语义信息
# 指定单张图像路径
img_path = 'Watermelon87_Semantic_Seg_Mask/img_dir/train/04_35-2.jpg'
mask_path = 'Watermelon87_Semantic_Seg_Mask/ann_dir/train/04_35-2.png'img = cv2.imread(img_path)
mask = cv2.imread(mask_path)# 可视化原图叠加
plt.figure(figsize=(8, 8))
plt.imshow(img[:,:,::-1])
plt.imshow(mask[:,:,0], alpha=0.6) # alpha 高亮区域透明度,越小越接近原图
plt.axis('off')
plt.show()

请添加图片描述

定义Dataset和Pipeline

  • Dataset部分,可以设定数值对应的具体类别,以及不同类别的标注颜色。图像格式,是否忽略类别0
  • Pipeline部分,可以设定训练、验证的数据处理步骤。以及规定图像裁剪尺寸
custom_dataset = """
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset@DATASETS.register_module()
class MyCustomDataset(BaseSegDataset):# 类别和对应的 RGB配色METAINFO = {'classes':['background', 'red', 'green', 'white', 'seed-black', 'seed-white'],'palette':[[127,127,127], [200,0,0], [0,200,0], [144,238,144], [30,30,30], [251,189,8]]}# 指定图像扩展名、标注扩展名def __init__(self,seg_map_suffix='.png',   # 标注mask图像的格式reduce_zero_label=False, # 类别ID为0的类别是否需要除去**kwargs) -> None:super().__init__(seg_map_suffix=seg_map_suffix,reduce_zero_label=reduce_zero_label,**kwargs)
"""with io.open('mmseg/datasets/MyCustomDataset.py', 'w', encoding='utf-8') as f:f.write(custom_dataset)
  • custom_dataset加入__init__.py文件
custom_init = """
# Copyright (c) OpenMMLab. All rights reserved.
# yapf: disable
from .ade import ADE20KDataset
from .basesegdataset import BaseSegDataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .coco_stuff import COCOStuffDataset
from .dark_zurich import DarkZurichDataset
from .dataset_wrappers import MultiImageMixDataset
from .decathlon import DecathlonDataset
from .drive import DRIVEDataset
from .hrf import HRFDataset
from .isaid import iSAIDDataset
from .isprs import ISPRSDataset
from .lip import LIPDataset
from .loveda import LoveDADataset
from .night_driving import NightDrivingDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
from .stare import STAREDataset
from .synapse import SynapseDataset
from .MyCustomDataset import MyCustomDataset
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,BioMedical3DRandomCrop, BioMedical3DRandomFlip,BioMedicalGaussianBlur, BioMedicalGaussianNoise,BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,LoadBiomedicalAnnotation, LoadBiomedicalData,LoadBiomedicalImageFromFile, LoadImageFromNDArray,PackSegInputs, PhotoMetricDistortion, RandomCrop,RandomCutOut, RandomMosaic, RandomRotate,RandomRotFlip, Rerange, ResizeShortestEdge,ResizeToMultiple, RGB2Gray, SegRescale)
from .voc import PascalVOCDataset# yapf: enable
__all__ = ['BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip','CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset','PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset','DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset','NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset','MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset','LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion','RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray','RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple','LoadImageFromNDArray', 'LoadBiomedicalImageFromFile','LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge','DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge','BioMedicalGaussianNoise', 'BioMedicalGaussianBlur','BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip','SynapseDataset', 'MyCustomDataset'
]"""with io.open('mmseg/datasets/__init__.py', 'w', encoding='utf-8') as f:f.write(custom_init)
  • 定义数据集预处理通道
custom_pipeline = """
# 数据集路径
dataset_type = 'MyCustomDataset' # 数据集类名
data_root = 'Watermelon87_Semantic_Seg_Mask/' # 数据集路径(相对于mmsegmentation主目录)# 输入模型的图像裁剪尺寸,一般是 128 的倍数,越小显存开销越少
crop_size = (640, 640)# 训练预处理
train_pipeline = [dict(type='LoadImageFromFile'),dict(type='LoadAnnotations'),dict(type='RandomResize',scale=(2048, 1024),ratio_range=(0.5, 2.0),keep_ratio=True),dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),dict(type='RandomFlip', prob=0.5),dict(type='PhotoMetricDistortion'),dict(type='PackSegInputs')
]# 测试预处理
test_pipeline = [dict(type='LoadImageFromFile'),dict(type='Resize', scale=(2048, 1024), keep_ratio=True),dict(type='LoadAnnotations'),dict(type='PackSegInputs')
]# TTA后处理
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),dict(type='TestTimeAug',transforms=[[dict(type='Resize', scale_factor=r, keep_ratio=True)for r in img_ratios],[dict(type='RandomFlip', prob=0., direction='horizontal'),dict(type='RandomFlip', prob=1., direction='horizontal')], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]])
]# 训练 Dataloader
train_dataloader = dict(batch_size=2,num_workers=4,persistent_workers=True,sampler=dict(type='InfiniteSampler', shuffle=True),dataset=dict(type=dataset_type,data_root=data_root,data_prefix=dict(img_path='img_dir/train', seg_map_path='ann_dir/train'),pipeline=train_pipeline))# 验证 Dataloader
val_dataloader = dict(batch_size=1,num_workers=4,persistent_workers=True,sampler=dict(type='DefaultSampler', shuffle=False),dataset=dict(type=dataset_type,data_root=data_root,data_prefix=dict(img_path='img_dir/val', seg_map_path='ann_dir/val'),pipeline=test_pipeline))# 测试 Dataloader
test_dataloader = val_dataloader# 验证 Evaluator
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice', 'mFscore'])# 测试 Evaluator
test_evaluator = val_evaluator
"""with io.open('configs/_base_/datasets/custom_pipeline.py', 'w', encoding='utf-8') as f:f.write(custom_pipeline)

修改配置文件

  • 主要修改类别个数、预训练权重路径、初始化图片尺寸(一般为128的整数倍)、batch_size、缩放学习率(修改的比例是 base_lr_default * (your_bs / default_bs))、更改学习率衰减策略
  • 关于学习率:主要修改optimizer中的lr,不用修改optim_wrapper
  • 冻结模型的骨干网络,对mask2former来说可以加快训练
cfg = Config.fromfile('configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py')
dataset_cfg = Config.fromfile('configs/_base_/datasets/custom_pipeline.py')
cfg.merge_from_dict(dataset_cfg)
# 类别个数
NUM_CLASS = 6
# 单卡训练时,需要把 SyncBN 改成 BN
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.crop_size = (640, 640)
cfg.model.data_preprocessor.size = cfg.crop_size# 预训练模型权重
cfg.load_from = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'# 模型 decode/auxiliary 输出头,指定为类别个数
cfg.model.decode_head.num_classes = NUM_CLASS
cfg.model.decode_head.loss_cls.class_weight = [1.0] * NUM_CLASS + [0.1]
cfg.model.backbone.frozen_stages = 4# 训练 Batch Size
cfg.train_dataloader.batch_size = 2
cfg.test_dataloader = cfg.val_dataloadercfg.optimizer.lr = cfg.optimizer.lr / 8# 结果保存目录
cfg.work_dir = './work_dirs'cfg.train_cfg.max_iters = 4000 # 训练迭代次数
cfg.train_cfg.val_interval = 50 # 评估模型间隔
cfg.default_hooks.logger.interval = 50 # 日志记录间隔
cfg.default_hooks.checkpoint.interval = 50 # 模型权重保存间隔
cfg.default_hooks.checkpoint.max_keep_ckpts = 2 # 最多保留几个模型权重
cfg.default_hooks.checkpoint.save_best = 'mIoU' # 保留指标最高的模型权重cfg.param_scheduler[0].end = cfg.train_cfg.max_iters
# 随机数种子
cfg['randomness'] = dict(seed=0)cfg.visualizer.vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')]
  • 保存配置文件
cfg.dump('custom_mask2former.py')
  • 开始训练
!python tools/train.py custom_mask2former.py
  • 选取最优模型,测试模型精度
# 取最佳模型权重
best_pth = glob.glob('work_dirs/best_mIoU*.pth')[0]
# 测试精度
!python tools/test.py custom_mask2former.py '{best_pth}'
  • 输出:
+------------+-------+-------+-------+--------+-----------+--------+
|   Class    |  IoU  |  Acc  |  Dice | Fscore | Precision | Recall |
+------------+-------+-------+-------+--------+-----------+--------+
| background | 98.55 | 99.12 | 99.27 | 99.27  |   99.42   | 99.12  |
|    red     | 96.54 | 98.83 | 98.24 | 98.24  |   97.65   | 98.83  |
|   green    | 94.37 | 96.08 |  97.1 |  97.1  |   98.14   | 96.08  |
|   white    | 85.96 | 92.67 | 92.45 | 92.45  |   92.24   | 92.67  |
| seed-black | 81.98 | 90.87 |  90.1 |  90.1  |   89.34   | 90.87  |
| seed-white | 65.57 | 69.98 | 79.21 | 79.21  |   91.24   | 69.98  |
+------------+-------+-------+-------+--------+-----------+--------+

可视化训练指标

在这里插入图片描述

肾小球数据集微调模型

  • 在单类别数据集(组织病理切片肾小球)上微调mask2former模型
  • 首先清空工作目录、data文件夹和outputs文件
# 清空工作目录
!rm -r work_dirs/*
# 清空data文件夹
!rm -r data/*
# 清空outputs文件夹
!rm -r outputs/*

可视化探索语义分割数据集

# 指定图像和标注路径
PATH_IMAGE = '/kaggle/input/glomeruli-hubmap-external-1024x1024/images_1024'
PATH_MASKS = '/kaggle/input/glomeruli-hubmap-external-1024x1024/masks_1024'mask = cv2.imread('/kaggle/input/glomeruli-hubmap-external-1024x1024/masks_1024/VUHSK_1762_29.png')
# 查看类别
np.unique(mask)
  • 输出
array([0, 1], dtype=uint8)
  • 可视化语义分割信息
# n行n列可视化
n = 5# 标注区域透明度,透明度越小,越接近原图
opacity = 0.65fig, axes = plt.subplots(nrows=n, ncols=n, sharex=True, figsize=(12,12))for i, file_name in enumerate(os.listdir(PATH_IMAGE)[:n**2]):# 载入图像和标注img_path = os.path.join(PATH_IMAGE, file_name)mask_path = os.path.join(PATH_MASKS, file_name.split('.')[0]+'.png')img = cv2.imread(img_path)mask = cv2.imread(mask_path)# 可视化axes[i//n, i%n].imshow(img[:,:,::-1])axes[i//n, i%n].imshow(mask[:,:,0], alpha=opacity)axes[i//n, i%n].axis('off') # 关闭坐标轴显示
fig.suptitle('Image and Semantic Label', fontsize=20)
plt.tight_layout()
plt.savefig('outputs/C2-1.jpg')
plt.show()

请添加图片描述

分割训练集与测试集

  • 新建各类训练、验证文件夹
# 新建图片训练、验证文件夹
!mkdir -p data/images/train
!mkdir -p data/images/val# 新建mask训练、验证文件夹
!mkdir -p data/masks/train
!mkdir -p data/masks/val
  • 随机打乱数据,并按照90%训练集、10%测试集分割
def copy_file(og_images, og_masks, tr_images, tr_masks, thor):# 获取源文件夹中的所有文件名file_names = os.listdir(og_images)# 随机打乱文件名列表random.shuffle(file_names)# 计算分割点split_index = int(thor * len(file_names))# 复制训练集文件for file_name in file_names[:split_index]:og_image = os.path.join(og_images, file_name)og_mask = os.path.join(og_masks, file_name)tr_image = os.path.join(tr_images, 'train', file_name)tr_mask = os.path.join(tr_masks, 'train', file_name)shutil.copyfile(og_image, tr_image)shutil.copyfile(og_mask, tr_mask)# 复制验证集文件for file_name in file_names[split_index:]:og_image = os.path.join(og_images, file_name)og_mask = os.path.join(og_masks, file_name)tr_image = os.path.join(tr_images, 'val', file_name)tr_mask = os.path.join(tr_masks, 'val', file_name)shutil.copyfile(og_image, tr_image)shutil.copyfile(og_mask, tr_mask)og_images = '/kaggle/input/glomeruli-hubmap-external-1024x1024/images_1024'
og_masks = '/kaggle/input/glomeruli-hubmap-external-1024x1024/masks_1024'tr_images = 'data/images'
tr_masks = 'data/masks'copy_file(og_images, og_masks, tr_images, tr_masks, 0.9)

重新定义Dataset和Pipeline

  • 主要是修改类别及对应RGB配色
  • 以及dataload的路径信息
custom_dataset = """
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset@DATASETS.register_module()
class MyCustomDataset(BaseSegDataset):# 类别和对应的RGB配色METAINFO = {'classes':['normal','sclerotic'],'palette':[[127,127,127],[251,189,8]]}# 指定图像扩展名、标注扩展名def __init__(self,img_suffix='.png',seg_map_suffix='.png',   # 标注mask图像的格式reduce_zero_label=False, # 类别ID为0的类别是否需要除去**kwargs) -> None:super().__init__(img_suffix=img_suffix,seg_map_suffix=seg_map_suffix,reduce_zero_label=reduce_zero_label,**kwargs)
"""with io.open('mmseg/datasets/MyCustomDataset.py', 'w', encoding='utf-8') as f:f.write(custom_dataset)
custom_init = """
# Copyright (c) OpenMMLab. All rights reserved.
# yapf: disable
from .ade import ADE20KDataset
from .basesegdataset import BaseSegDataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .coco_stuff import COCOStuffDataset
from .dark_zurich import DarkZurichDataset
from .dataset_wrappers import MultiImageMixDataset
from .decathlon import DecathlonDataset
from .drive import DRIVEDataset
from .hrf import HRFDataset
from .isaid import iSAIDDataset
from .isprs import ISPRSDataset
from .lip import LIPDataset
from .loveda import LoveDADataset
from .night_driving import NightDrivingDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
from .stare import STAREDataset
from .synapse import SynapseDataset
from .MyCustomDataset import MyCustomDataset
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,BioMedical3DRandomCrop, BioMedical3DRandomFlip,BioMedicalGaussianBlur, BioMedicalGaussianNoise,BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,LoadBiomedicalAnnotation, LoadBiomedicalData,LoadBiomedicalImageFromFile, LoadImageFromNDArray,PackSegInputs, PhotoMetricDistortion, RandomCrop,RandomCutOut, RandomMosaic, RandomRotate,RandomRotFlip, Rerange, ResizeShortestEdge,ResizeToMultiple, RGB2Gray, SegRescale)
from .voc import PascalVOCDataset# yapf: enable
__all__ = ['BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip','CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset','PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset','DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset','NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset','MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset','LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion','RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray','RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple','LoadImageFromNDArray', 'LoadBiomedicalImageFromFile','LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge','DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge','BioMedicalGaussianNoise', 'BioMedicalGaussianBlur','BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip','SynapseDataset', 'MyCustomDataset'
]"""with io.open('mmseg/datasets/__init__.py', 'w', encoding='utf-8') as f:f.write(custom_init)
  • 定义数据预处理管道
custom_pipeline = """
# 数据集路径
dataset_type = 'MyCustomDataset' # 数据集类名
data_root = 'data/' # 数据集路径(相对于mmsegmentation主目录)# 输入模型的图像裁剪尺寸,一般是 128 的倍数,越小显存开销越少
crop_size = (640, 640)# 训练预处理
train_pipeline = [dict(type='LoadImageFromFile'),dict(type='LoadAnnotations'),dict(type='RandomResize',scale=(2048, 1024),ratio_range=(0.5, 2.0),keep_ratio=True),dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),dict(type='RandomFlip', prob=0.5),dict(type='PhotoMetricDistortion'),dict(type='PackSegInputs')
]# 测试预处理
test_pipeline = [dict(type='LoadImageFromFile'),dict(type='Resize', scale=(2048, 1024), keep_ratio=True),dict(type='LoadAnnotations'),dict(type='PackSegInputs')
]# TTA后处理
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),dict(type='TestTimeAug',transforms=[[dict(type='Resize', scale_factor=r, keep_ratio=True)for r in img_ratios],[dict(type='RandomFlip', prob=0., direction='horizontal'),dict(type='RandomFlip', prob=1., direction='horizontal')], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]])
]# 训练 Dataloader
train_dataloader = dict(batch_size=2,num_workers=4,persistent_workers=True,sampler=dict(type='InfiniteSampler', shuffle=True),dataset=dict(type=dataset_type,data_root=data_root,data_prefix=dict(img_path='images/train', seg_map_path='masks/train'),pipeline=train_pipeline))# 验证 Dataloader
val_dataloader = dict(batch_size=1,num_workers=4,persistent_workers=True,sampler=dict(type='DefaultSampler', shuffle=False),dataset=dict(type=dataset_type,data_root=data_root,data_prefix=dict(img_path='images/val', seg_map_path='masks/val'),pipeline=test_pipeline))# 测试 Dataloader
test_dataloader = val_dataloader# 验证 Evaluator
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice', 'mFscore'])# 测试 Evaluator
test_evaluator = val_evaluator
"""with io.open('configs/_base_/datasets/custom_pipeline.py', 'w', encoding='utf-8') as f:f.write(custom_pipeline)

修改配置文件

cfg = Config.fromfile('configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py')
dataset_cfg = Config.fromfile('configs/_base_/datasets/custom_pipeline.py')
cfg.merge_from_dict(dataset_cfg)
  • 更改配置文件
# 类别个数
NUM_CLASS = 2
# 单卡训练时,需要把 SyncBN 改成 BN
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.crop_size = (640, 640)
cfg.model.data_preprocessor.size = cfg.crop_size# 预训练模型权重
cfg.load_from = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'# 模型 decode/auxiliary 输出头,指定为类别个数
cfg.model.decode_head.num_classes = NUM_CLASS
cfg.model.decode_head.loss_cls.class_weight = [1.0] * NUM_CLASS + [0.1]
cfg.model.backbone.frozen_stages = 4# 训练 Batch Size
cfg.train_dataloader.batch_size = 2
cfg.test_dataloader = cfg.val_dataloadercfg.optimizer.lr = cfg.optimizer.lr / 8# 结果保存目录
cfg.work_dir = './work_dirs'cfg.train_cfg.max_iters = 40000 # 训练迭代次数
cfg.train_cfg.val_interval = 500 # 评估模型间隔
cfg.default_hooks.logger.interval = 50 # 日志记录间隔
cfg.default_hooks.checkpoint.interval = 2500 # 模型权重保存间隔
cfg.default_hooks.checkpoint.max_keep_ckpts = 2 # 最多保留几个模型权重
cfg.default_hooks.checkpoint.save_best = 'mIoU' # 保留指标最高的模型权重# 随机数种子
cfg['randomness'] = dict(seed=0)cfg.visualizer.vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')]
  • 保存配置文件,并开始训练
cfg.dump('custom_mask2former.py')
!python tools/train.py custom_mask2former.py

可视化训练指标

在这里插入图片描述

评估模型以及测试推理速度

  • 评估模型精度
# 取最佳模型权重
best_pth = glob.glob('work_dirs/best_mIoU*.pth')[0]
# 测试精度
!python tools/test.py custom_mask2former.py '{best_pth}'
  • 输出:
+-----------+-------+-------+-------+--------+-----------+--------+
|   Class   |  IoU  |  Acc  |  Dice | Fscore | Precision | Recall |
+-----------+-------+-------+-------+--------+-----------+--------+
|   normal  | 99.74 | 99.89 | 99.87 | 99.87  |   99.86   | 99.89  |
| sclerotic | 86.41 | 91.87 | 92.71 | 92.71  |   93.57   | 91.87  |
+-----------+-------+-------+-------+--------+-----------+--------+
  • 测试模型推理速度
# 测试FPS
!python tools/analysis_tools/benchmark.py custom_mask2former.py '{best_pth}'
  • 输出:
Done image [50 / 200], fps: 2.24 img / s
Done image [100/ 200], fps: 2.24 img / s
Done image [150/ 200], fps: 2.24 img / s
Done image [200/ 200], fps: 2.24 img / s
Overall fps: 2.24 img / sAverage fps of 1 evaluations: 2.24
The variance of 1 evaluations: 0.0

这篇关于利用MMSegmentation微调Mask2Former模型的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/248448

相关文章

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

Retrieval-based-Voice-Conversion-WebUI模型构建指南

一、模型介绍 Retrieval-based-Voice-Conversion-WebUI(简称 RVC)模型是一个基于 VITS(Variational Inference with adversarial learning for end-to-end Text-to-Speech)的简单易用的语音转换框架。 具有以下特点 简单易用:RVC 模型通过简单易用的网页界面,使得用户无需深入了

透彻!驯服大型语言模型(LLMs)的五种方法,及具体方法选择思路

引言 随着时间的发展,大型语言模型不再停留在演示阶段而是逐步面向生产系统的应用,随着人们期望的不断增加,目标也发生了巨大的变化。在短短的几个月的时间里,人们对大模型的认识已经从对其zero-shot能力感到惊讶,转变为考虑改进模型质量、提高模型可用性。 「大语言模型(LLMs)其实就是利用高容量的模型架构(例如Transformer)对海量的、多种多样的数据分布进行建模得到,它包含了大量的先验

图神经网络模型介绍(1)

我们将图神经网络分为基于谱域的模型和基于空域的模型,并按照发展顺序详解每个类别中的重要模型。 1.1基于谱域的图神经网络         谱域上的图卷积在图学习迈向深度学习的发展历程中起到了关键的作用。本节主要介绍三个具有代表性的谱域图神经网络:谱图卷积网络、切比雪夫网络和图卷积网络。 (1)谱图卷积网络 卷积定理:函数卷积的傅里叶变换是函数傅里叶变换的乘积,即F{f*g}

秋招最新大模型算法面试,熬夜都要肝完它

💥大家在面试大模型LLM这个板块的时候,不知道面试完会不会复盘、总结,做笔记的习惯,这份大模型算法岗面试八股笔记也帮助不少人拿到过offer ✨对于面试大模型算法工程师会有一定的帮助,都附有完整答案,熬夜也要看完,祝大家一臂之力 这份《大模型算法工程师面试题》已经上传CSDN,还有完整版的大模型 AI 学习资料,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

【生成模型系列(初级)】嵌入(Embedding)方程——自然语言处理的数学灵魂【通俗理解】

【通俗理解】嵌入(Embedding)方程——自然语言处理的数学灵魂 关键词提炼 #嵌入方程 #自然语言处理 #词向量 #机器学习 #神经网络 #向量空间模型 #Siri #Google翻译 #AlexNet 第一节:嵌入方程的类比与核心概念【尽可能通俗】 嵌入方程可以被看作是自然语言处理中的“翻译机”,它将文本中的单词或短语转换成计算机能够理解的数学形式,即向量。 正如翻译机将一种语言

AI Toolkit + H100 GPU,一小时内微调最新热门文生图模型 FLUX

上个月,FLUX 席卷了互联网,这并非没有原因。他们声称优于 DALLE 3、Ideogram 和 Stable Diffusion 3 等模型,而这一点已被证明是有依据的。随着越来越多的流行图像生成工具(如 Stable Diffusion Web UI Forge 和 ComyUI)开始支持这些模型,FLUX 在 Stable Diffusion 领域的扩展将会持续下去。 自 FLU

SWAP作物生长模型安装教程、数据制备、敏感性分析、气候变化影响、R模型敏感性分析与贝叶斯优化、Fortran源代码分析、气候数据降尺度与变化影响分析

查看原文>>>全流程SWAP农业模型数据制备、敏感性分析及气候变化影响实践技术应用 SWAP模型是由荷兰瓦赫宁根大学开发的先进农作物模型,它综合考虑了土壤-水分-大气以及植被间的相互作用;是一种描述作物生长过程的一种机理性作物生长模型。它不但运用Richard方程,使其能够精确的模拟土壤中水分的运动,而且耦合了WOFOST作物模型使作物的生长描述更为科学。 本文让更多的科研人员和农业工作者

线性因子模型 - 独立分量分析(ICA)篇

序言 线性因子模型是数据分析与机器学习中的一类重要模型,它们通过引入潜变量( latent variables \text{latent variables} latent variables)来更好地表征数据。其中,独立分量分析( ICA \text{ICA} ICA)作为线性因子模型的一种,以其独特的视角和广泛的应用领域而备受关注。 ICA \text{ICA} ICA旨在将观察到的复杂信号