本文主要是介绍如何用tensorflow中的object detection API来训练自己的数据集上的目标检测器,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
这仍然是一个学习记录博~欢迎讨论
安装tensorflow object detection API
首先当然是要先安装tensorflow object detection API啦,网上有很多教程可以参考,比如:Ubuntu 16.04下安装TensorFlow Object Detection API(对象检测API)
P.S. 安装时需要注意的就是tensorflow-gpu的版本问题,你的tensorflow和tensorflow-gpu以及tensorflow API的版本都要一致,如果tensorflow-gpu版本是1.x的,那就不能安装2.x的API啦,有很多版本问题,使用的时候会报错的,比如:
ModuleNotFoundError: No module named 'tensorflow.contrib'; 'tensorflow' is not a package
我的tensorflow-gpu是1.13.1,我的tensorflow也是1.13.1,而我安装的API的版本是1.13.0~
基于浣熊的object detection
这个部分可以参考:用TensorFlow训练一个物体检测器(手把手教学版),写的真的蛮详细的。
不过文章里还是有点小问题的,所以在这里补充说明一下:
- train.txt和test.txt文件的生成代码,博客里没有给全代码,下面给大家一段直接可以运行的代码来参考(路径自己改一下,运行一次就好了,不要重复运行,会一直追加写入的)
import os
import randomtrain_txt = open("xx/xx/dataset/train.txt", 'a+')
test_txt = open("xx/xx/dataset/test.txt", 'a+')pt = "xx/xx/xx/raccoon_dataset/images"
image_name = os.listdir(pt)count = 0
for temp in image_name:if temp.endswith(".jpg"):print(temp.replace('.jpg', ''))if count < 160:train_txt.write(temp.replace('.jpg', '')+'\n')else:test_txt.write(temp.replace('.jpg', '') + '\n')count += 1
- 训练之前不要忘记添加环境变量
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
这里的pwd就是当前路径,一般我们都是切换到models/research/object_detection目录下来运行的,所以这个pwd就是object_detection所在的路径。
- 训练所用的文件train.py,eval.py不在object_detection目录下,而在object_detection/legacy目录下,可能是版本变了,所以把train.py什么的挪了位置。
使用model_train.py代替train.py
如果不想使用老版本的train.py的话,可以用object_detection下的model_train.py进行训练,它把老版本的train.py和eval.py集合到了一起,model_train.py的源码解释可以参考一下:Tensorflow Object Detection API 源码分析之 model_main.py。
在终端输入:
python model_main.py \--pipeline_config_path=config文件所在路径 \--model_dir=模型数据输出的保存路径 \--num_train_steps=60000 \--num_eval_steps=20 \--alsologtostderr
不出意外的话就可以正常训练了,训练之后可以在model_dir文件夹中看到保存下来的模型。
注:接下来冻结模型的导出以及模型的测试都还是和用TensorFlow训练一个物体检测器(手把手教学版)中介绍的一样。
这篇关于如何用tensorflow中的object detection API来训练自己的数据集上的目标检测器的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!