本文主要是介绍【代码】CenterNet使用(续)(对五六七部分详解)(六),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
接上面部分,对五六七部分进行详解,这篇介绍第六部分。
一、回顾
第六部分对得到得dets进行后处理:
dets = self.post_process(dets, meta, scale)torch.cuda.synchronize()post_process_time = time.time()post_time += post_process_time - decode_timedetections.append(dets)
post_process在ctdet.py中出现:
def post_process(self, dets, meta, scale=1):dets = dets.detach().cpu().numpy()dets = dets.reshape(1, -1, dets.shape[2])dets = ctdet_post_process(dets.copy(), [meta['c']], [meta['s']],meta['out_height'], meta['out_width'], self.opt.num_classes)for j in range(1, self.num_classes + 1):dets[0][j] = np.array(dets[0][j], dtype=np.float32).reshape(-1, 5)dets[0][j][:, :4] /= scalereturn dets[0]
使用ctdet_post_process进行后处理,
最后得到得dets是一个len=80的张量(列表?)
其中每个元素是一个N * 5的ndarray。猜测是每个类别中的dets。detections是一个包含dets的列表。
二、详解
主要是ctdet_post_process部分
输入是这些:
ctdet_post_process来源于utils/post_process,代码和解释如下:
def ctdet_post_process(dets, c, s, h, w, num_classes):# dets: batch x max_dets x dim# return 1-based class det dictret = []# 对于每个batchfor i in range(dets.shape[0]):top_preds = {}# 输入dets的左上角,中心点,最长边,(128, 128):heatmap的长宽# 得到变换后的dets,怎么变换的还未知dets[i, :, :2] = transform_preds(dets[i, :, 0:2], c[i], s[i], (w, h))# 输入dets的右上角,。。。同理dets[i, :, 2:4] = transform_preds(dets[i, :, 2:4], c[i], s[i], (w, h))classes = dets[i, :, -1]# 将第j个类的结果放到top_preds[j+1]中,top_preds是一个dictfor j in range(num_classes):inds = (classes == j)top_preds[j + 1] = np.concatenate([dets[i, inds, :4].astype(np.float32),dets[i, inds, 4:5].astype(np.float32)], axis=1).tolist()# 将一张图片的结果放到ret中ret.append(top_preds)return ret
这篇关于【代码】CenterNet使用(续)(对五六七部分详解)(六)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!