本文主要是介绍Faster RCNN 推理 从头写 java (四) Classifier 网络预测,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
- 1. 图片预处理
- 2. RPN网络预测
- 3. RPN to ROIs
4. Classifier 网络预测
- 5. Classifier网络输出对 ROIs过滤与修正
- 6. NMS (非最大值抑制)
- 7. 坐标转换为原始图片维度
一: 输入输出
输入:
- ROIs: RPN to ROI 后 没32个为一组的ROIs, shape为 [1, 32, 4]
- feature: RPN 层的输出, 也就是VGG16的feature map, shape 为 [1, 37, 50, 512]
输出:
- P_cls: 每个ROI的概率 shape为 [1, 32, 2]
- P_regr: 每个ROI的回归值, shape 为 [1, 37, 50, 4]
二: 流程
- 预测
三: code by code
ROIs, feature 转换为tensorflow 的 Tensor
if (featureMap.dataType() != DataType.FLOAT) featureMap = featureMap.castTo(DataType.FLOAT);
Tensor<Float> feature_input = TypeConvertor.ndarrayToTensor(featureMap);if (ROIs.dataType() != DataType.FLOAT) ROIs = ROIs.castTo(DataType.FLOAT);
Tensor<Float> ROIs_input = TypeConvertor.ndarrayToTensor(ROIs);
Classifier 网络模型预测
List<Tensor<?>> output = this.session.runner().feed(INPUT_FEATURE_NAME, feature_input).feed(INPUT_ROI_NAME, ROIs_input).fetch(OUTPUT_CLS_NAME).fetch(OUTPUT_REG_NAME).run();
构建输出
0: P_cls
1: P_regr
return new FasterRCnnClassifier_Output(output.get(0), output.get(1));
这篇关于Faster RCNN 推理 从头写 java (四) Classifier 网络预测的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!