本文主要是介绍Faster RCNN 推理 从头写 java (二) RPN网络预测,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
- 1. 图片预处理
2. RPN网络预测
- 3. RPN to ROIs
- 4. Classifier 网络预测
- 5. Classifier网络输出对 ROIs过滤与修正
- 6. NMS (非最大值抑制)
- 7. 坐标转换为原始图片维度
一: 输入输出
输入:
- omg: 经过预处理过的图像, shape为 [1, 600, 800, 3].
输出:
- cls: 每个anchor在pixel上的概率, shape为 [1, 37, 50, 49].
- reg: 每个anchor在pixel上的回归值, shape 为 [1, 37, 50, 196].
- feature: 经过VGG16后的feature map, shape 为 [1, 37, 50, 512].
二: 流程
- 图片BGR 格式转换为 RGB 格式。
- 图片缩放。
- 图片均值中值化。
三: code by code
img 转换为tensorflow 的 Tensor
Tensor<Float> input = TypeConvertor.ndarrayToTensor(img);
预测
List<Tensor<?>> output = this.session.runner().feed(INPUT_NAME, input).fetch(OUTPUT_CLS_NAME).fetch(OUTPUT_REG_NAME).fetch(OUTPUT_FEATURE_MAP_NAME).run();
构建输出
0: cls
1: reg
3: feature
return new FasterRCnnRPN_Output(output.get(0), output.get(1), output.get(2));
这篇关于Faster RCNN 推理 从头写 java (二) RPN网络预测的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!