[分布外检测]Entropy Maximization and Meta Classification for Out-of-Distribution Detection...实现记录

本文主要是介绍[分布外检测]Entropy Maximization and Meta Classification for Out-of-Distribution Detection...实现记录,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Aomaly Segmentation 项目记录

该文档记录异常检测在自动驾驶语义分割场景中的应用

主要参考论文Entropy Maximization and Meta Classification for Out-of-Distribution Detection in Semantic Segmentation

摘要:

Deep neural networks (DNNs) for the semantic segmentation of images are usually trained to operate on a predefined closed set of object classes. This is in contrast to the “open world” setting where DNNs are envisioned to be deployed to. From a functional safety point of view, the ability to detect so-called “out-of-distribution” (OoD) samples, i.e., objects outside of a DNN’s semantic space, is crucial for many applications such as automated driving. We present a two-step procedure for OoD detection. Firstly, we utilize samples from the COCO dataset as OoD proxy(代替物) and introduce a second training objective to maximize the softmax entropy on these samples. Starting from pretrained semantic segmentation networks we re-train a number of DNNs on different in-distribution datasets and evaluate on completely disjoint OoD datasets. Secondly, we perform a transparent post-processing step to discard false positive OoD samples by so-called “meta classification”. To this end, we apply linear models to a set of hand-crafted metrics derived from the DNN’s softmax probabilities. Our method contributes to safer DNNs with more reliable overall system performance.

数据处理:

该项目中主要运用了COCO 和 Cityscapes两个数据集

COCO(OoD proxy)
pycocotools

索引需要用到的图片和annotations,并生成需要的mask

论文主要利用了COCO2017的segmentation数据集,处理数据集的过程利用coco的api:pycocotools.coco

官方文档如下:

The COCO API assists in loading, parsing, and visualizing annotations in COCO. The API supports multiple annotation formats (please see the data format page). For additional details see: CocoApi.m, coco.py, and CocoApi.lua for Matlab, Python, and Lua code, respectively, and also the Python API demo.

使用记录:

from pycocotools.coco import COCO as coco_tools

生成tools类对象, annotation_file是coco官方网站下载的数据集中annotation对应的json文件。

tools = coco_tools(annotation_file)
  • getCatIds(catNms)

获取类对应编号,在该方法中,需要构建COCO OoD proxy,因此需要把Cityscapes中包含的相同类的image去掉

exclude_classes = ['person', 'bicycle', 'car', 'motorcycle', 'bus', 'truck', 'traffic light', 'stop sign']
exclude_cat_Ids = tools.getCatIds(catNms = exclude_classes)
# 返回list
# exclude_cat_Ids
# [1, 2, 3, 4, 6, 8, 10, 13]
  • getImgIds(catIds)

获取包含输入类的所有图片的编号

exclude_img_Ids = []
for cat_Id in exclude_cat_Ids:exclude_img_Ids += tools.getImgIds(catIds = cat_Id)
# 返回list
# [262145, 262146, 524291, 262148, 393223, 393224, 524297, 393227, 131084, 393230, 262161, 131089, 524311, 393241, ...]
  • loadImgs(imgid)

读取图片,返回dict

img = tools.loadImgs(img_Id)[0]
'''
'license':1
'file_name':'000000177284.jpg'
'coco_url':'http://images.cocodataset.org/train2017/000000177284.jpg'
'height':480
'width':640
'date_captured':'2013-11-18 02:58:15'
'flickr_url':'http://farm9.staticflickr.com/8036/8074156186_a7331cbd3b_z.jpg'
'id':177284
len():8
'''
  • getAnnIds(imgIds, iscrowd=None)

获取对应编号图片的annotation的编号

  • loadAnns(annids)

读取annotation

annotations = tools.loadAnns(ann_Ids)
'''
'segmentation':[[122.16, 330.27, 194.59, 225.41, 278.92, 195.14, 289.73, 172.43, 316.76, ...]]
'area':46713.55159999999
'iscrowd':0
'image_id':177284
'bbox':[122.16, 140.0, 370.81, 201.08]
'category_id':22
'id':582827
len():7
'''
  • annToMask(annoations)

从annotations读取mask

COCO dataset
class COCO(Dataset):train_id_in = 0train_id_out = 254min_image_size = 480def __init__(self, root, split="train", transform = None, shuffle = True,proxy_size = None)self.root = rootself.coco_year = list(filter(None, self.root.split("/")))[-1]self.split = split + self.coco_yearself.images = []self.targets = []self.transform = transformfor root, _, filenames in os.walk(os.path.join(self.root, "annotations", "ood_seg_" + self.split)):assert self.split in ['train' + self.coco_year, 'val' + self.coco_year]for filename in filenames:if os.path.splitext(filename)[-1] == '.png':self.targets.append(os.path.join(root, filename))self.images.append(os.path.join(self.root, self.split, filename.split(".")[0] + ".jpg"))if shuffle: # 打乱zipped = list(zip(self.images, self.targets))random.shuffle(zipped)self.images, self.targets = zip(*zipped)if proxy_size is not None: # COCO数据集只取一定量作为PROXYself.images = list(self.images[:int(proxy_size)])self.targets = list(self.targets[:int(proxy_size)])def __len__(self):return len(self.images)def __getitem__(self, i):image = Image.open(self.images[i]).convert('RGB')target = Image.open(self.targets[i]).convert('L')if self.transform is not None:image, target = self.transform(image, target)return image, targetdef __repr__(self):fmt_str = 'Number of COCO Images: %d\n' % len(self.images)return fmt_str.strip()
np.array(coco[0][1])
array([[  0,   0,   0, ...,   0,   0,   0],[  0,   0,   0, ...,   0,   0,   0],[  0,   0,   0, ...,   0,   0,   0],...,[254, 254, 254, ...,   0,   0,   0],[254, 254, 254, ...,   0,   0,   0],[  0,   0, 254, ...,   0,   0,   0]], dtype=uint8)

可以看到coco数据集经过处理后 target只包括0和254,0是没有mask的地方,254是有mask的地方

Cityscapes
Cityscapes dataset
class Cityscapes(Dataset):CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id','has_instances', 'ignore_in_eval', 'color'])labels = [#                 name                     id    trainId   category            catId     hasInstances   ignoreInEval   colorCityscapesClass(  'unlabeled'            ,  0 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),...]mean = (0.485, 0.456, 0.406)std = (0.229, 0.224, 0.225)ignore_in_eval_ids, label_ids, train_ids, train_id2id = [], [], [], []  # empty lists for storing idscolor_palette_train_ids = [(0, 0, 0) for i in range(256)]for i in range(len(labels)):if labels[i].ignore_in_eval and labels[i].train_id not in ignore_in_eval_ids:ignore_in_eval_ids.append(labels[i].train_id) # eval 不要的类别放进去for i in range(len(labels)):label_ids.append(labels[i].id)if labels[i].train_id not in ignore_in_eval_ids:train_ids.append(labels[i].train_id)color_palette_train_ids[labels[i].train_id] = labels[i].colortrain_id2id.append(labels[i].id)num_label_ids = len(set(label_ids)) # 所有的类num_train_ids = len(set(train_ids)) # eval需要用到的类id2label = {label.id: label for label in labels}train_id2label = {label.train_id: label for label in labels}def __init__(self, root = "/home/datasets/cityscapes/", split = "val", mode = "gtFine",target_type = "semantic_id", transform = None,predictions_root = None) -> None:self.root = rootself.split = splitself.mode = 'gtFine' if "fine" in mode.lower() else 'gtCoarse' # fine or coarseself.transform = transformself.images_dir = os.path.join(self.root, 'leftImg8bit', self.split)self.targets_dir = os.path.join(self.root, self.mode, self.split)self.predictions_dir = os.path.join(predictions_root, self.split) if predictions_root is not None else ""self.images = []self.targets = []self.predictions = []for city in os.listdir(self.images_dir):img_dir = os.path.join(self.images_dir, city)target_dir = os.path.join(self.targets_dir, city)pred_dir = os.path.join(self.predictions_dir, city)for file_name in os.listdir(img_dir):target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],self._get_target_suffix(self.mode, target_type))self.images.append(os.path.join(img_dir, file_name))self.targets.append(os.path.join(target_dir, target_name))self.predictions.append(os.path.join(pred_dir, file_name.replace("_leftImg8bit", "")))def __getitem__(self, index):image = Image.open(self.images[index]).convert('RGB')if self.split in ['train', 'val']:target = Image.open(self.targets[index])else:target = Noneif self.transform is not None:image, target = self.transform(image, target)return image, targetdef __len__(self):return len(self.images)
Target encode:
def encode_target(target, pareto_alpha, num_classes, ignore_train_ind, ood_ind=254):"""encode target tensor with all hot encoding for OoD samples:param target: torch tensor:param pareto_alpha: OoD loss weight:param num_classes: number of classes in original task:param ignore_train_ind: void class in original task:param ood_ind: class label corresponding to OoD class:return: one/all hot encoded torch tensor"""npy = target.numpy()npz = npy.copy()npy[np.isin(npy, ood_ind)] = num_classes # 19npy[np.isin(npy, ignore_train_ind)] = num_classes + 1 # 20enc = np.eye(num_classes + 2)[npy][..., :-2]  # one hot encoding with last 2 axis cutoffenc[(npy == num_classes)] = np.full(num_classes, pareto_alpha / num_classes)  # set all hot encoded vectorenc[(enc == 1)] = 1 - pareto_alpha  # convex combination between in and out distribution samplesenc[np.isin(npz, ignore_train_ind)] = np.zeros(num_classes)enc = torch.from_numpy(enc)enc = enc.permute(0, 3, 1, 2).contiguous()return enc

这篇关于[分布外检测]Entropy Maximization and Meta Classification for Out-of-Distribution Detection...实现记录的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/743001

相关文章

idea中创建新类时自动添加注释的实现

《idea中创建新类时自动添加注释的实现》在每次使用idea创建一个新类时,过了一段时间发现看不懂这个类是用来干嘛的,为了解决这个问题,我们可以设置在创建一个新类时自动添加注释,帮助我们理解这个类的用... 目录前言:详细操作:步骤一:点击上方的 文件(File),点击&nbmyHIgsp;设置(Setti

SpringBoot实现MD5加盐算法的示例代码

《SpringBoot实现MD5加盐算法的示例代码》加盐算法是一种用于增强密码安全性的技术,本文主要介绍了SpringBoot实现MD5加盐算法的示例代码,文中通过示例代码介绍的非常详细,对大家的学习... 目录一、什么是加盐算法二、如何实现加盐算法2.1 加盐算法代码实现2.2 注册页面中进行密码加盐2.

MySQL大表数据的分区与分库分表的实现

《MySQL大表数据的分区与分库分表的实现》数据库的分区和分库分表是两种常用的技术方案,本文主要介绍了MySQL大表数据的分区与分库分表的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有... 目录1. mysql大表数据的分区1.1 什么是分区?1.2 分区的类型1.3 分区的优点1.4 分

一文详解如何从零构建Spring Boot Starter并实现整合

《一文详解如何从零构建SpringBootStarter并实现整合》SpringBoot是一个开源的Java基础框架,用于创建独立、生产级的基于Spring框架的应用程序,:本文主要介绍如何从... 目录一、Spring Boot Starter的核心价值二、Starter项目创建全流程2.1 项目初始化(

Mysql删除几亿条数据表中的部分数据的方法实现

《Mysql删除几亿条数据表中的部分数据的方法实现》在MySQL中删除一个大表中的数据时,需要特别注意操作的性能和对系统的影响,本文主要介绍了Mysql删除几亿条数据表中的部分数据的方法实现,具有一定... 目录1、需求2、方案1. 使用 DELETE 语句分批删除2. 使用 INPLACE ALTER T

MySQL INSERT语句实现当记录不存在时插入的几种方法

《MySQLINSERT语句实现当记录不存在时插入的几种方法》MySQL的INSERT语句是用于向数据库表中插入新记录的关键命令,下面:本文主要介绍MySQLINSERT语句实现当记录不存在时... 目录使用 INSERT IGNORE使用 ON DUPLICATE KEY UPDATE使用 REPLACE

mysql数据库重置表主键id的实现

《mysql数据库重置表主键id的实现》在我们的开发过程中,难免在做测试的时候会生成一些杂乱无章的SQL主键数据,本文主要介绍了mysql数据库重置表主键id的实现,具有一定的参考价值,感兴趣的可以了... 目录关键语法演示案例在我们的开发过程中,难免在做测试的时候会生成一些杂乱无章的SQL主键数据,当我们

SpringBoot配置Ollama实现本地部署DeepSeek

《SpringBoot配置Ollama实现本地部署DeepSeek》本文主要介绍了在本地环境中使用Ollama配置DeepSeek模型,并在IntelliJIDEA中创建一个Sprin... 目录前言详细步骤一、本地配置DeepSeek二、SpringBoot项目调用本地DeepSeek前言随着人工智能技

Python 中的异步与同步深度解析(实践记录)

《Python中的异步与同步深度解析(实践记录)》在Python编程世界里,异步和同步的概念是理解程序执行流程和性能优化的关键,这篇文章将带你深入了解它们的差异,以及阻塞和非阻塞的特性,同时通过实际... 目录python中的异步与同步:深度解析与实践异步与同步的定义异步同步阻塞与非阻塞的概念阻塞非阻塞同步

Python Dash框架在数据可视化仪表板中的应用与实践记录

《PythonDash框架在数据可视化仪表板中的应用与实践记录》Python的PlotlyDash库提供了一种简便且强大的方式来构建和展示互动式数据仪表板,本篇文章将深入探讨如何使用Dash设计一... 目录python Dash框架在数据可视化仪表板中的应用与实践1. 什么是Plotly Dash?1.1