[分布外检测]Entropy Maximization and Meta Classification for Out-of-Distribution Detection...实现记录

Aomaly Segmentation 项目记录


主要参考论文Entropy Maximization and Meta Classification for Out-of-Distribution Detection in Semantic Segmentation


Deep neural networks (DNNs) for the semantic segmentation of images are usually trained to operate on a predefined closed set of object classes. This is in contrast to the “open world” setting where DNNs are envisioned to be deployed to. From a functional safety point of view, the ability to detect so-called “out-of-distribution” (OoD) samples, i.e., objects outside of a DNN’s semantic space, is crucial for many applications such as automated driving. We present a two-step procedure for OoD detection. Firstly, we utilize samples from the COCO dataset as OoD proxy(代替物) and introduce a second training objective to maximize the softmax entropy on these samples. Starting from pretrained semantic segmentation networks we re-train a number of DNNs on different in-distribution datasets and evaluate on completely disjoint OoD datasets. Secondly, we perform a transparent post-processing step to discard false positive OoD samples by so-called “meta classification”. To this end, we apply linear models to a set of hand-crafted metrics derived from the DNN’s softmax probabilities. Our method contributes to safer DNNs with more reliable overall system performance.


该项目中主要运用了COCO 和 Cityscapes两个数据集

COCO(OoD proxy)




The COCO API assists in loading, parsing, and visualizing annotations in COCO. The API supports multiple annotation formats (please see the data format page). For additional details see: CocoApi.m, coco.py, and CocoApi.lua for Matlab, Python, and Lua code, respectively, and also the Python API demo.


from pycocotools.coco import COCO as coco_tools

生成tools类对象, annotation_file是coco官方网站下载的数据集中annotation对应的json文件。

tools = coco_tools(annotation_file)
  • getCatIds(catNms)

获取类对应编号,在该方法中,需要构建COCO OoD proxy,因此需要把Cityscapes中包含的相同类的image去掉

exclude_classes = ['person', 'bicycle', 'car', 'motorcycle', 'bus', 'truck', 'traffic light', 'stop sign']
exclude_cat_Ids = tools.getCatIds(catNms = exclude_classes)
# 返回list
# exclude_cat_Ids
# [1, 2, 3, 4, 6, 8, 10, 13]
  • getImgIds(catIds)


exclude_img_Ids = []
for cat_Id in exclude_cat_Ids:exclude_img_Ids += tools.getImgIds(catIds = cat_Id)
# 返回list
# [262145, 262146, 524291, 262148, 393223, 393224, 524297, 393227, 131084, 393230, 262161, 131089, 524311, 393241, ...]
  • loadImgs(imgid)


img = tools.loadImgs(img_Id)[0]
'date_captured':'2013-11-18 02:58:15'
  • getAnnIds(imgIds, iscrowd=None)


  • loadAnns(annids)


annotations = tools.loadAnns(ann_Ids)
'segmentation':[[122.16, 330.27, 194.59, 225.41, 278.92, 195.14, 289.73, 172.43, 316.76, ...]]
'bbox':[122.16, 140.0, 370.81, 201.08]
  • annToMask(annoations)


COCO dataset
class COCO(Dataset):train_id_in = 0train_id_out = 254min_image_size = 480def __init__(self, root, split="train", transform = None, shuffle = True,proxy_size = None)self.root = rootself.coco_year = list(filter(None, self.root.split("/")))[-1]self.split = split + self.coco_yearself.images = []self.targets = []self.transform = transformfor root, _, filenames in os.walk(os.path.join(self.root, "annotations", "ood_seg_" + self.split)):assert self.split in ['train' + self.coco_year, 'val' + self.coco_year]for filename in filenames:if os.path.splitext(filename)[-1] == '.png':self.targets.append(os.path.join(root, filename))self.images.append(os.path.join(self.root, self.split, filename.split(".")[0] + ".jpg"))if shuffle: # 打乱zipped = list(zip(self.images, self.targets))random.shuffle(zipped)self.images, self.targets = zip(*zipped)if proxy_size is not None: # COCO数据集只取一定量作为PROXYself.images = list(self.images[:int(proxy_size)])self.targets = list(self.targets[:int(proxy_size)])def __len__(self):return len(self.images)def __getitem__(self, i):image = Image.open(self.images[i]).convert('RGB')target = Image.open(self.targets[i]).convert('L')if self.transform is not None:image, target = self.transform(image, target)return image, targetdef __repr__(self):fmt_str = 'Number of COCO Images: %d\n' % len(self.images)return fmt_str.strip()
array([[  0,   0,   0, ...,   0,   0,   0],[  0,   0,   0, ...,   0,   0,   0],[  0,   0,   0, ...,   0,   0,   0],...,[254, 254, 254, ...,   0,   0,   0],[254, 254, 254, ...,   0,   0,   0],[  0,   0, 254, ...,   0,   0,   0]], dtype=uint8)

可以看到coco数据集经过处理后 target只包括0和254,0是没有mask的地方,254是有mask的地方

Cityscapes dataset
class Cityscapes(Dataset):CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id','has_instances', 'ignore_in_eval', 'color'])labels = [#                 name                     id    trainId   category            catId     hasInstances   ignoreInEval   colorCityscapesClass(  'unlabeled'            ,  0 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),...]mean = (0.485, 0.456, 0.406)std = (0.229, 0.224, 0.225)ignore_in_eval_ids, label_ids, train_ids, train_id2id = [], [], [], []  # empty lists for storing idscolor_palette_train_ids = [(0, 0, 0) for i in range(256)]for i in range(len(labels)):if labels[i].ignore_in_eval and labels[i].train_id not in ignore_in_eval_ids:ignore_in_eval_ids.append(labels[i].train_id) # eval 不要的类别放进去for i in range(len(labels)):label_ids.append(labels[i].id)if labels[i].train_id not in ignore_in_eval_ids:train_ids.append(labels[i].train_id)color_palette_train_ids[labels[i].train_id] = labels[i].colortrain_id2id.append(labels[i].id)num_label_ids = len(set(label_ids)) # 所有的类num_train_ids = len(set(train_ids)) # eval需要用到的类id2label = {label.id: label for label in labels}train_id2label = {label.train_id: label for label in labels}def __init__(self, root = "/home/datasets/cityscapes/", split = "val", mode = "gtFine",target_type = "semantic_id", transform = None,predictions_root = None) -> None:self.root = rootself.split = splitself.mode = 'gtFine' if "fine" in mode.lower() else 'gtCoarse' # fine or coarseself.transform = transformself.images_dir = os.path.join(self.root, 'leftImg8bit', self.split)self.targets_dir = os.path.join(self.root, self.mode, self.split)self.predictions_dir = os.path.join(predictions_root, self.split) if predictions_root is not None else ""self.images = []self.targets = []self.predictions = []for city in os.listdir(self.images_dir):img_dir = os.path.join(self.images_dir, city)target_dir = os.path.join(self.targets_dir, city)pred_dir = os.path.join(self.predictions_dir, city)for file_name in os.listdir(img_dir):target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],self._get_target_suffix(self.mode, target_type))self.images.append(os.path.join(img_dir, file_name))self.targets.append(os.path.join(target_dir, target_name))self.predictions.append(os.path.join(pred_dir, file_name.replace("_leftImg8bit", "")))def __getitem__(self, index):image = Image.open(self.images[index]).convert('RGB')if self.split in ['train', 'val']:target = Image.open(self.targets[index])else:target = Noneif self.transform is not None:image, target = self.transform(image, target)return image, targetdef __len__(self):return len(self.images)
Target encode:
def encode_target(target, pareto_alpha, num_classes, ignore_train_ind, ood_ind=254):"""encode target tensor with all hot encoding for OoD samples:param target: torch tensor:param pareto_alpha: OoD loss weight:param num_classes: number of classes in original task:param ignore_train_ind: void class in original task:param ood_ind: class label corresponding to OoD class:return: one/all hot encoded torch tensor"""npy = target.numpy()npz = npy.copy()npy[np.isin(npy, ood_ind)] = num_classes # 19npy[np.isin(npy, ignore_train_ind)] = num_classes + 1 # 20enc = np.eye(num_classes + 2)[npy][..., :-2]  # one hot encoding with last 2 axis cutoffenc[(npy == num_classes)] = np.full(num_classes, pareto_alpha / num_classes)  # set all hot encoded vectorenc[(enc == 1)] = 1 - pareto_alpha  # convex combination between in and out distribution samplesenc[np.isin(npz, ignore_train_ind)] = np.zeros(num_classes)enc = torch.from_numpy(enc)enc = enc.permute(0, 3, 1, 2).contiguous()return enc

