本文主要是介绍mmclassification 训练自己的数据集,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
文章目录
- 从源码安装
- 数据集准备
- config文件
- 训练
- 附录
从源码安装
git clone https://github.com/open-mmlab/mmpretrain.git
cd mmpretrain
pip install -U openmim && mim install -e .
下面是我使用的版本
/media/xp/data/pydoc/mmlab/mmpretrain$ pip show mmcv mmpretrain mmengine
Name: mmcv
Version: 2.1.0
Summary: OpenMMLab Computer Vision Foundation
Home-page: https://github.com/open-mmlab/mmcv
Author: MMCV Contributors
Author-email: openmmlab@gmail.com
License: UNKNOWN
Location: /home/xp/anaconda3/envs/py3/lib/python3.8/site-packages
Requires: addict, mmengine, numpy, packaging, Pillow, pyyaml, yapf
Required-by:
---
Name: mmpretrain
Version: 1.2.0
Summary: OpenMMLab Model Pretraining Toolbox and Benchmark
Home-page: https://github.com/open-mmlab/mmpretrain
Author: MMPretrain Contributors
Author-email: openmmlab@gmail.com
License: Apache License 2.0
Location: /media/xp/data/pydoc/mmlab/mmpretrain
Editable project location: /media/xp/data/pydoc/mmlab/mmpretrain
Requires: einops, importlib-metadata, mat4py, matplotlib, modelindex, numpy, rich
Required-by:
---
Name: mmengine
Version: 0.10.3
Summary: Engine of OpenMMLab projects
Home-page: https://github.com/open-mmlab/mmengine
Author: MMEngine Authors
Author-email: openmmlab@gmail.com
License: UNKNOWN
Location: /home/xp/anaconda3/envs/py3/lib/python3.8/site-packages
Requires: addict, matplotlib, numpy, opencv-python, pyyaml, rich, termcolor, yapf
Required-by: mmcv
数据集准备
我以cat and dog分类数据集为例,我的训练集如下
/media/xp/data/image/deep_image/mini_cat_and_dog$ tree -L 2
.
├── train
│ ├── cat
│ └── dog
└── val├── cat└── dog
注意
:我训练的时候有些图好像是坏的,mmcv以opencv为后端来获取图片,这里最好先把坏图过滤掉,不然训练的时候会报cv imencode失败或者找不到图像。用下面的代码可以去除掉opencv打不开的图。
import cv2 as cv
import osdef find_all_image_files(root_dir):image_files = []for root, dirs, files in os.walk(root_dir):for file in files:if file.endswith('.jpg') or file.endswith('.png'):image_files.append(os.path.join(root, file))return image_filesdef is_bad_image(image_file):try:img = cv.imread(image_file)if img is None:return Truereturn Falseexcept:return Truedef remove_bad_images(root_dir):image_files = find_all_image_files(root_dir)for image_file in image_files:if is_bad_image(image_file):os.remove(image_file)print(f"Removed bad image: {image_file}")remove_bad_images("/media/xp/data/image/deep_image/mini_cat_and_dog")
config文件
mmlab系列的训练测试转化都是以config来配置的,三个基础块,一个是数据集,一个是模型,一个是runtime,有很多模型都是从_base_目录中继承这三个组件,然后修改其中的一些选项来训练不同的模型和数据集。
在训练的时候mm会保存一个训练的配置到work_dir目录下,后面也可以直接复制这个config去修改,把所有内容整合到一个config中,方便管理。如果你也喜欢这样的方式可以直接copy附录中的config修改去训练。
下面是我训练mobilenet v3时修改的config。
- 在config/mobilenet_v3 目录下添加一个文件my_mobilenetv3.py
configs/mobilenet_v3/my_mobilenetv3.py
_base_ = [# '../_base_/models/mobilenet_v3/mobilenet_v3_small_075_imagenet.py','../_base_/datasets/my_custom.py','../_base_/default_runtime.py',
]# model settingsmodel = dict(type='ImageClassifier',backbone=dict(type='MobileNetV3', arch='small_075'),neck=dict(type='GlobalAveragePooling'),head=dict(type='StackedLinearClsHead',num_classes=2,in_channels=432,mid_channels=[1024],dropout_rate=0.2,act_cfg=dict(type='HSwish'),loss=dict(type='CrossEntropyLoss', loss_weight=1.0),init_cfg=dict(type='Normal', layer='Linear', mean=0., std=0.01, bias=0.),topk=(1, 1)))
# model = dict(backbone=dict(norm_cfg=dict(type='BN', eps=1e-5, momentum=0.1)))my_image_size = 128
my_max_epochs = 300
my_batch_size = 128train_pipeline = [dict(type='LoadImageFromFile'),dict(type='RandomResizedCrop',scale=my_image_size,backend='pillow',interpolation='bicubic'),dict(type='RandomFlip', prob=0.5, direction='horizontal'),dict(type='AutoAugment',policies='imagenet',hparams=dict(pad_val=[round(x) for x in [128,128,128]])),dict(type='RandomErasing',erase_prob=0.2,mode='rand',min_area_ratio=0.02,max_area_ratio=1 / 3,fill_color=[128,128,128],fill_std=[50,50,50]),dict(type='PackInputs'),
]test_pipeline = [dict(type='LoadImageFromFile'),dict(type='ResizeEdge',scale=my_image_size,edge='short',backend='pillow',interpolation='bicubic'),dict(type='CenterCrop', crop_size=my_image_size),dict(type='PackInputs'),
]train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader# schedule settings
optim_wrapper = dict(optimizer=dict(type='RMSprop',lr=0.064,alpha=0.9,momentum=0.9,eps=0.0316,weight_decay=1e-5))param_scheduler = dict(type='StepLR', by_epoch=True, step_size=2, gamma=0.973)train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=10)
val_cfg = dict()
test_cfg = dict()# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
# base_batch_size = (8 GPUs) x (128 samples per GPU)
auto_scale_lr = dict(base_batch_size=my_batch_size)
- 在configs/base/datasets/下面创建 my_custom.py
# dataset settings
dataset_type = 'CustomDataset'
data_preprocessor = dict(num_classes=2,# RGB format normalization parametersmean=[128,128,128],std=[50,50,50],# convert image from BGR to RGBto_rgb=True,
)train_pipeline = [dict(type='LoadImageFromFile'),dict(type='ResizeEdge', scale=128, edge='short'),dict(type='CenterCrop', crop_size=128),dict(type='RandomFlip', prob=0.5, direction='horizontal'),dict(type='PackInputs'),
]test_pipeline = [dict(type='LoadImageFromFile'),dict(type='ResizeEdge', scale=128, edge='short'),dict(type='CenterCrop', crop_size=128),dict(type='PackInputs'),
]train_dataloader = dict(batch_size=32,num_workers=1,dataset=dict(type=dataset_type,data_root='/media/xp/data/image/deep_image/mini_cat_and_dog',data_prefix='train',with_label=True,pipeline=train_pipeline),sampler=dict(type='DefaultSampler', shuffle=True),
)val_dataloader = dict(batch_size=32,num_workers=1,dataset=dict(type=dataset_type,data_root='/media/xp/data/image/deep_image/mini_cat_and_dog',data_prefix='val',with_label=True,pipeline=test_pipeline),sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 1))# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator
训练
$ python tools/train.py configs/mobilenet_v3/my_mobilenetv3.py
输出
04/22 10:09:07 - mmengine - INFO -
------------------------------------------------------------
System environment:sys.platform: linuxPython: 3.8.18 (default, Sep 11 2023, 13:40:15) [GCC 11.2.0]CUDA available: FalseMUSA available: Falsenumpy_random_seed: 1921958984GCC: gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0PyTorch: 2.2.2PyTorch compiling details: PyTorch built with:- GCC 9.3- C++ Version: 201703- Intel(R) oneAPI Math Kernel Library Version 2023.1-Product Build 20230303 for Intel(R) 64 architecture applications- Intel(R) MKL-DNN v3.3.2 (Git Hash 2dc95a2ad0841e29db8b22fbccaf3e5da7992b01)- OpenMP 201511 (a.k.a. OpenMP 4.5)- LAPACK is enabled (usually provided by MKL)- NNPACK is enabled- CPU capability usage: AVX2- Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=2.2.2, USE_CUDA=0, USE_CUDNN=OFF, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF, TorchVision: 0.17.2OpenCV: 4.9.0MMEngine: 0.10.3Runtime environment:cudnn_benchmark: Falsemp_cfg: {'mp_start_method': 'fork', 'opencv_num_threads': 0}dist_cfg: {'backend': 'nccl'}seed: 1921958984deterministic: FalseDistributed launcher: noneDistributed training: FalseGPU number: 1
--------------------------------------
04/22 10:09:08 - mmengine - WARNING - "FileClient" will be deprecated in future. Please use io functions in https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io
04/22 10:09:08 - mmengine - WARNING - "HardDiskBackend" is the alias of "LocalBackend" and the former will be deprecated in future.
04/22 10:09:08 - mmengine - INFO - Checkpoints will be saved to /media/xp/data/pydoc/mmlab/mmpretrain/work_dirs/my_mobilenetv3.
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:09:17 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:09:17 - mmengine - INFO - Epoch(train) [1][98/98] lr: 6.4000e-02 eta: 1:31:37 time: 0.0913 data_time: 0.0129 loss: 11.2596
04/22 10:09:17 - mmengine - INFO - Saving checkpoint at 1 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:09:26 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:09:26 - mmengine - INFO - Epoch(train) [2][98/98] lr: 6.4000e-02 eta: 1:30:36 time: 0.0905 data_time: 0.0129 loss: 0.7452
04/22 10:09:26 - mmengine - INFO - Saving checkpoint at 2 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:09:35 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:09:35 - mmengine - INFO - Epoch(train) [3][98/98] lr: 6.2272e-02 eta: 1:29:30 time: 0.0841 data_time: 0.0059 loss: 0.7198
04/22 10:09:35 - mmengine - INFO - Saving checkpoint at 3 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:09:44 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:09:44 - mmengine - INFO - Epoch(train) [4][98/98] lr: 6.2272e-02 eta: 1:29:02 time: 0.0856 data_time: 0.0047 loss: 0.6938
04/22 10:09:44 - mmengine - INFO - Saving checkpoint at 4 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:09:53 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:09:53 - mmengine - INFO - Epoch(train) [5][98/98] lr: 6.0591e-02 eta: 1:28:42 time: 0.0877 data_time: 0.0100 loss: 0.7128
04/22 10:09:53 - mmengine - INFO - Saving checkpoint at 5 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:10:02 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:10:02 - mmengine - INFO - Epoch(train) [6][98/98] lr: 6.0591e-02 eta: 1:28:32 time: 0.0857 data_time: 0.0069 loss: 0.7214
04/22 10:10:02 - mmengine - INFO - Saving checkpoint at 6 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:10:11 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:10:11 - mmengine - INFO - Epoch(train) [7][98/98] lr: 5.8955e-02 eta: 1:28:11 time: 0.0860 data_time: 0.0063 loss: 0.7113
04/22 10:10:11 - mmengine - INFO - Saving checkpoint at 7 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:10:20 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:10:20 - mmengine - INFO - Epoch(train) [8][98/98] lr: 5.8955e-02 eta: 1:28:05 time: 0.0881 data_time: 0.0083 loss: 0.6989
04/22 10:10:20 - mmengine - INFO - Saving checkpoint at 8 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:10:29 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:10:29 - mmengine - INFO - Epoch(train) [9][98/98] lr: 5.7363e-02 eta: 1:28:23 time: 0.0883 data_time: 0.0077 loss: 0.6874
04/22 10:10:29 - mmengine - INFO - Saving checkpoint at 9 epochs
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
04/22 10:10:39 - mmengine - INFO - Exp name: my_mobilenetv3_20240422_100907
04/22 10:10:39 - mmengine - INFO - Epoch(train) [10][98/98] lr: 5.7363e-02 eta: 1:28:28 time: 0.0894 data_time: 0.0068 loss: 0.7028
04/22 10:10:39 - mmengine - INFO - Saving checkpoint at 10 epochs
04/22 10:10:39 - mmengine - INFO - Epoch(val) [10][3/3] accuracy/top1: 60.8696 data_time: 0.0411 time: 0.0650
附录
- 数据集准备
官方文档 - 训练完整config,可以直接修改了拿去训练用的,三个模块整合一起的。
my_train_batch_size = 64
my_val_batch_size = 16
my_image_size = 128
my_max_epochs = 300my_checkpoints_interval = 10 # 10 epochs to save a checkpointmy_train_dataset_root = '/media/xp/data/image/deep_image/mini_cat_and_dog'
my_train_data_prefix = 'train'
my_val_dataset_root = '/media/xp/data/image/deep_image/mini_cat_and_dog'
my_val_data_prefix = 'val'
my_test_dataset_root = '/media/xp/data/image/deep_image/mini_cat_and_dog'
my_test_data_prefix = 'test'work_dir = './work_dirs/my_mobilenetv3'my_class_names = ['cat', 'dog']auto_scale_lr = dict(base_batch_size=128)
data_preprocessor = dict(mean=[128,128,128,], num_classes=2, std=[50,50,50,], to_rgb=True)
dataset_type = 'CustomDataset'default_hooks = dict(checkpoint=dict(interval=my_checkpoints_interval, type='CheckpointHook'),logger=dict(interval=100, type='LoggerHook'),param_scheduler=dict(type='ParamSchedulerHook'),sampler_seed=dict(type='DistSamplerSeedHook'),timer=dict(type='IterTimerHook'),visualization=dict(enable=False, type='VisualizationHook'))
default_scope = 'mmpretrain'
env_cfg = dict(cudnn_benchmark=False,dist_cfg=dict(backend='nccl'),mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
launcher = 'none'
load_from = None
log_level = 'INFO'
model = dict(backbone=dict(arch='small_075', type='MobileNetV3'),head=dict(act_cfg=dict(type='HSwish'),dropout_rate=0.2,in_channels=432,init_cfg=dict(bias=0.0, layer='Linear', mean=0.0, std=0.01, type='Normal'),loss=dict(loss_weight=1.0, type='CrossEntropyLoss'),mid_channels=[1024,],num_classes=len(my_class_names),topk=(1,1,),type='StackedLinearClsHead'),neck=dict(type='GlobalAveragePooling'),type='ImageClassifier')optim_wrapper = dict(optimizer=dict(alpha=0.9,eps=0.0316,lr=0.064,momentum=0.9,type='RMSprop',weight_decay=1e-05))
param_scheduler = dict(by_epoch=True, gamma=0.973, step_size=2, type='StepLR')
randomness = dict(deterministic=False, seed=None)
resume = False
test_cfg = dict()
test_dataloader = dict(batch_size=my_val_batch_size,collate_fn=dict(type='default_collate'),dataset=dict(data_prefix='val',data_root=my_val_dataset_root,pipeline=[dict(type='LoadImageFromFile'),dict(backend='pillow',edge='short',interpolation='bicubic',scale=my_image_size,type='ResizeEdge'),dict(crop_size=my_image_size, type='CenterCrop'),dict(type='PackInputs'),],type='CustomDataset',with_label=True),num_workers=1,persistent_workers=True,pin_memory=True,sampler=dict(shuffle=False, type='DefaultSampler'))
test_evaluator = dict(topk=(1,1,), type='Accuracy')
test_pipeline = [dict(type='LoadImageFromFile'),dict(backend='pillow',edge='short',interpolation='bicubic',scale=my_image_size,type='ResizeEdge'),dict(crop_size=my_image_size, type='CenterCrop'),dict(type='PackInputs'),
]
train_cfg = dict(by_epoch=True, max_epochs=my_max_epochs, val_interval=10)
train_dataloader = dict(batch_size=my_train_batch_size,collate_fn=dict(type='default_collate'),dataset=dict(data_prefix=my_train_data_prefix,data_root=my_train_dataset_root,pipeline=[dict(type='LoadImageFromFile'),dict(backend='pillow',interpolation='bicubic',scale=my_image_size,type='RandomResizedCrop'),dict(direction='horizontal', prob=0.5, type='RandomFlip'),dict(hparams=dict(pad_val=[128,128,128,]),policies='imagenet',type='AutoAugment'),dict(erase_prob=0.2,fill_color=[128,128,128,],fill_std=[50,50,50,],max_area_ratio=0.3333333333333333,min_area_ratio=0.02,mode='rand',type='RandomErasing'),dict(type='PackInputs'),],type='CustomDataset',with_label=True),num_workers=1,persistent_workers=True,pin_memory=True,sampler=dict(shuffle=True, type='DefaultSampler'))
train_pipeline = [dict(type='LoadImageFromFile'),dict(backend='pillow',interpolation='bicubic',scale=my_image_size,type='RandomResizedCrop'),dict(direction='horizontal', prob=0.5, type='RandomFlip'),dict(hparams=dict(pad_val=[128,128,128,]),policies='imagenet',type='AutoAugment'),dict(erase_prob=0.2,fill_color=[128,128,128,],fill_std=[50,50,50,],max_area_ratio=0.3333333333333333,min_area_ratio=0.02,mode='rand',type='RandomErasing'),dict(type='PackInputs'),
]
val_cfg = dict()
val_dataloader = dict(batch_size=my_val_batch_size,collate_fn=dict(type='default_collate'),dataset=dict(data_prefix=my_val_data_prefix,data_root=my_val_dataset_root,pipeline=[dict(type='LoadImageFromFile'),dict(backend='pillow',edge='short',interpolation='bicubic',scale=my_image_size,type='ResizeEdge'),dict(crop_size=my_image_size, type='CenterCrop'),dict(type='PackInputs'),],type='CustomDataset',with_label=True),num_workers=1,persistent_workers=True,pin_memory=True,sampler=dict(shuffle=False, type='DefaultSampler'))
val_evaluator = dict(topk=(1,1,), type='Accuracy')
vis_backends = [dict(type='LocalVisBackend'),
]
visualizer = dict(type='UniversalVisualizer', vis_backends=[dict(type='LocalVisBackend'),])
这篇关于mmclassification 训练自己的数据集的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!