计算psnr ssim niqe fid mae lpips等指标的代码

2024-04-10 21:28

本文主要是介绍计算psnr ssim niqe fid mae lpips等指标的代码,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  • 以下代码仅供参考,路径处理最好自己改一下
# Author: Wu
# Created: 2023/8/15
# module containing metrics functions
# using package in https://github.com/chaofengc/IQA-PyTorch
import torch
from PIL import Image
import numpy as np
from piqa import PSNR, SSIM
import pyiqa
import argparse
import os
from collections import defaultdict
first = True
first2 = True
lpips_metric = None
niqe_metric = None
config = None
def read_img(img_path, ref_image=None):img = Image.open(img_path).convert('RGB')# resize gt to size of inputif ref_image is not None: w,h = img.size_,_, h_ref, w_ref = ref_image.shapeif w_ref!=w or h_ref!=h:img = img.resize((w_ref, h_ref), Image.ANTIALIAS)img = (np.asarray(img)/255.0)img = torch.from_numpy(img).float()img = img.permute(2,0,1)img = img.to(torch.device(f'cuda:{config.device}')).unsqueeze(0)return img.contiguous()def get_NIQE(enhanced_image, gt_path=None):niqe_metric = pyiqa.create_metric('niqe', device=enhanced_image.device).to(torch.device(f'cuda:{config.device}'))return  niqe_metric(enhanced_image)
def get_FID(enhanced_image_path, gt_path):fid_metric = pyiqa.create_metric('fid').to(torch.device(f'cuda:{config.device}'))score = fid_metric(enhanced_image_path, gt_path)return score
def get_psnr(enhanced_image, gt_path):gtimg = Image.open(gt_path).convert('RGB')gtimg = gtimg.resize((1200, 900), Image.ANTIALIAS)gtimg = (np.asarray(gtimg)/255.0)gtimg = torch.from_numpy(gtimg).float()gtimg = gtimg.permute(2,0,1)gtimg = gtimg.to(torch.device(f'cuda:{config.device}')).unsqueeze(0).contiguous()criterion = PSNR().to(torch.device(f'cuda:{config.device}'))return criterion(enhanced_image, gtimg).cpu().item()
def get_ssim(enhanced_image, gt_path):gtimg = Image.open(gt_path).convert('RGB')gtimg = gtimg.resize((1200, 900), Image.ANTIALIAS)gtimg = (np.asarray(gtimg)/255.0)gtimg = torch.from_numpy(gtimg).float()gtimg = gtimg.permute(2,0,1)gtimg = gtimg.to(torch.device(f'cuda:{config.device}')).unsqueeze(0).contiguous()criterion = SSIM().to(torch.device(f'cuda:{config.device}'))return criterion(enhanced_image, gtimg).cpu().item()
def get_lpips(enhanced_image, gt_path):gtimg = Image.open(gt_path).convert('RGB')gtimg = gtimg.resize((1200, 900), Image.ANTIALIAS)gtimg = (np.asarray(gtimg)/255.0)gtimg = torch.from_numpy(gtimg).float()gtimg = gtimg.permute(2,0,1)gtimg = gtimg.to(torch.device(f'cuda:{config.device}')).unsqueeze(0).contiguous()iqa_metric = pyiqa.create_metric('lpips', device=enhanced_image.device)return iqa_metric(enhanced_image, gtimg).cpu().item()
def get_MAE(enhanced_image, gt_path):gtimg = Image.open(gt_path).convert('RGB')gtimg = gtimg.resize((1200, 900), Image.ANTIALIAS)gtimg = (np.asarray(gtimg)/255.0)gtimg = torch.from_numpy(gtimg).float()gtimg = gtimg.permute(2,0,1)gtimg = gtimg.to(torch.device(f'cuda:{config.device}')).unsqueeze(0).contiguous()return torch.mean(torch.abs(enhanced_image-gtimg)).cpu().item()def get_metric(enhanced_image, gt_path, metrics):if gt_path is not None:gtimg = read_img(gt_path, enhanced_image)else:gtimg = Noneres = dict()if 'psnr' in metrics:psnr = PSNR().to(torch.device(f'cuda:{config.device}'))res['psnr'] = psnr(enhanced_image, gtimg).cpu().item()if 'ssim' in metrics:ssim = SSIM().to(torch.device(f'cuda:{config.device}'))res['ssim'] = ssim(enhanced_image, gtimg).cpu().item()if 'mae' in metrics:res['mae'] = torch.mean(torch.abs(enhanced_image-gtimg)).cpu().item()if 'niqe' in metrics:global first2global niqe_metricif first2:first2 = Falseniqe_metric = pyiqa.create_metric('niqe', device=enhanced_image.device)res['niqe'] = niqe_metric(enhanced_image).cpu().item()if 'lpips' in metrics:global firstglobal lpips_metricif first:first = Falselpips_metric = pyiqa.create_metric('lpips', device=enhanced_image.device)res['lpips'] = lpips_metric(enhanced_image, gtimg).cpu().item()return resdef get_metrics_dataset(pred_path, gt_path, dataset='lol'):if dataset == 'fivek':input_file_path_list = []gt_file_path_list = []file_list = os.listdir(os.path.join(gt_path))for filename in file_list:input_file_path_list.append(os.path.join(pred_path, filename))gt_file_path_list.append(os.path.join(gt_path,  filename))elif dataset == 'lol':input_file_path_list = []gt_file_path_list = []file_list = os.listdir(os.path.join(gt_path))for filename in file_list:input_file_path_list.append(os.path.join(pred_path, filename.replace('normal', 'low')))gt_file_path_list.append(os.path.join(gt_path,  filename))elif dataset == 'EE':input_file_path_list = []gt_file_path_list = []file_list = os.listdir(os.path.join(pred_path))for filename in file_list:input_file_path_list.append(os.path.join(pred_path, filename))suffix = filename.split('_')[-1]new_filename = filename[:-len(suffix)-1]+'.jpg'gt_file_path_list.append(os.path.join(gt_path,  new_filename))elif dataset == 'upair':input_file_path_list = []gt_file_path_list = []file_list = os.listdir(os.path.join(pred_path))for filename in file_list:input_file_path_list.append(os.path.join(pred_path, filename))gt_file_path_list.append(None)else:print(f'{dataset} not supported')exit()return input_file_path_list, gt_file_path_listif __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--gt', type=str, default="/data1/wjh/LOL_v2/Real_captured/eval/gt")parser.add_argument('--pred', type=str, default="/data1/wjh/ECNet/baseline/gt_referenced/output")parser.add_argument('--dataset', type=str, default="lol")parser.add_argument('--device', type=str, default="0")parser.add_argument('--psnr', action='store_true')parser.add_argument('--ssim', action='store_true')parser.add_argument('--fid', action='store_true')parser.add_argument('--niqe', action='store_true')parser.add_argument('--lpips', action='store_true')parser.add_argument('--mae', action='store_true')config = parser.parse_args()print(config)gt_path = config.gtpred_path = config.pred# os.environ['CUDA_VISIBLE_DEVICES']=config.deviceassert os.path.exists(gt_path), 'gt_path not exits'assert os.path.exists(pred_path), 'pred_path not exits'metrics_names = []for metrics_name in ['psnr', 'ssim', 'niqe', 'lpips', 'mae']:if vars(config)[metrics_name]:metrics_names.append(metrics_name)# compute metricsmetrics_dict = defaultdict(list)metrics = dict()with torch.no_grad():# load img pathinput_file_paths,  gt_file_paths = get_metrics_dataset(pred_path, gt_path, config.dataset)# read img and compute metricsfor input_file_path, gt_file_path in zip(input_file_paths, gt_file_paths):# print(input_file_path)pred = read_img(input_file_path)metrics = get_metric(pred, gt_file_path, metrics_names)for metrics_name in metrics:metrics_dict[metrics_name].append(metrics[metrics_name])for metrics_name in metrics:print(f'{metrics_name}: {np.mean(metrics_dict[metrics_name])}')if config.fid:fid_score = get_FID(pred_path, gt_path)print(F'fid: {fid_score}')

这篇关于计算psnr ssim niqe fid mae lpips等指标的代码的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/892186

相关文章

JAVA项目swing转javafx语法规则以及示例代码

《JAVA项目swing转javafx语法规则以及示例代码》:本文主要介绍JAVA项目swing转javafx语法规则以及示例代码的相关资料,文中详细讲解了主类继承、窗口创建、布局管理、控件替换、... 目录最常用的“一行换一行”速查表(直接全局替换)实际转换示例(JFramejs → JavaFX)迁移建

Go异常处理、泛型和文件操作实例代码

《Go异常处理、泛型和文件操作实例代码》Go语言的异常处理机制与传统的面向对象语言(如Java、C#)所使用的try-catch结构有所不同,它采用了自己独特的设计理念和方法,:本文主要介绍Go异... 目录一:异常处理常见的异常处理向上抛中断程序恢复程序二:泛型泛型函数泛型结构体泛型切片泛型 map三:文

MyBatis中的两种参数传递类型详解(示例代码)

《MyBatis中的两种参数传递类型详解(示例代码)》文章介绍了MyBatis中传递多个参数的两种方式,使用Map和使用@Param注解或封装POJO,Map方式适用于动态、不固定的参数,但可读性和安... 目录✅ android方式一:使用Map<String, Object>✅ 方式二:使用@Param

SpringBoot实现图形验证码的示例代码

《SpringBoot实现图形验证码的示例代码》验证码的实现方式有很多,可以由前端实现,也可以由后端进行实现,也有很多的插件和工具包可以使用,在这里,我们使用Hutool提供的小工具实现,本文介绍Sp... 目录项目创建前端代码实现约定前后端交互接口需求分析接口定义Hutool工具实现服务器端代码引入依赖获

利用Python在万圣节实现比心弹窗告白代码

《利用Python在万圣节实现比心弹窗告白代码》:本文主要介绍关于利用Python在万圣节实现比心弹窗告白代码的相关资料,每个弹窗会显示一条温馨提示,程序通过参数方程绘制爱心形状,并使用多线程技术... 目录前言效果预览要点1. 爱心曲线方程2. 显示温馨弹窗函数(详细拆解)2.1 函数定义和延迟机制2.2

Springmvc常用的注解代码示例

《Springmvc常用的注解代码示例》本文介绍了SpringMVC中常用的控制器和请求映射注解,包括@Controller、@RequestMapping等,以及请求参数绑定注解,如@Request... 目录一、控制器与请求映射注解二、请求参数绑定注解三、其他常用注解(扩展)四、注解使用注意事项一、控制

C++简单日志系统实现代码示例

《C++简单日志系统实现代码示例》日志系统是成熟软件中的一个重要组成部分,其记录软件的使用和运行行为,方便事后进行故障分析、数据统计等,:本文主要介绍C++简单日志系统实现的相关资料,文中通过代码... 目录前言Util.hppLevel.hppLogMsg.hppFormat.hppSink.hppBuf

VS Code中的Python代码格式化插件示例讲解

《VSCode中的Python代码格式化插件示例讲解》在Java开发过程中,代码的规范性和可读性至关重要,一个团队中如果每个开发者的代码风格各异,会给代码的维护、审查和协作带来极大的困难,这篇文章主... 目录前言如何安装与配置使用建议与技巧如何选择总结前言在 VS Code 中,有几款非常出色的 pyt

利用Python将PDF文件转换为PNG图片的代码示例

《利用Python将PDF文件转换为PNG图片的代码示例》在日常工作和开发中,我们经常需要处理各种文档格式,PDF作为一种通用且跨平台的文档格式,被广泛应用于合同、报告、电子书等场景,然而,有时我们需... 目录引言为什么选择 python 进行 PDF 转 PNG?Spire.PDF for Python

Java集合之Iterator迭代器实现代码解析

《Java集合之Iterator迭代器实现代码解析》迭代器Iterator是Java集合框架中的一个核心接口,位于java.util包下,它定义了一种标准的元素访问机制,为各种集合类型提供了一种统一的... 目录一、什么是Iterator二、Iterator的核心方法三、基本使用示例四、Iterator的工