SSF-CNN:空间光谱融合的卷积光谱图像超分网络

2023-11-27 19:01

本文主要是介绍SSF-CNN:空间光谱融合的卷积光谱图像超分网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

SSF-CNN: SPATIAL AND SPECTRAL FUSION WITH CNN FOR HYPERSPECTRAL IMAGE SUPER-RESOLUTION

文章目录

  • SSF-CNN: SPATIAL AND SPECTRAL FUSION WITH CNN FOR HYPERSPECTRAL IMAGE SUPER-RESOLUTION
    • 简介
    • 解决问题
    • 网络框架
    • 代码实现
    • 训练部分
    • 运行结果

简介

​ 本文提出了一种利用空间和光谱进行高光谱融合图像超分辨率的新型CNN架构,首先是对高光谱图像进行双三次插值,使其空间分辨率大小和多光谱一致,然后进行concat操作。使用类似于SRCNN的网络框架对融合超分的图像进行优化,最后输出高分辨率高光谱超分图像。

​ 对于PDCon,也就是引入了部分密集连接,将输入concat到每一个卷积层后面。
Hyperspectral-Image-Super-Resolution-Benchmark——光谱图像超分基准-CSDN博客
Paper: IEEE
Code:https://github.com/miraclefan777/SSFCNN

2023-11-25_16-06-09

解决问题

  1. 传统方法通过基于优化的方法恢复 HR-HS 图像的质量在很大程度上取决于预定义的约束。此外,由于约束项数量较多,优化过程通常涉及较高的计算成本。
  2. 执行HSI SR的一个直接想法是直接应用这样的网络来放大LR-HS图像的空间维度或HR-RGB图像的光谱维度,我们称之为Spatial-CNN和Spectral-CNN,这两种单图像方法忽略了两种图像特有的信息互补优势。

网络框架

  1. 原始的SRCNN是将图片映射到Ycbcr空间,并只使用其中的 Y 分量作为输入来预测 HR Y 图像,该论文则是将图片的通道信息以及空间信息整个进行输入
  2. 原始SRCNN卷积核大小第1,2修改为3*3,增加上下文信息,同时为了避免高维数据(padding为same,保持和原有特征图大小一致)

代码实现

class SSFCNNnet(nn.Module):def __init__(self, num_spectral=31, scale_factor=8, pdconv=False):super(SSFCNNnet, self).__init__()self.scale_factor = scale_factorself.pdconv = pdconvself.Upsample = nn.Upsample(mode='bicubic', scale_factor=self.scale_factor)self.conv1 = nn.Conv2d(num_spectral + 3, 64, kernel_size=3, padding="same")if pdconv:self.conv2 = nn.Conv2d(64 + 3, 32, kernel_size=3, padding="same")self.conv3 = nn.Conv2d(32 + 3, num_spectral, kernel_size=5, padding="same")else:self.conv2 = nn.Conv2d(64, 32, kernel_size=3, padding="same")self.conv3 = nn.Conv2d(32, num_spectral, kernel_size=5, padding="same")self.relu = nn.ReLU(inplace=True)def forward(self, lr_hs, hr_ms):""":param lr_hs:LR-HSI低分辨率的高光谱图像:param hr_ms:高分辨率的多光谱图像:return:"""# 对LR-HSI低分辨率图像进行上采样,让其分辨率更高lr_hs_up = self.Upsample(lr_hs)# 将上采样后的LR-HSI低分辨率图像与高分辨率的多光谱图像进行拼接x = torch.cat((lr_hs_up, hr_ms), dim=1)x = self.relu(self.conv1(x))if self.pdconv:x = torch.cat((x, hr_ms), dim=1)x = self.relu(self.conv2(x))x = torch.cat((x, hr_ms), dim=1)else:x = self.relu(self.conv2(x))out = self.conv3(x)return out

如果需要使用密集连接,只需要在初始化网络模型时,传参pdconv=True

训练部分

未提供自定义dataset类,根据自己的dateset进行参数的修改即可。

import argparse
from calculate_metrics import Loss_SAM, Loss_RMSE, Loss_PSNR
from models.SSFCNNnet import SSFCNNnet
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from train_dataloader import CAVEHSIDATAprocess
from utils import create_F, fspecial,AverageMeter
import os
import copy
import torch
import torch.nn as nnif __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--model', type=str, default="SSFCNNnet")parser.add_argument('--train-file', type=str, required=True)parser.add_argument('--eval-file', type=str, required=True)parser.add_argument('--outputs-dir', type=str, required=True)parser.add_argument('--scale', type=int, default=2)parser.add_argument('--lr', type=float, default=1e-4)parser.add_argument('--batch-size', type=int, default=32)parser.add_argument('--num-workers', type=int, default=0)parser.add_argument('--num-epochs', type=int, default=400)parser.add_argument('--seed', type=int, default=123)args = parser.parse_args()assert args.model in ['SSFCNNnet', 'PDcon_SSF']outputs_dir = os.path.join(args.outputs_dir, '{}'.format(args.model))if not os.path.exists(outputs_dir):os.makedirs(outputs_dir)device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')torch.manual_seed(args.seed)# 训练参数# loss_func = nn.L1Loss(reduction='mean').cuda()criterion = nn.MSELoss()#################数据集处理#################R = create_F()PSF = fspecial('gaussian', 8, 3)downsample_factor = 8training_size = 64stride = 32stride1 = 32train_dataset = CAVEHSIDATAprocess(args.train_file, R, training_size, stride, downsample_factor, PSF, 20)train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True)eval_dataset = CAVEHSIDATAprocess(args.eval_file, R, training_size, stride, downsample_factor, PSF, 12)eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)#################数据集处理################## 模型if args.model == 'SSFCNNnet':model = SSFCNNnet().cuda()else:model = SSFCNNnet(pdconv=True).cuda()best_weights = copy.deepcopy(model.state_dict())best_epoch = 0best_psnr = 0.0# 模型初始化for m in model.modules():if isinstance(m, (nn.Conv2d, nn.Linear)):nn.init.xavier_uniform_(m.weight)elif isinstance(m, nn.LayerNorm):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)optimizer = torch.optim.Adam([{'params': model.conv1.parameters()},{'params': model.conv2.parameters()},{'params': model.conv3.parameters(), 'lr': args.lr * 0.1}], lr=args.lr)start_epoch = 0for epoch in range(start_epoch, args.num_epochs):model.train()epoch_losses = AverageMeter()with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t:t.set_description('epoch:{}/{}'.format(epoch, args.num_epochs - 1))for data in train_dataloader:label, lr_hs, hr_ms = datalabel = label.to(device)lr_hs = lr_hs.to(device)hr_ms = hr_ms.to(device)lr = optimizer.param_groups[0]['lr']pred = model(hr_ms, lr_hs)loss = criterion(pred, label)epoch_losses.update(loss.item(), len(label))optimizer.zero_grad()loss.backward()optimizer.step()t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg), lr='{0:1.8f}'.format(lr))t.update(len(label))# torch.save(model.state_dict(), os.path.join(outputs_dir, 'epoch_{}.pth'.format(epoch)))if epoch % 5 == 0:model.eval()val_loss = AverageMeter()SAM = Loss_SAM()RMSE = Loss_RMSE()PSNR = Loss_PSNR()sam = AverageMeter()rmse = AverageMeter()psnr = AverageMeter()for data in eval_dataloader:label, lr_hs, hr_ms = datalr_hs = lr_hs.to(device)hr_ms = hr_ms.to(device)label = label.cpu().numpy()with torch.no_grad():preds = model(hr_ms, lr_hs).cpu().numpy()sam.update(SAM(preds, label), len(label))rmse.update(RMSE(preds, label), len(label))psnr.update(PSNR(preds, label), len(label))if psnr.avg > best_psnr:best_epoch = epochbest_psnr = psnr.avgbest_weights = copy.deepcopy(model.state_dict())print('eval psnr: {:.2f}  RMSE: {:.2f}  SAM: {:.2f} '.format(psnr.avg, rmse.avg, sam.avg))

运行结果

在这里插入图片描述

这篇关于SSF-CNN:空间光谱融合的卷积光谱图像超分网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/428320

相关文章

基于WinForm+Halcon实现图像缩放与交互功能

《基于WinForm+Halcon实现图像缩放与交互功能》本文主要讲述在WinForm中结合Halcon实现图像缩放、平移及实时显示灰度值等交互功能,包括初始化窗口的不同方式,以及通过特定事件添加相应... 目录前言初始化窗口添加图像缩放功能添加图像平移功能添加实时显示灰度值功能示例代码总结最后前言本文将

SSID究竟是什么? WiFi网络名称及工作方式解析

《SSID究竟是什么?WiFi网络名称及工作方式解析》SID可以看作是无线网络的名称,类似于有线网络中的网络名称或者路由器的名称,在无线网络中,设备通过SSID来识别和连接到特定的无线网络... 当提到 Wi-Fi 网络时,就避不开「SSID」这个术语。简单来说,SSID 就是 Wi-Fi 网络的名称。比如

Java实现任务管理器性能网络监控数据的方法详解

《Java实现任务管理器性能网络监控数据的方法详解》在现代操作系统中,任务管理器是一个非常重要的工具,用于监控和管理计算机的运行状态,包括CPU使用率、内存占用等,对于开发者和系统管理员来说,了解这些... 目录引言一、背景知识二、准备工作1. Maven依赖2. Gradle依赖三、代码实现四、代码详解五

基于人工智能的图像分类系统

目录 引言项目背景环境准备 硬件要求软件安装与配置系统设计 系统架构关键技术代码示例 数据预处理模型训练模型预测应用场景结论 1. 引言 图像分类是计算机视觉中的一个重要任务,目标是自动识别图像中的对象类别。通过卷积神经网络(CNN)等深度学习技术,我们可以构建高效的图像分类系统,广泛应用于自动驾驶、医疗影像诊断、监控分析等领域。本文将介绍如何构建一个基于人工智能的图像分类系统,包括环境

Linux 网络编程 --- 应用层

一、自定义协议和序列化反序列化 代码: 序列化反序列化实现网络版本计算器 二、HTTP协议 1、谈两个简单的预备知识 https://www.baidu.com/ --- 域名 --- 域名解析 --- IP地址 http的端口号为80端口,https的端口号为443 url为统一资源定位符。CSDNhttps://mp.csdn.net/mp_blog/creation/editor

ASIO网络调试助手之一:简介

多年前,写过几篇《Boost.Asio C++网络编程》的学习文章,一直没机会实践。最近项目中用到了Asio,于是抽空写了个网络调试助手。 开发环境: Win10 Qt5.12.6 + Asio(standalone) + spdlog 支持协议: UDP + TCP Client + TCP Server 独立的Asio(http://www.think-async.com)只包含了头文件,不依

poj 3181 网络流,建图。

题意: 农夫约翰为他的牛准备了F种食物和D种饮料。 每头牛都有各自喜欢的食物和饮料,而每种食物和饮料都只能分配给一头牛。 问最多能有多少头牛可以同时得到喜欢的食物和饮料。 解析: 由于要同时得到喜欢的食物和饮料,所以网络流建图的时候要把牛拆点了。 如下建图: s -> 食物 -> 牛1 -> 牛2 -> 饮料 -> t 所以分配一下点: s  =  0, 牛1= 1~

poj 3068 有流量限制的最小费用网络流

题意: m条有向边连接了n个仓库,每条边都有一定费用。 将两种危险品从0运到n-1,除了起点和终点外,危险品不能放在一起,也不能走相同的路径。 求最小的费用是多少。 解析: 抽象出一个源点s一个汇点t,源点与0相连,费用为0,容量为2。 汇点与n - 1相连,费用为0,容量为2。 每条边之间也相连,费用为每条边的费用,容量为1。 建图完毕之后,求一条流量为2的最小费用流就行了

poj 2112 网络流+二分

题意: k台挤奶机,c头牛,每台挤奶机可以挤m头牛。 现在给出每只牛到挤奶机的距离矩阵,求最小化牛的最大路程。 解析: 最大值最小化,最小值最大化,用二分来做。 先求出两点之间的最短距离。 然后二分匹配牛到挤奶机的最大路程,匹配中的判断是在这个最大路程下,是否牛的数量达到c只。 如何求牛的数量呢,用网络流来做。 从源点到牛引一条容量为1的边,然后挤奶机到汇点引一条容量为m的边

韦季李输入法_输入法和鼠标的深度融合

在数字化输入的新纪元,传统键盘输入方式正悄然进化。以往,面对实体键盘,我们常需目光游离于屏幕与键盘之间,以确认指尖下的精准位置。而屏幕键盘虽直观可见,却常因占据屏幕空间,迫使我们在操作与视野间做出妥协,频繁调整布局以兼顾输入与界面浏览。 幸而,韦季李输入法的横空出世,彻底颠覆了这一现状。它不仅对输入界面进行了革命性的重构,更巧妙地将鼠标这一传统外设融入其中,开创了一种前所未有的交互体验。 想象