[分布外检测]Entropy Maximization and Meta Classification for Out-of-Distribution Detection...实现记录

本文主要是介绍[分布外检测]Entropy Maximization and Meta Classification for Out-of-Distribution Detection...实现记录,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Aomaly Segmentation 项目记录

该文档记录异常检测在自动驾驶语义分割场景中的应用

主要参考论文Entropy Maximization and Meta Classification for Out-of-Distribution Detection in Semantic Segmentation

摘要:

Deep neural networks (DNNs) for the semantic segmentation of images are usually trained to operate on a predefined closed set of object classes. This is in contrast to the “open world” setting where DNNs are envisioned to be deployed to. From a functional safety point of view, the ability to detect so-called “out-of-distribution” (OoD) samples, i.e., objects outside of a DNN’s semantic space, is crucial for many applications such as automated driving. We present a two-step procedure for OoD detection. Firstly, we utilize samples from the COCO dataset as OoD proxy(代替物) and introduce a second training objective to maximize the softmax entropy on these samples. Starting from pretrained semantic segmentation networks we re-train a number of DNNs on different in-distribution datasets and evaluate on completely disjoint OoD datasets. Secondly, we perform a transparent post-processing step to discard false positive OoD samples by so-called “meta classification”. To this end, we apply linear models to a set of hand-crafted metrics derived from the DNN’s softmax probabilities. Our method contributes to safer DNNs with more reliable overall system performance.

数据处理:

该项目中主要运用了COCO 和 Cityscapes两个数据集

COCO(OoD proxy)
pycocotools

索引需要用到的图片和annotations,并生成需要的mask

论文主要利用了COCO2017的segmentation数据集,处理数据集的过程利用coco的api:pycocotools.coco

官方文档如下:

The COCO API assists in loading, parsing, and visualizing annotations in COCO. The API supports multiple annotation formats (please see the data format page). For additional details see: CocoApi.m, coco.py, and CocoApi.lua for Matlab, Python, and Lua code, respectively, and also the Python API demo.

使用记录:

from pycocotools.coco import COCO as coco_tools

生成tools类对象, annotation_file是coco官方网站下载的数据集中annotation对应的json文件。

tools = coco_tools(annotation_file)
  • getCatIds(catNms)

获取类对应编号,在该方法中,需要构建COCO OoD proxy,因此需要把Cityscapes中包含的相同类的image去掉

exclude_classes = ['person', 'bicycle', 'car', 'motorcycle', 'bus', 'truck', 'traffic light', 'stop sign']
exclude_cat_Ids = tools.getCatIds(catNms = exclude_classes)
# 返回list
# exclude_cat_Ids
# [1, 2, 3, 4, 6, 8, 10, 13]
  • getImgIds(catIds)

获取包含输入类的所有图片的编号

exclude_img_Ids = []
for cat_Id in exclude_cat_Ids:exclude_img_Ids += tools.getImgIds(catIds = cat_Id)
# 返回list
# [262145, 262146, 524291, 262148, 393223, 393224, 524297, 393227, 131084, 393230, 262161, 131089, 524311, 393241, ...]
  • loadImgs(imgid)

读取图片,返回dict

img = tools.loadImgs(img_Id)[0]
'''
'license':1
'file_name':'000000177284.jpg'
'coco_url':'http://images.cocodataset.org/train2017/000000177284.jpg'
'height':480
'width':640
'date_captured':'2013-11-18 02:58:15'
'flickr_url':'http://farm9.staticflickr.com/8036/8074156186_a7331cbd3b_z.jpg'
'id':177284
len():8
'''
  • getAnnIds(imgIds, iscrowd=None)

获取对应编号图片的annotation的编号

  • loadAnns(annids)

读取annotation

annotations = tools.loadAnns(ann_Ids)
'''
'segmentation':[[122.16, 330.27, 194.59, 225.41, 278.92, 195.14, 289.73, 172.43, 316.76, ...]]
'area':46713.55159999999
'iscrowd':0
'image_id':177284
'bbox':[122.16, 140.0, 370.81, 201.08]
'category_id':22
'id':582827
len():7
'''
  • annToMask(annoations)

从annotations读取mask

COCO dataset
class COCO(Dataset):train_id_in = 0train_id_out = 254min_image_size = 480def __init__(self, root, split="train", transform = None, shuffle = True,proxy_size = None)self.root = rootself.coco_year = list(filter(None, self.root.split("/")))[-1]self.split = split + self.coco_yearself.images = []self.targets = []self.transform = transformfor root, _, filenames in os.walk(os.path.join(self.root, "annotations", "ood_seg_" + self.split)):assert self.split in ['train' + self.coco_year, 'val' + self.coco_year]for filename in filenames:if os.path.splitext(filename)[-1] == '.png':self.targets.append(os.path.join(root, filename))self.images.append(os.path.join(self.root, self.split, filename.split(".")[0] + ".jpg"))if shuffle: # 打乱zipped = list(zip(self.images, self.targets))random.shuffle(zipped)self.images, self.targets = zip(*zipped)if proxy_size is not None: # COCO数据集只取一定量作为PROXYself.images = list(self.images[:int(proxy_size)])self.targets = list(self.targets[:int(proxy_size)])def __len__(self):return len(self.images)def __getitem__(self, i):image = Image.open(self.images[i]).convert('RGB')target = Image.open(self.targets[i]).convert('L')if self.transform is not None:image, target = self.transform(image, target)return image, targetdef __repr__(self):fmt_str = 'Number of COCO Images: %d\n' % len(self.images)return fmt_str.strip()
np.array(coco[0][1])
array([[  0,   0,   0, ...,   0,   0,   0],[  0,   0,   0, ...,   0,   0,   0],[  0,   0,   0, ...,   0,   0,   0],...,[254, 254, 254, ...,   0,   0,   0],[254, 254, 254, ...,   0,   0,   0],[  0,   0, 254, ...,   0,   0,   0]], dtype=uint8)

可以看到coco数据集经过处理后 target只包括0和254,0是没有mask的地方,254是有mask的地方

Cityscapes
Cityscapes dataset
class Cityscapes(Dataset):CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id','has_instances', 'ignore_in_eval', 'color'])labels = [#                 name                     id    trainId   category            catId     hasInstances   ignoreInEval   colorCityscapesClass(  'unlabeled'            ,  0 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),...]mean = (0.485, 0.456, 0.406)std = (0.229, 0.224, 0.225)ignore_in_eval_ids, label_ids, train_ids, train_id2id = [], [], [], []  # empty lists for storing idscolor_palette_train_ids = [(0, 0, 0) for i in range(256)]for i in range(len(labels)):if labels[i].ignore_in_eval and labels[i].train_id not in ignore_in_eval_ids:ignore_in_eval_ids.append(labels[i].train_id) # eval 不要的类别放进去for i in range(len(labels)):label_ids.append(labels[i].id)if labels[i].train_id not in ignore_in_eval_ids:train_ids.append(labels[i].train_id)color_palette_train_ids[labels[i].train_id] = labels[i].colortrain_id2id.append(labels[i].id)num_label_ids = len(set(label_ids)) # 所有的类num_train_ids = len(set(train_ids)) # eval需要用到的类id2label = {label.id: label for label in labels}train_id2label = {label.train_id: label for label in labels}def __init__(self, root = "/home/datasets/cityscapes/", split = "val", mode = "gtFine",target_type = "semantic_id", transform = None,predictions_root = None) -> None:self.root = rootself.split = splitself.mode = 'gtFine' if "fine" in mode.lower() else 'gtCoarse' # fine or coarseself.transform = transformself.images_dir = os.path.join(self.root, 'leftImg8bit', self.split)self.targets_dir = os.path.join(self.root, self.mode, self.split)self.predictions_dir = os.path.join(predictions_root, self.split) if predictions_root is not None else ""self.images = []self.targets = []self.predictions = []for city in os.listdir(self.images_dir):img_dir = os.path.join(self.images_dir, city)target_dir = os.path.join(self.targets_dir, city)pred_dir = os.path.join(self.predictions_dir, city)for file_name in os.listdir(img_dir):target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],self._get_target_suffix(self.mode, target_type))self.images.append(os.path.join(img_dir, file_name))self.targets.append(os.path.join(target_dir, target_name))self.predictions.append(os.path.join(pred_dir, file_name.replace("_leftImg8bit", "")))def __getitem__(self, index):image = Image.open(self.images[index]).convert('RGB')if self.split in ['train', 'val']:target = Image.open(self.targets[index])else:target = Noneif self.transform is not None:image, target = self.transform(image, target)return image, targetdef __len__(self):return len(self.images)
Target encode:
def encode_target(target, pareto_alpha, num_classes, ignore_train_ind, ood_ind=254):"""encode target tensor with all hot encoding for OoD samples:param target: torch tensor:param pareto_alpha: OoD loss weight:param num_classes: number of classes in original task:param ignore_train_ind: void class in original task:param ood_ind: class label corresponding to OoD class:return: one/all hot encoded torch tensor"""npy = target.numpy()npz = npy.copy()npy[np.isin(npy, ood_ind)] = num_classes # 19npy[np.isin(npy, ignore_train_ind)] = num_classes + 1 # 20enc = np.eye(num_classes + 2)[npy][..., :-2]  # one hot encoding with last 2 axis cutoffenc[(npy == num_classes)] = np.full(num_classes, pareto_alpha / num_classes)  # set all hot encoded vectorenc[(enc == 1)] = 1 - pareto_alpha  # convex combination between in and out distribution samplesenc[np.isin(npz, ignore_train_ind)] = np.zeros(num_classes)enc = torch.from_numpy(enc)enc = enc.permute(0, 3, 1, 2).contiguous()return enc

这篇关于[分布外检测]Entropy Maximization and Meta Classification for Out-of-Distribution Detection...实现记录的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/743001

相关文章

C++使用栈实现括号匹配的代码详解

《C++使用栈实现括号匹配的代码详解》在编程中,括号匹配是一个常见问题,尤其是在处理数学表达式、编译器解析等任务时,栈是一种非常适合处理此类问题的数据结构,能够精确地管理括号的匹配问题,本文将通过C+... 目录引言问题描述代码讲解代码解析栈的状态表示测试总结引言在编程中,括号匹配是一个常见问题,尤其是在

Java实现检查多个时间段是否有重合

《Java实现检查多个时间段是否有重合》这篇文章主要为大家详细介绍了如何使用Java实现检查多个时间段是否有重合,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录流程概述步骤详解China编程步骤1:定义时间段类步骤2:添加时间段步骤3:检查时间段是否有重合步骤4:输出结果示例代码结语作

使用C++实现链表元素的反转

《使用C++实现链表元素的反转》反转链表是链表操作中一个经典的问题,也是面试中常见的考题,本文将从思路到实现一步步地讲解如何实现链表的反转,帮助初学者理解这一操作,我们将使用C++代码演示具体实现,同... 目录问题定义思路分析代码实现带头节点的链表代码讲解其他实现方式时间和空间复杂度分析总结问题定义给定

Java覆盖第三方jar包中的某一个类的实现方法

《Java覆盖第三方jar包中的某一个类的实现方法》在我们日常的开发中,经常需要使用第三方的jar包,有时候我们会发现第三方的jar包中的某一个类有问题,或者我们需要定制化修改其中的逻辑,那么应该如何... 目录一、需求描述二、示例描述三、操作步骤四、验证结果五、实现原理一、需求描述需求描述如下:需要在

如何使用Java实现请求deepseek

《如何使用Java实现请求deepseek》这篇文章主要为大家详细介绍了如何使用Java实现请求deepseek功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1.deepseek的api创建2.Java实现请求deepseek2.1 pom文件2.2 json转化文件2.2

python使用fastapi实现多语言国际化的操作指南

《python使用fastapi实现多语言国际化的操作指南》本文介绍了使用Python和FastAPI实现多语言国际化的操作指南,包括多语言架构技术栈、翻译管理、前端本地化、语言切换机制以及常见陷阱和... 目录多语言国际化实现指南项目多语言架构技术栈目录结构翻译工作流1. 翻译数据存储2. 翻译生成脚本

如何通过Python实现一个消息队列

《如何通过Python实现一个消息队列》这篇文章主要为大家详细介绍了如何通过Python实现一个简单的消息队列,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录如何通过 python 实现消息队列如何把 http 请求放在队列中执行1. 使用 queue.Queue 和 reque

Python如何实现PDF隐私信息检测

《Python如何实现PDF隐私信息检测》随着越来越多的个人信息以电子形式存储和传输,确保这些信息的安全至关重要,本文将介绍如何使用Python检测PDF文件中的隐私信息,需要的可以参考下... 目录项目背景技术栈代码解析功能说明运行结php果在当今,数据隐私保护变得尤为重要。随着越来越多的个人信息以电子形

使用 sql-research-assistant进行 SQL 数据库研究的实战指南(代码实现演示)

《使用sql-research-assistant进行SQL数据库研究的实战指南(代码实现演示)》本文介绍了sql-research-assistant工具,该工具基于LangChain框架,集... 目录技术背景介绍核心原理解析代码实现演示安装和配置项目集成LangSmith 配置(可选)启动服务应用场景

使用Python快速实现链接转word文档

《使用Python快速实现链接转word文档》这篇文章主要为大家详细介绍了如何使用Python快速实现链接转word文档功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 演示代码展示from newspaper import Articlefrom docx import