利用MMSegmentation微调Mask2Former模型

2023-10-20 16:59

本文主要是介绍利用MMSegmentation微调Mask2Former模型,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前言

  • 本文介绍了专用于语义分隔模型的pythonmmsegmentationgithub项目地址,运行环境为Kaggle notebookGPUP100
  • 针对环境配置、预训练模型推理、在西瓜数据集上微调新sota模型mask2former模型,数据说明
  • 由于西瓜数据集较小,我们最后在组织病理切片肾小球数据集上微调了mask2former模型,数据说明
  • 该教程有部分参考github项目MMSegmentation_Tutorials,项目地址

环境配置

  • 跑通代码需要openmimmmsegmentationmmenginemmdetectionmmcv环境,mmcv环境在kaggle配置比较麻烦,需要预配置包,这里我将所有预配置包都打包好了,放到了数据集frozen-packages-mmdetection中,详情页
import IPython.display as display
!pip install -U openmim!rm -rf mmsegmentation
!git clone https://github.com/open-mmlab/mmsegmentation.git
%cd mmsegmentation
!pip install -v -e .!pip install "mmdet>=3.0.0rc4"!pip install -q /kaggle/input/frozen-packages-mmdetection/mmcv-2.0.1-cp310-cp310-linux_x86_64.whl!pip install wandb
display.clear_output()
  • 实测运行上述代码,在kaggle中可以达到运行项目需求,无报错(2023年7月13日)。
  • 导入常用基础包
import io
import os
import cv2
import glob
import time
import torch
import shutil
import mmcv
import wandb
import random
import mmengine
import numpy as np
from PIL import Image
from tqdm import tqdm
from mmengine import Configimport matplotlib.pyplot as plt
%matplotlib inlinefrom mmseg.datasets import cityscapes
from mmseg.utils import register_all_modules
register_all_modules()from mmseg.datasets import CityscapesDataset
from mmengine.model.utils import revert_sync_batchnorm
from mmseg.apis import init_model, inference_model, show_result_pyplot# 忽略警告
import warnings
warnings.filterwarnings('ignore')display.clear_output()
  • 创建文件夹,用于放置数据集、模型预训练权重和模型推理输出
# 创建 checkpoint 文件夹,用于存放预训练模型权重文件
os.mkdir('checkpoint')# 创建 outputs 文件夹,用于存放预测结果
os.mkdir('outputs')# 创建 data 文件夹,用于存放图片和视频素材
os.mkdir('data')
  • 分别下载pspnet、segformer、mask2former在cityscapes上的预训练权重,并保存在checkpoint文件夹中
# 从Model Zoo预训练模型,下载并保存在 checkpoint 文件夹中
!wget https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth -P checkpoint
!wget https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_8x1_1024x1024_160k_cityscapes/segformer_mit-b5_8x1_1024x1024_160k_cityscapes_20211206_072934-87a052ec.pth -P checkpoint
!wget https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth -P checkpoint
display.clear_output()
  • 下载一些测试模型用的图片以及视频,并存放到data文件夹中。
# 伦敦街景图片
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220713-mmdetection/images/street_uk.jpeg -P data# 上海驾车街景视频,视频来源:https://www.youtube.com/watch?v=ll8TgCZ0plk
!wget https://zihao-download.obs.cn-east-3.myhuaweicloud.com/detectron2/traffic.mp4 -P data# 街拍视频,2022年3月30日
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220713-mmdetection/images/street_20220330_174028.mp4 -P data
display.clear_output()

图片推理

命令行推理

  • 使用命令行对图片进行推理,并使用PIL对结果进行可视化
  • 分别使用了pspnet模型和segformer模型进行推理
# pspnet模型
!python demo/image_demo.py \data/street_uk.jpeg \configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py \checkpoint/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth \--out-file outputs/B1_uk_pspnet.jpg \--device cuda:0 \--opacity 0.5display.clear_output()
Image.open('outputs/B1_uk_pspnet.jpg')

请添加图片描述

# segformer模型
!python demo/image_demo.py \data/street_uk.jpeg \configs/segformer/segformer_mit-b5_8xb1-160k_cityscapes-1024x1024.py \checkpoint/segformer_mit-b5_8x1_1024x1024_160k_cityscapes_20211206_072934-87a052ec.pth \--out-file outputs/B1_uk_segformer.jpg \--device cuda:0 \--opacity 0.5
display.clear_output()
Image.open('outputs/B1_uk_segformer.jpg')

请添加图片描述

  • 可以看到其实segformer的效果比pspnet模型效果要好,基本上能将不同物体分割开。

API推理

  • 使用mmsegmentation的Python API进行图片推理
  • 使用mask2former模型推理,并利用matplotlib对结果进行可视化
img_path = 'data/street_uk.jpeg'
img_pil = Image.open(img_path)
# 模型 config 配置文件
config_file = 'configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py'# 模型 checkpoint 权重文件
checkpoint_file = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'model = init_model(config_file, checkpoint_file, device='cuda:0')if not torch.cuda.is_available():model = revert_sync_batchnorm(model)result = inference_model(model, img_path)
pred_mask = result.pred_sem_seg.data[0].detach().cpu().numpy()display.clear_output()
img_bgr = cv2.imread(img_path)
plt.figure(figsize=(14, 8))
plt.imshow(img_bgr[:,:,::-1])
plt.imshow(pred_mask, alpha=0.55) # alpha 高亮区域透明度,越小越接近原图
plt.axis('off')
plt.savefig('outputs/B2-1.jpg')
plt.show()

请添加图片描述

  • mask2former作为sota模型,效果确实非常棒!

视频推理

命令行推理

  • 不推荐,速度很慢
!python demo/video_demo.py \data/street_20220330_174028.mp4 \configs/segformer/segformer_mit-b5_8xb1-160k_cityscapes-1024x1024.py \checkpoint/segformer_mit-b5_8x1_1024x1024_160k_cityscapes_20211206_072934-87a052ec.pth \--device cuda:0 \--output-file outputs/B3_video.mp4 \--opacity 0.5

API推理

  • mask2former模型使用API对视频进行推理
# 模型 config 配置文件
config_file = 'configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py'# 模型 checkpoint 权重文件
checkpoint_file = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'model = init_model(config_file, checkpoint_file, device='cuda:0')if not torch.cuda.is_available():model = revert_sync_batchnorm(model)display.clear_output()input_video = 'data/street_20220330_174028.mp4'temp_out_dir = time.strftime('%Y%m%d%H%M%S')
os.mkdir(temp_out_dir)
print('创建临时文件夹 {} 用于存放每帧预测结果'.format(temp_out_dir))# 获取 Cityscapes 街景数据集 类别名和调色板
classes = cityscapes.CityscapesDataset.METAINFO['classes']
palette = cityscapes.CityscapesDataset.METAINFO['palette']def pridict_single_frame(img, opacity=0.2):result = inference_model(model, img)# 将分割图按调色板染色seg_map = np.array(result.pred_sem_seg.data[0].detach().cpu().numpy()).astype('uint8')seg_img = Image.fromarray(seg_map).convert('P')seg_img.putpalette(np.array(palette, dtype=np.uint8))show_img = (np.array(seg_img.convert('RGB')))*(1-opacity) + img*opacityreturn show_img# 读入待预测视频
imgs = mmcv.VideoReader(input_video)prog_bar = mmengine.ProgressBar(len(imgs))# 对视频逐帧处理
for frame_id, img in enumerate(imgs):## 处理单帧画面show_img = pridict_single_frame(img, opacity=0.15)temp_path = f'{temp_out_dir}/{frame_id:06d}.jpg' # 保存语义分割预测结果图像至临时文件夹cv2.imwrite(temp_path, show_img)prog_bar.update() # 更新进度条# 把每一帧串成视频文件
mmcv.frames2video(temp_out_dir, 'outputs/B3_video.mp4', fps=imgs.fps, fourcc='mp4v')shutil.rmtree(temp_out_dir) # 删除存放每帧画面的临时文件夹
print('删除临时文件夹', temp_out_dir)

小样本数据集微调mask2former

  • 在西瓜语义分隔数据集上对模型进行微调

下载数据集

!rm -rf Watermelon87_Semantic_Seg_Mask.zip Watermelon87_Semantic_Seg_Mask!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/dataset/watermelon/Watermelon87_Semantic_Seg_Mask.zip!unzip Watermelon87_Semantic_Seg_Mask.zip >> /dev/null # 解压!rm -rf Watermelon87_Semantic_Seg_Mask.zip # 删除压缩包!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/watermelon/data/watermelon_test1.jpg -P data!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/watermelon/data/video_watermelon_2.mp4 -P data!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/watermelon/data/video_watermelon_3.mov -P data# 删除系统自动生成的多余文件
!find . -iname '__MACOSX'
!find . -iname '.DS_Store'
!find . -iname '.ipynb_checkpoints'# 删除多余文件
!for i in `find . -iname '__MACOSX'`; do rm -rf $i;done
!for i in `find . -iname '.DS_Store'`; do rm -rf $i;done
!for i in `find . -iname '.ipynb_checkpoints'`; do rm -rf $i;done# 验证多余文件已删除
!find . -iname '__MACOSX'
!find . -iname '.DS_Store'
!find . -iname '.ipynb_checkpoints'display.clear_output()

可视化探索语义分割数据集

  • 可视化语义信息
# 指定单张图像路径
img_path = 'Watermelon87_Semantic_Seg_Mask/img_dir/train/04_35-2.jpg'
mask_path = 'Watermelon87_Semantic_Seg_Mask/ann_dir/train/04_35-2.png'img = cv2.imread(img_path)
mask = cv2.imread(mask_path)# 可视化原图叠加
plt.figure(figsize=(8, 8))
plt.imshow(img[:,:,::-1])
plt.imshow(mask[:,:,0], alpha=0.6) # alpha 高亮区域透明度,越小越接近原图
plt.axis('off')
plt.show()

请添加图片描述

定义Dataset和Pipeline

  • Dataset部分,可以设定数值对应的具体类别,以及不同类别的标注颜色。图像格式,是否忽略类别0
  • Pipeline部分,可以设定训练、验证的数据处理步骤。以及规定图像裁剪尺寸
custom_dataset = """
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset@DATASETS.register_module()
class MyCustomDataset(BaseSegDataset):# 类别和对应的 RGB配色METAINFO = {'classes':['background', 'red', 'green', 'white', 'seed-black', 'seed-white'],'palette':[[127,127,127], [200,0,0], [0,200,0], [144,238,144], [30,30,30], [251,189,8]]}# 指定图像扩展名、标注扩展名def __init__(self,seg_map_suffix='.png',   # 标注mask图像的格式reduce_zero_label=False, # 类别ID为0的类别是否需要除去**kwargs) -> None:super().__init__(seg_map_suffix=seg_map_suffix,reduce_zero_label=reduce_zero_label,**kwargs)
"""with io.open('mmseg/datasets/MyCustomDataset.py', 'w', encoding='utf-8') as f:f.write(custom_dataset)
  • custom_dataset加入__init__.py文件
custom_init = """
# Copyright (c) OpenMMLab. All rights reserved.
# yapf: disable
from .ade import ADE20KDataset
from .basesegdataset import BaseSegDataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .coco_stuff import COCOStuffDataset
from .dark_zurich import DarkZurichDataset
from .dataset_wrappers import MultiImageMixDataset
from .decathlon import DecathlonDataset
from .drive import DRIVEDataset
from .hrf import HRFDataset
from .isaid import iSAIDDataset
from .isprs import ISPRSDataset
from .lip import LIPDataset
from .loveda import LoveDADataset
from .night_driving import NightDrivingDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
from .stare import STAREDataset
from .synapse import SynapseDataset
from .MyCustomDataset import MyCustomDataset
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,BioMedical3DRandomCrop, BioMedical3DRandomFlip,BioMedicalGaussianBlur, BioMedicalGaussianNoise,BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,LoadBiomedicalAnnotation, LoadBiomedicalData,LoadBiomedicalImageFromFile, LoadImageFromNDArray,PackSegInputs, PhotoMetricDistortion, RandomCrop,RandomCutOut, RandomMosaic, RandomRotate,RandomRotFlip, Rerange, ResizeShortestEdge,ResizeToMultiple, RGB2Gray, SegRescale)
from .voc import PascalVOCDataset# yapf: enable
__all__ = ['BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip','CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset','PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset','DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset','NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset','MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset','LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion','RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray','RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple','LoadImageFromNDArray', 'LoadBiomedicalImageFromFile','LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge','DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge','BioMedicalGaussianNoise', 'BioMedicalGaussianBlur','BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip','SynapseDataset', 'MyCustomDataset'
]"""with io.open('mmseg/datasets/__init__.py', 'w', encoding='utf-8') as f:f.write(custom_init)
  • 定义数据集预处理通道
custom_pipeline = """
# 数据集路径
dataset_type = 'MyCustomDataset' # 数据集类名
data_root = 'Watermelon87_Semantic_Seg_Mask/' # 数据集路径(相对于mmsegmentation主目录)# 输入模型的图像裁剪尺寸,一般是 128 的倍数,越小显存开销越少
crop_size = (640, 640)# 训练预处理
train_pipeline = [dict(type='LoadImageFromFile'),dict(type='LoadAnnotations'),dict(type='RandomResize',scale=(2048, 1024),ratio_range=(0.5, 2.0),keep_ratio=True),dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),dict(type='RandomFlip', prob=0.5),dict(type='PhotoMetricDistortion'),dict(type='PackSegInputs')
]# 测试预处理
test_pipeline = [dict(type='LoadImageFromFile'),dict(type='Resize', scale=(2048, 1024), keep_ratio=True),dict(type='LoadAnnotations'),dict(type='PackSegInputs')
]# TTA后处理
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),dict(type='TestTimeAug',transforms=[[dict(type='Resize', scale_factor=r, keep_ratio=True)for r in img_ratios],[dict(type='RandomFlip', prob=0., direction='horizontal'),dict(type='RandomFlip', prob=1., direction='horizontal')], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]])
]# 训练 Dataloader
train_dataloader = dict(batch_size=2,num_workers=4,persistent_workers=True,sampler=dict(type='InfiniteSampler', shuffle=True),dataset=dict(type=dataset_type,data_root=data_root,data_prefix=dict(img_path='img_dir/train', seg_map_path='ann_dir/train'),pipeline=train_pipeline))# 验证 Dataloader
val_dataloader = dict(batch_size=1,num_workers=4,persistent_workers=True,sampler=dict(type='DefaultSampler', shuffle=False),dataset=dict(type=dataset_type,data_root=data_root,data_prefix=dict(img_path='img_dir/val', seg_map_path='ann_dir/val'),pipeline=test_pipeline))# 测试 Dataloader
test_dataloader = val_dataloader# 验证 Evaluator
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice', 'mFscore'])# 测试 Evaluator
test_evaluator = val_evaluator
"""with io.open('configs/_base_/datasets/custom_pipeline.py', 'w', encoding='utf-8') as f:f.write(custom_pipeline)

修改配置文件

  • 主要修改类别个数、预训练权重路径、初始化图片尺寸(一般为128的整数倍)、batch_size、缩放学习率(修改的比例是 base_lr_default * (your_bs / default_bs))、更改学习率衰减策略
  • 关于学习率:主要修改optimizer中的lr,不用修改optim_wrapper
  • 冻结模型的骨干网络,对mask2former来说可以加快训练
cfg = Config.fromfile('configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py')
dataset_cfg = Config.fromfile('configs/_base_/datasets/custom_pipeline.py')
cfg.merge_from_dict(dataset_cfg)
# 类别个数
NUM_CLASS = 6
# 单卡训练时,需要把 SyncBN 改成 BN
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.crop_size = (640, 640)
cfg.model.data_preprocessor.size = cfg.crop_size# 预训练模型权重
cfg.load_from = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'# 模型 decode/auxiliary 输出头,指定为类别个数
cfg.model.decode_head.num_classes = NUM_CLASS
cfg.model.decode_head.loss_cls.class_weight = [1.0] * NUM_CLASS + [0.1]
cfg.model.backbone.frozen_stages = 4# 训练 Batch Size
cfg.train_dataloader.batch_size = 2
cfg.test_dataloader = cfg.val_dataloadercfg.optimizer.lr = cfg.optimizer.lr / 8# 结果保存目录
cfg.work_dir = './work_dirs'cfg.train_cfg.max_iters = 4000 # 训练迭代次数
cfg.train_cfg.val_interval = 50 # 评估模型间隔
cfg.default_hooks.logger.interval = 50 # 日志记录间隔
cfg.default_hooks.checkpoint.interval = 50 # 模型权重保存间隔
cfg.default_hooks.checkpoint.max_keep_ckpts = 2 # 最多保留几个模型权重
cfg.default_hooks.checkpoint.save_best = 'mIoU' # 保留指标最高的模型权重cfg.param_scheduler[0].end = cfg.train_cfg.max_iters
# 随机数种子
cfg['randomness'] = dict(seed=0)cfg.visualizer.vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')]
  • 保存配置文件
cfg.dump('custom_mask2former.py')
  • 开始训练
!python tools/train.py custom_mask2former.py
  • 选取最优模型,测试模型精度
# 取最佳模型权重
best_pth = glob.glob('work_dirs/best_mIoU*.pth')[0]
# 测试精度
!python tools/test.py custom_mask2former.py '{best_pth}'
  • 输出:
+------------+-------+-------+-------+--------+-----------+--------+
|   Class    |  IoU  |  Acc  |  Dice | Fscore | Precision | Recall |
+------------+-------+-------+-------+--------+-----------+--------+
| background | 98.55 | 99.12 | 99.27 | 99.27  |   99.42   | 99.12  |
|    red     | 96.54 | 98.83 | 98.24 | 98.24  |   97.65   | 98.83  |
|   green    | 94.37 | 96.08 |  97.1 |  97.1  |   98.14   | 96.08  |
|   white    | 85.96 | 92.67 | 92.45 | 92.45  |   92.24   | 92.67  |
| seed-black | 81.98 | 90.87 |  90.1 |  90.1  |   89.34   | 90.87  |
| seed-white | 65.57 | 69.98 | 79.21 | 79.21  |   91.24   | 69.98  |
+------------+-------+-------+-------+--------+-----------+--------+

可视化训练指标

在这里插入图片描述

肾小球数据集微调模型

  • 在单类别数据集(组织病理切片肾小球)上微调mask2former模型
  • 首先清空工作目录、data文件夹和outputs文件
# 清空工作目录
!rm -r work_dirs/*
# 清空data文件夹
!rm -r data/*
# 清空outputs文件夹
!rm -r outputs/*

可视化探索语义分割数据集

# 指定图像和标注路径
PATH_IMAGE = '/kaggle/input/glomeruli-hubmap-external-1024x1024/images_1024'
PATH_MASKS = '/kaggle/input/glomeruli-hubmap-external-1024x1024/masks_1024'mask = cv2.imread('/kaggle/input/glomeruli-hubmap-external-1024x1024/masks_1024/VUHSK_1762_29.png')
# 查看类别
np.unique(mask)
  • 输出
array([0, 1], dtype=uint8)
  • 可视化语义分割信息
# n行n列可视化
n = 5# 标注区域透明度,透明度越小,越接近原图
opacity = 0.65fig, axes = plt.subplots(nrows=n, ncols=n, sharex=True, figsize=(12,12))for i, file_name in enumerate(os.listdir(PATH_IMAGE)[:n**2]):# 载入图像和标注img_path = os.path.join(PATH_IMAGE, file_name)mask_path = os.path.join(PATH_MASKS, file_name.split('.')[0]+'.png')img = cv2.imread(img_path)mask = cv2.imread(mask_path)# 可视化axes[i//n, i%n].imshow(img[:,:,::-1])axes[i//n, i%n].imshow(mask[:,:,0], alpha=opacity)axes[i//n, i%n].axis('off') # 关闭坐标轴显示
fig.suptitle('Image and Semantic Label', fontsize=20)
plt.tight_layout()
plt.savefig('outputs/C2-1.jpg')
plt.show()

请添加图片描述

分割训练集与测试集

  • 新建各类训练、验证文件夹
# 新建图片训练、验证文件夹
!mkdir -p data/images/train
!mkdir -p data/images/val# 新建mask训练、验证文件夹
!mkdir -p data/masks/train
!mkdir -p data/masks/val
  • 随机打乱数据,并按照90%训练集、10%测试集分割
def copy_file(og_images, og_masks, tr_images, tr_masks, thor):# 获取源文件夹中的所有文件名file_names = os.listdir(og_images)# 随机打乱文件名列表random.shuffle(file_names)# 计算分割点split_index = int(thor * len(file_names))# 复制训练集文件for file_name in file_names[:split_index]:og_image = os.path.join(og_images, file_name)og_mask = os.path.join(og_masks, file_name)tr_image = os.path.join(tr_images, 'train', file_name)tr_mask = os.path.join(tr_masks, 'train', file_name)shutil.copyfile(og_image, tr_image)shutil.copyfile(og_mask, tr_mask)# 复制验证集文件for file_name in file_names[split_index:]:og_image = os.path.join(og_images, file_name)og_mask = os.path.join(og_masks, file_name)tr_image = os.path.join(tr_images, 'val', file_name)tr_mask = os.path.join(tr_masks, 'val', file_name)shutil.copyfile(og_image, tr_image)shutil.copyfile(og_mask, tr_mask)og_images = '/kaggle/input/glomeruli-hubmap-external-1024x1024/images_1024'
og_masks = '/kaggle/input/glomeruli-hubmap-external-1024x1024/masks_1024'tr_images = 'data/images'
tr_masks = 'data/masks'copy_file(og_images, og_masks, tr_images, tr_masks, 0.9)

重新定义Dataset和Pipeline

  • 主要是修改类别及对应RGB配色
  • 以及dataload的路径信息
custom_dataset = """
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset@DATASETS.register_module()
class MyCustomDataset(BaseSegDataset):# 类别和对应的RGB配色METAINFO = {'classes':['normal','sclerotic'],'palette':[[127,127,127],[251,189,8]]}# 指定图像扩展名、标注扩展名def __init__(self,img_suffix='.png',seg_map_suffix='.png',   # 标注mask图像的格式reduce_zero_label=False, # 类别ID为0的类别是否需要除去**kwargs) -> None:super().__init__(img_suffix=img_suffix,seg_map_suffix=seg_map_suffix,reduce_zero_label=reduce_zero_label,**kwargs)
"""with io.open('mmseg/datasets/MyCustomDataset.py', 'w', encoding='utf-8') as f:f.write(custom_dataset)
custom_init = """
# Copyright (c) OpenMMLab. All rights reserved.
# yapf: disable
from .ade import ADE20KDataset
from .basesegdataset import BaseSegDataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .coco_stuff import COCOStuffDataset
from .dark_zurich import DarkZurichDataset
from .dataset_wrappers import MultiImageMixDataset
from .decathlon import DecathlonDataset
from .drive import DRIVEDataset
from .hrf import HRFDataset
from .isaid import iSAIDDataset
from .isprs import ISPRSDataset
from .lip import LIPDataset
from .loveda import LoveDADataset
from .night_driving import NightDrivingDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
from .stare import STAREDataset
from .synapse import SynapseDataset
from .MyCustomDataset import MyCustomDataset
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,BioMedical3DRandomCrop, BioMedical3DRandomFlip,BioMedicalGaussianBlur, BioMedicalGaussianNoise,BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,LoadBiomedicalAnnotation, LoadBiomedicalData,LoadBiomedicalImageFromFile, LoadImageFromNDArray,PackSegInputs, PhotoMetricDistortion, RandomCrop,RandomCutOut, RandomMosaic, RandomRotate,RandomRotFlip, Rerange, ResizeShortestEdge,ResizeToMultiple, RGB2Gray, SegRescale)
from .voc import PascalVOCDataset# yapf: enable
__all__ = ['BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip','CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset','PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset','DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset','NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset','MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset','LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion','RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray','RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple','LoadImageFromNDArray', 'LoadBiomedicalImageFromFile','LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge','DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge','BioMedicalGaussianNoise', 'BioMedicalGaussianBlur','BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip','SynapseDataset', 'MyCustomDataset'
]"""with io.open('mmseg/datasets/__init__.py', 'w', encoding='utf-8') as f:f.write(custom_init)
  • 定义数据预处理管道
custom_pipeline = """
# 数据集路径
dataset_type = 'MyCustomDataset' # 数据集类名
data_root = 'data/' # 数据集路径(相对于mmsegmentation主目录)# 输入模型的图像裁剪尺寸,一般是 128 的倍数,越小显存开销越少
crop_size = (640, 640)# 训练预处理
train_pipeline = [dict(type='LoadImageFromFile'),dict(type='LoadAnnotations'),dict(type='RandomResize',scale=(2048, 1024),ratio_range=(0.5, 2.0),keep_ratio=True),dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),dict(type='RandomFlip', prob=0.5),dict(type='PhotoMetricDistortion'),dict(type='PackSegInputs')
]# 测试预处理
test_pipeline = [dict(type='LoadImageFromFile'),dict(type='Resize', scale=(2048, 1024), keep_ratio=True),dict(type='LoadAnnotations'),dict(type='PackSegInputs')
]# TTA后处理
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),dict(type='TestTimeAug',transforms=[[dict(type='Resize', scale_factor=r, keep_ratio=True)for r in img_ratios],[dict(type='RandomFlip', prob=0., direction='horizontal'),dict(type='RandomFlip', prob=1., direction='horizontal')], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]])
]# 训练 Dataloader
train_dataloader = dict(batch_size=2,num_workers=4,persistent_workers=True,sampler=dict(type='InfiniteSampler', shuffle=True),dataset=dict(type=dataset_type,data_root=data_root,data_prefix=dict(img_path='images/train', seg_map_path='masks/train'),pipeline=train_pipeline))# 验证 Dataloader
val_dataloader = dict(batch_size=1,num_workers=4,persistent_workers=True,sampler=dict(type='DefaultSampler', shuffle=False),dataset=dict(type=dataset_type,data_root=data_root,data_prefix=dict(img_path='images/val', seg_map_path='masks/val'),pipeline=test_pipeline))# 测试 Dataloader
test_dataloader = val_dataloader# 验证 Evaluator
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice', 'mFscore'])# 测试 Evaluator
test_evaluator = val_evaluator
"""with io.open('configs/_base_/datasets/custom_pipeline.py', 'w', encoding='utf-8') as f:f.write(custom_pipeline)

修改配置文件

cfg = Config.fromfile('configs/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024.py')
dataset_cfg = Config.fromfile('configs/_base_/datasets/custom_pipeline.py')
cfg.merge_from_dict(dataset_cfg)
  • 更改配置文件
# 类别个数
NUM_CLASS = 2
# 单卡训练时,需要把 SyncBN 改成 BN
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.crop_size = (640, 640)
cfg.model.data_preprocessor.size = cfg.crop_size# 预训练模型权重
cfg.load_from = 'checkpoint/mask2former_swin-l-in22k-384x384-pre_8xb2-90k_cityscapes-512x1024_20221202_141901-28ad20f1.pth'# 模型 decode/auxiliary 输出头,指定为类别个数
cfg.model.decode_head.num_classes = NUM_CLASS
cfg.model.decode_head.loss_cls.class_weight = [1.0] * NUM_CLASS + [0.1]
cfg.model.backbone.frozen_stages = 4# 训练 Batch Size
cfg.train_dataloader.batch_size = 2
cfg.test_dataloader = cfg.val_dataloadercfg.optimizer.lr = cfg.optimizer.lr / 8# 结果保存目录
cfg.work_dir = './work_dirs'cfg.train_cfg.max_iters = 40000 # 训练迭代次数
cfg.train_cfg.val_interval = 500 # 评估模型间隔
cfg.default_hooks.logger.interval = 50 # 日志记录间隔
cfg.default_hooks.checkpoint.interval = 2500 # 模型权重保存间隔
cfg.default_hooks.checkpoint.max_keep_ckpts = 2 # 最多保留几个模型权重
cfg.default_hooks.checkpoint.save_best = 'mIoU' # 保留指标最高的模型权重# 随机数种子
cfg['randomness'] = dict(seed=0)cfg.visualizer.vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')]
  • 保存配置文件,并开始训练
cfg.dump('custom_mask2former.py')
!python tools/train.py custom_mask2former.py

可视化训练指标

在这里插入图片描述

评估模型以及测试推理速度

  • 评估模型精度
# 取最佳模型权重
best_pth = glob.glob('work_dirs/best_mIoU*.pth')[0]
# 测试精度
!python tools/test.py custom_mask2former.py '{best_pth}'
  • 输出:
+-----------+-------+-------+-------+--------+-----------+--------+
|   Class   |  IoU  |  Acc  |  Dice | Fscore | Precision | Recall |
+-----------+-------+-------+-------+--------+-----------+--------+
|   normal  | 99.74 | 99.89 | 99.87 | 99.87  |   99.86   | 99.89  |
| sclerotic | 86.41 | 91.87 | 92.71 | 92.71  |   93.57   | 91.87  |
+-----------+-------+-------+-------+--------+-----------+--------+
  • 测试模型推理速度
# 测试FPS
!python tools/analysis_tools/benchmark.py custom_mask2former.py '{best_pth}'
  • 输出:
Done image [50 / 200], fps: 2.24 img / s
Done image [100/ 200], fps: 2.24 img / s
Done image [150/ 200], fps: 2.24 img / s
Done image [200/ 200], fps: 2.24 img / s
Overall fps: 2.24 img / sAverage fps of 1 evaluations: 2.24
The variance of 1 evaluations: 0.0

这篇关于利用MMSegmentation微调Mask2Former模型的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/248448

相关文章

Java的IO模型、Netty原理解析

《Java的IO模型、Netty原理解析》Java的I/O是以流的方式进行数据输入输出的,Java的类库涉及很多领域的IO内容:标准的输入输出,文件的操作、网络上的数据传输流、字符串流、对象流等,这篇... 目录1.什么是IO2.同步与异步、阻塞与非阻塞3.三种IO模型BIO(blocking I/O)NI

基于Flask框架添加多个AI模型的API并进行交互

《基于Flask框架添加多个AI模型的API并进行交互》:本文主要介绍如何基于Flask框架开发AI模型API管理系统,允许用户添加、删除不同AI模型的API密钥,感兴趣的可以了解下... 目录1. 概述2. 后端代码说明2.1 依赖库导入2.2 应用初始化2.3 API 存储字典2.4 路由函数2.5 应

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确

C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)

《C#集成DeepSeek模型实现AI私有化的流程步骤(本地部署与API调用教程)》本文主要介绍了C#集成DeepSeek模型实现AI私有化的方法,包括搭建基础环境,如安装Ollama和下载DeepS... 目录前言搭建基础环境1、安装 Ollama2、下载 DeepSeek R1 模型客户端 ChatBo

SpringBoot快速接入OpenAI大模型的方法(JDK8)

《SpringBoot快速接入OpenAI大模型的方法(JDK8)》本文介绍了如何使用AI4J快速接入OpenAI大模型,并展示了如何实现流式与非流式的输出,以及对函数调用的使用,AI4J支持JDK8... 目录使用AI4J快速接入OpenAI大模型介绍AI4J-github快速使用创建SpringBoot

0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型的操作流程

《0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeekR1模型的操作流程》DeepSeekR1模型凭借其强大的自然语言处理能力,在未来具有广阔的应用前景,有望在多个领域发... 目录0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型,3步搞定一个应

Deepseek R1模型本地化部署+API接口调用详细教程(释放AI生产力)

《DeepseekR1模型本地化部署+API接口调用详细教程(释放AI生产力)》本文介绍了本地部署DeepSeekR1模型和通过API调用将其集成到VSCode中的过程,作者详细步骤展示了如何下载和... 目录前言一、deepseek R1模型与chatGPT o1系列模型对比二、本地部署步骤1.安装oll

Spring AI Alibaba接入大模型时的依赖问题小结

《SpringAIAlibaba接入大模型时的依赖问题小结》文章介绍了如何在pom.xml文件中配置SpringAIAlibaba依赖,并提供了一个示例pom.xml文件,同时,建议将Maven仓... 目录(一)pom.XML文件:(二)application.yml配置文件(一)pom.xml文件:首

如何在本地部署 DeepSeek Janus Pro 文生图大模型

《如何在本地部署DeepSeekJanusPro文生图大模型》DeepSeekJanusPro模型在本地成功部署,支持图片理解和文生图功能,通过Gradio界面进行交互,展示了其强大的多模态处... 目录什么是 Janus Pro1. 安装 conda2. 创建 python 虚拟环境3. 克隆 janus

本地私有化部署DeepSeek模型的详细教程

《本地私有化部署DeepSeek模型的详细教程》DeepSeek模型是一种强大的语言模型,本地私有化部署可以让用户在自己的环境中安全、高效地使用该模型,避免数据传输到外部带来的安全风险,同时也能根据自... 目录一、引言二、环境准备(一)硬件要求(二)软件要求(三)创建虚拟环境三、安装依赖库四、获取 Dee