图像分割实战-系列教程17:deeplabV3+ VOC分割实战5-------main.py

2024-01-20 20:04

本文主要是介绍图像分割实战-系列教程17:deeplabV3+ VOC分割实战5-------main.py,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在这里插入图片描述

🍁🍁🍁图像分割实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传

deeplab系列算法概述
deeplabV3+ VOC分割实战1
deeplabV3+ VOC分割实战2
deeplabV3+ VOC分割实战3
deeplabV3+ VOC分割实战4
deeplabV3+ VOC分割实战5

10、main.py的main()函数

def main():opts = get_argparser().parse_args()if opts.dataset.lower() == 'voc':opts.num_classes = 21elif opts.dataset.lower() == 'cityscapes':opts.num_classes = 19# Setup visualizationvis = Visualizer(port=opts.vis_port,env=opts.vis_env) if opts.enable_vis else Noneif vis is not None:  # display optionsvis.vis_table("Options", vars(opts))os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_iddevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print("Device: %s" % device)# Setup random seedtorch.manual_seed(opts.random_seed)np.random.seed(opts.random_seed)random.seed(opts.random_seed)# Setup dataloaderif opts.dataset=='voc' and not opts.crop_val:opts.val_batch_size = 1train_dst, val_dst = get_dataset(opts)train_loader = data.DataLoader(train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=0)val_loader = data.DataLoader(val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=0)print("Dataset: %s, Train set: %d, Val set: %d" %(opts.dataset, len(train_dst), len(val_dst)))# Set up modelmodel_map = {'deeplabv3_resnet50': network.deeplabv3_resnet50,'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,'deeplabv3_resnet101': network.deeplabv3_resnet101,'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,'deeplabv3_mobilenet': network.deeplabv3_mobilenet,'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet}model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)if opts.separable_conv and 'plus' in opts.model:network.convert_to_separable_conv(model.classifier)utils.set_bn_momentum(model.backbone, momentum=0.01)# Set up metricsmetrics = StreamSegMetrics(opts.num_classes)# Set up optimizeroptimizer = torch.optim.SGD(params=[{'params': model.backbone.parameters(), 'lr': 0.1*opts.lr},{'params': model.classifier.parameters(), 'lr': opts.lr},], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)#optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)#torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)if opts.lr_policy=='poly':scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)elif opts.lr_policy=='step':scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1)# Set up criterion# criterion = utils.get_loss(opts.loss_type)if opts.loss_type == 'focal_loss':criterion = utils.FocalLoss(ignore_index=255, size_average=True)elif opts.loss_type == 'cross_entropy':criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')def save_ckpt(path):""" save current model"""torch.save({"cur_itrs": cur_itrs,"model_state": model.module.state_dict(),"optimizer_state": optimizer.state_dict(),"scheduler_state": scheduler.state_dict(),"best_score": best_score,}, path)print("Model saved as %s" % path)utils.mkdir('checkpoints')# Restorebest_score = 0.0cur_itrs = 0cur_epochs = 0if opts.ckpt is not None and os.path.isfile(opts.ckpt):# https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdancheckpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))model.load_state_dict(checkpoint["model_state"])model = nn.DataParallel(model)model.to(device)if opts.continue_training:optimizer.load_state_dict(checkpoint["optimizer_state"])scheduler.load_state_dict(checkpoint["scheduler_state"])cur_itrs = checkpoint["cur_itrs"]best_score = checkpoint['best_score']print("Training state restored from %s" % opts.ckpt)print("Model restored from %s" % opts.ckpt)del checkpoint  # free memoryelse:print("[!] Retrain")model = nn.DataParallel(model)model.to(device)#==========   Train Loop   ==========#vis_sample_id = np.random.randint(0, len(val_loader), opts.vis_num_samples,np.int32) if opts.enable_vis else None  # sample idxs for visualizationdenorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # denormalization for ori imagesif opts.test_only:model.eval()val_score, ret_samples = validate(opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id)print(metrics.to_str(val_score))returninterval_loss = 0while True: #cur_itrs < opts.total_itrs:# =====  Train  =====model.train()cur_epochs += 1for (images, labels) in train_loader:cur_itrs += 1images = images.to(device, dtype=torch.float32)labels = labels.to(device, dtype=torch.long)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()np_loss = loss.detach().cpu().numpy()interval_loss += np_lossif vis is not None:vis.vis_scalar('Loss', cur_itrs, np_loss)if (cur_itrs) % 10 == 0:interval_loss = interval_loss/10print("Epoch %d, Itrs %d/%d, Loss=%f" %(cur_epochs, cur_itrs, opts.total_itrs, interval_loss))interval_loss = 0.0if (cur_itrs) % opts.val_interval == 0:save_ckpt('checkpoints/latest_%s_%s_os%d.pth' %(opts.model, opts.dataset, opts.output_stride))print("validation...")model.eval()val_score, ret_samples = validate(opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id)print(metrics.to_str(val_score))if val_score['Mean IoU'] > best_score:  # save best modelbest_score = val_score['Mean IoU']save_ckpt('checkpoints/best_%s_%s_os%d.pth' %(opts.model, opts.dataset,opts.output_stride))if vis is not None:  # visualize validation score and samplesvis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc'])vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU'])vis.vis_table("[Val] Class IoU", val_score['Class IoU'])for k, (img, target, lbl) in enumerate(ret_samples):img = (denorm(img) * 255).astype(np.uint8)target = train_dst.decode_target(target).transpose(2, 0, 1).astype(np.uint8)lbl = train_dst.decode_target(lbl).transpose(2, 0, 1).astype(np.uint8)concat_img = np.concatenate((img, target, lbl), axis=2)  # concat along widthvis.vis_image('Sample %d' % k, concat_img)model.train()scheduler.step()  if cur_itrs >=  opts.total_itrs:returnif __name__ == '__main__':main()

deeplab系列算法概述
deeplabV3+ VOC分割实战1
deeplabV3+ VOC分割实战2
deeplabV3+ VOC分割实战3
deeplabV3+ VOC分割实战4
deeplabV3+ VOC分割实战5

这篇关于图像分割实战-系列教程17:deeplabV3+ VOC分割实战5-------main.py的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/627168

相关文章

springboot使用Scheduling实现动态增删启停定时任务教程

《springboot使用Scheduling实现动态增删启停定时任务教程》:本文主要介绍springboot使用Scheduling实现动态增删启停定时任务教程,具有很好的参考价值,希望对大家有... 目录1、配置定时任务需要的线程池2、创建ScheduledFuture的包装类3、注册定时任务,增加、删

如何为Yarn配置国内源的详细教程

《如何为Yarn配置国内源的详细教程》在使用Yarn进行项目开发时,由于网络原因,直接使用官方源可能会导致下载速度慢或连接失败,配置国内源可以显著提高包的下载速度和稳定性,本文将详细介绍如何为Yarn... 目录一、查询当前使用的镜像源二、设置国内源1. 设置为淘宝镜像源2. 设置为其他国内源三、还原为官方

Python实现图片分割的多种方法总结

《Python实现图片分割的多种方法总结》图片分割是图像处理中的一个重要任务,它的目标是将图像划分为多个区域或者对象,本文为大家整理了一些常用的分割方法,大家可以根据需求自行选择... 目录1. 基于传统图像处理的分割方法(1) 使用固定阈值分割图片(2) 自适应阈值分割(3) 使用图像边缘检测分割(4)

一文带你搞懂Python中__init__.py到底是什么

《一文带你搞懂Python中__init__.py到底是什么》朋友们,今天我们来聊聊Python里一个低调却至关重要的文件——__init__.py,有些人可能听说过它是“包的标志”,也有人觉得它“没... 目录先搞懂 python 模块(module)Python 包(package)是啥?那么 __in

使用Python实现图像LBP特征提取的操作方法

《使用Python实现图像LBP特征提取的操作方法》LBP特征叫做局部二值模式,常用于纹理特征提取,并在纹理分类中具有较强的区分能力,本文给大家介绍了如何使用Python实现图像LBP特征提取的操作方... 目录一、LBP特征介绍二、LBP特征描述三、一些改进版本的LBP1.圆形LBP算子2.旋转不变的LB

Maven的使用和配置国内源的保姆级教程

《Maven的使用和配置国内源的保姆级教程》Maven是⼀个项目管理工具,基于POM(ProjectObjectModel,项目对象模型)的概念,Maven可以通过一小段描述信息来管理项目的构建,报告... 目录1. 什么是Maven?2.创建⼀个Maven项目3.Maven 核心功能4.使用Maven H

IDEA自动生成注释模板的配置教程

《IDEA自动生成注释模板的配置教程》本文介绍了如何在IntelliJIDEA中配置类和方法的注释模板,包括自动生成项目名称、包名、日期和时间等内容,以及如何定制参数和返回值的注释格式,需要的朋友可以... 目录项目场景配置方法类注释模板定义类开头的注释步骤类注释效果方法注释模板定义方法开头的注释步骤方法注

Python列表去重的4种核心方法与实战指南详解

《Python列表去重的4种核心方法与实战指南详解》在Python开发中,处理列表数据时经常需要去除重复元素,本文将详细介绍4种最实用的列表去重方法,有需要的小伙伴可以根据自己的需要进行选择... 目录方法1:集合(set)去重法(最快速)方法2:顺序遍历法(保持顺序)方法3:副本删除法(原地修改)方法4:

在Spring Boot中浅尝内存泄漏的实战记录

《在SpringBoot中浅尝内存泄漏的实战记录》本文给大家分享在SpringBoot中浅尝内存泄漏的实战记录,结合实例代码给大家介绍的非常详细,感兴趣的朋友一起看看吧... 目录使用静态集合持有对象引用,阻止GC回收关键点:可执行代码:验证:1,运行程序(启动时添加JVM参数限制堆大小):2,访问 htt

Python如何将大TXT文件分割成4KB小文件

《Python如何将大TXT文件分割成4KB小文件》处理大文本文件是程序员经常遇到的挑战,特别是当我们需要把一个几百MB甚至几个GB的TXT文件分割成小块时,下面我们来聊聊如何用Python自动完成这... 目录为什么需要分割TXT文件基础版:按行分割进阶版:精确控制文件大小完美解决方案:支持UTF-8编码