图像分割实战-系列教程17:deeplabV3+ VOC分割实战5-------main.py

2024-01-20 20:04

本文主要是介绍图像分割实战-系列教程17:deeplabV3+ VOC分割实战5-------main.py,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在这里插入图片描述

🍁🍁🍁图像分割实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传

deeplab系列算法概述
deeplabV3+ VOC分割实战1
deeplabV3+ VOC分割实战2
deeplabV3+ VOC分割实战3
deeplabV3+ VOC分割实战4
deeplabV3+ VOC分割实战5

10、main.py的main()函数

def main():opts = get_argparser().parse_args()if opts.dataset.lower() == 'voc':opts.num_classes = 21elif opts.dataset.lower() == 'cityscapes':opts.num_classes = 19# Setup visualizationvis = Visualizer(port=opts.vis_port,env=opts.vis_env) if opts.enable_vis else Noneif vis is not None:  # display optionsvis.vis_table("Options", vars(opts))os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_iddevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print("Device: %s" % device)# Setup random seedtorch.manual_seed(opts.random_seed)np.random.seed(opts.random_seed)random.seed(opts.random_seed)# Setup dataloaderif opts.dataset=='voc' and not opts.crop_val:opts.val_batch_size = 1train_dst, val_dst = get_dataset(opts)train_loader = data.DataLoader(train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=0)val_loader = data.DataLoader(val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=0)print("Dataset: %s, Train set: %d, Val set: %d" %(opts.dataset, len(train_dst), len(val_dst)))# Set up modelmodel_map = {'deeplabv3_resnet50': network.deeplabv3_resnet50,'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,'deeplabv3_resnet101': network.deeplabv3_resnet101,'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,'deeplabv3_mobilenet': network.deeplabv3_mobilenet,'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet}model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)if opts.separable_conv and 'plus' in opts.model:network.convert_to_separable_conv(model.classifier)utils.set_bn_momentum(model.backbone, momentum=0.01)# Set up metricsmetrics = StreamSegMetrics(opts.num_classes)# Set up optimizeroptimizer = torch.optim.SGD(params=[{'params': model.backbone.parameters(), 'lr': 0.1*opts.lr},{'params': model.classifier.parameters(), 'lr': opts.lr},], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)#optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)#torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)if opts.lr_policy=='poly':scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)elif opts.lr_policy=='step':scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1)# Set up criterion# criterion = utils.get_loss(opts.loss_type)if opts.loss_type == 'focal_loss':criterion = utils.FocalLoss(ignore_index=255, size_average=True)elif opts.loss_type == 'cross_entropy':criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')def save_ckpt(path):""" save current model"""torch.save({"cur_itrs": cur_itrs,"model_state": model.module.state_dict(),"optimizer_state": optimizer.state_dict(),"scheduler_state": scheduler.state_dict(),"best_score": best_score,}, path)print("Model saved as %s" % path)utils.mkdir('checkpoints')# Restorebest_score = 0.0cur_itrs = 0cur_epochs = 0if opts.ckpt is not None and os.path.isfile(opts.ckpt):# https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdancheckpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))model.load_state_dict(checkpoint["model_state"])model = nn.DataParallel(model)model.to(device)if opts.continue_training:optimizer.load_state_dict(checkpoint["optimizer_state"])scheduler.load_state_dict(checkpoint["scheduler_state"])cur_itrs = checkpoint["cur_itrs"]best_score = checkpoint['best_score']print("Training state restored from %s" % opts.ckpt)print("Model restored from %s" % opts.ckpt)del checkpoint  # free memoryelse:print("[!] Retrain")model = nn.DataParallel(model)model.to(device)#==========   Train Loop   ==========#vis_sample_id = np.random.randint(0, len(val_loader), opts.vis_num_samples,np.int32) if opts.enable_vis else None  # sample idxs for visualizationdenorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # denormalization for ori imagesif opts.test_only:model.eval()val_score, ret_samples = validate(opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id)print(metrics.to_str(val_score))returninterval_loss = 0while True: #cur_itrs < opts.total_itrs:# =====  Train  =====model.train()cur_epochs += 1for (images, labels) in train_loader:cur_itrs += 1images = images.to(device, dtype=torch.float32)labels = labels.to(device, dtype=torch.long)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()np_loss = loss.detach().cpu().numpy()interval_loss += np_lossif vis is not None:vis.vis_scalar('Loss', cur_itrs, np_loss)if (cur_itrs) % 10 == 0:interval_loss = interval_loss/10print("Epoch %d, Itrs %d/%d, Loss=%f" %(cur_epochs, cur_itrs, opts.total_itrs, interval_loss))interval_loss = 0.0if (cur_itrs) % opts.val_interval == 0:save_ckpt('checkpoints/latest_%s_%s_os%d.pth' %(opts.model, opts.dataset, opts.output_stride))print("validation...")model.eval()val_score, ret_samples = validate(opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id)print(metrics.to_str(val_score))if val_score['Mean IoU'] > best_score:  # save best modelbest_score = val_score['Mean IoU']save_ckpt('checkpoints/best_%s_%s_os%d.pth' %(opts.model, opts.dataset,opts.output_stride))if vis is not None:  # visualize validation score and samplesvis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc'])vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU'])vis.vis_table("[Val] Class IoU", val_score['Class IoU'])for k, (img, target, lbl) in enumerate(ret_samples):img = (denorm(img) * 255).astype(np.uint8)target = train_dst.decode_target(target).transpose(2, 0, 1).astype(np.uint8)lbl = train_dst.decode_target(lbl).transpose(2, 0, 1).astype(np.uint8)concat_img = np.concatenate((img, target, lbl), axis=2)  # concat along widthvis.vis_image('Sample %d' % k, concat_img)model.train()scheduler.step()  if cur_itrs >=  opts.total_itrs:returnif __name__ == '__main__':main()

deeplab系列算法概述
deeplabV3+ VOC分割实战1
deeplabV3+ VOC分割实战2
deeplabV3+ VOC分割实战3
deeplabV3+ VOC分割实战4
deeplabV3+ VOC分割实战5

这篇关于图像分割实战-系列教程17:deeplabV3+ VOC分割实战5-------main.py的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/627168

相关文章

Windows环境下解决Matplotlib中文字体显示问题的详细教程

《Windows环境下解决Matplotlib中文字体显示问题的详细教程》本文详细介绍了在Windows下解决Matplotlib中文显示问题的方法,包括安装字体、更新缓存、配置文件设置及编码調整,并... 目录引言问题分析解决方案详解1. 检查系统已安装字体2. 手动添加中文字体(以SimHei为例)步骤

Java JDK1.8 安装和环境配置教程详解

《JavaJDK1.8安装和环境配置教程详解》文章简要介绍了JDK1.8的安装流程,包括官网下载对应系统版本、安装时选择非系统盘路径、配置JAVA_HOME、CLASSPATH和Path环境变量,... 目录1.下载JDK2.安装JDK3.配置环境变量4.检验JDK官网下载地址:Java Downloads

MySQL 多列 IN 查询之语法、性能与实战技巧(最新整理)

《MySQL多列IN查询之语法、性能与实战技巧(最新整理)》本文详解MySQL多列IN查询,对比传统OR写法,强调其简洁高效,适合批量匹配复合键,通过联合索引、分批次优化提升性能,兼容多种数据库... 目录一、基础语法:多列 IN 的两种写法1. 直接值列表2. 子查询二、对比传统 OR 的写法三、性能分析

Python办公自动化实战之打造智能邮件发送工具

《Python办公自动化实战之打造智能邮件发送工具》在数字化办公场景中,邮件自动化是提升工作效率的关键技能,本文将演示如何使用Python的smtplib和email库构建一个支持图文混排,多附件,多... 目录前言一、基础配置:搭建邮件发送框架1.1 邮箱服务准备1.2 核心库导入1.3 基础发送函数二、

PowerShell中15个提升运维效率关键命令实战指南

《PowerShell中15个提升运维效率关键命令实战指南》作为网络安全专业人员的必备技能,PowerShell在系统管理、日志分析、威胁检测和自动化响应方面展现出强大能力,下面我们就来看看15个提升... 目录一、PowerShell在网络安全中的战略价值二、网络安全关键场景命令实战1. 系统安全基线核查

使用Docker构建Python Flask程序的详细教程

《使用Docker构建PythonFlask程序的详细教程》在当今的软件开发领域,容器化技术正变得越来越流行,而Docker无疑是其中的佼佼者,本文我们就来聊聊如何使用Docker构建一个简单的Py... 目录引言一、准备工作二、创建 Flask 应用程序三、创建 dockerfile四、构建 Docker

从原理到实战深入理解Java 断言assert

《从原理到实战深入理解Java断言assert》本文深入解析Java断言机制,涵盖语法、工作原理、启用方式及与异常的区别,推荐用于开发阶段的条件检查与状态验证,并强调生产环境应使用参数验证工具类替代... 目录深入理解 Java 断言(assert):从原理到实战引言:为什么需要断言?一、断言基础1.1 语

Java MQTT实战应用

《JavaMQTT实战应用》本文详解MQTT协议,涵盖其发布/订阅机制、低功耗高效特性、三种服务质量等级(QoS0/1/2),以及客户端、代理、主题的核心概念,最后提供Linux部署教程、Sprin... 目录一、MQTT协议二、MQTT优点三、三种服务质量等级四、客户端、代理、主题1. 客户端(Clien

在Spring Boot中集成RabbitMQ的实战记录

《在SpringBoot中集成RabbitMQ的实战记录》本文介绍SpringBoot集成RabbitMQ的步骤,涵盖配置连接、消息发送与接收,并对比两种定义Exchange与队列的方式:手动声明(... 目录前言准备工作1. 安装 RabbitMQ2. 消息发送者(Producer)配置1. 创建 Spr

深度解析Spring Boot拦截器Interceptor与过滤器Filter的区别与实战指南

《深度解析SpringBoot拦截器Interceptor与过滤器Filter的区别与实战指南》本文深度解析SpringBoot中拦截器与过滤器的区别,涵盖执行顺序、依赖关系、异常处理等核心差异,并... 目录Spring Boot拦截器(Interceptor)与过滤器(Filter)深度解析:区别、实现