本文主要是介绍用tensorflow中slim下的分类网络训练自己的数据集以及fine-tuning(可以直接实战使用),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
目录
前期准备
训练flower数据集(包括fine-tuning)
训练自己的数据集(包括fine-tuning)
前期准备
前期了解
tensorflow models
在tensorflow models中有官方维护和非官方维护的models,official models就是官方维护的models,里面使用的接口都是一些官方的接口,比如tf.layers.conv2d之类。而research models是tensorflow的研究人员自己实现的一些流行网络,不受官方支持,里面会用到一些slim之类的非官方接口。但是因为research models实现的网络非常多,而且提供了完整的训练和评价方案,所以我们现在基于research models中的实现来部署网络。
环境配置
首先要保证tf.contrib.slim在你的tensorflow环境中是存在的,运行下面的脚本保证没有错误发生。
python -c "import tensorflow.contrib.slim as slim; eval = slim.evaluation.evaluate_once"
base代码准备
TF的库里面没有TF-slim的内容,所以我们需要将代码clone到本地
cd $HOME/workspace
git clone https://github.com/tensorflow/models/
运行以下脚本确定是否可用
cd $HOME/workspace/models/research/slim
python -c "from nets import cifarnet; mynet = cifarnet.cifarnet"
其实我们只需要使用research中的slim的代码,所以我是直接拷贝了slim的代码到本地,基于slim代码进行修改。
训练flower数据集
下载数据并创建tfrecord
官网提供了下载并且转换数据集的方法,运行如下脚本即可,脚本会直接下载flower数据集并且存储为TFRecord的格式。
$ python download_and_convert_data.py \--dataset_name=flowers \--dataset_dir=./tmp/data/flowers
为何官网要使用TFRecord呢?因为TFRecord和tensorflow内部有一个加速机制。实际读取tfrecord数据时,先以相应的tfrecord文件为参数,创建一个输入队列,这个队列有一定的容量,在一部分数据出队列时,tfrecord中的其他数据就可以通过预取进入队列,这个过程和网络的计算是独立进行的。也就是说,网络每一个iteration的训练不必等待数据队列准备好再开始,队列中的数据始终是充足的,而往队列中填充数据时,也可以使用多线程加速。
下载pre-trained checkpoint
每个网络对应的checkpoint可以从官网上找到,官网也提供了下载inception v3的checkpoint的例子
$ mkdir ./tmp/checkpoints
$ wget http
这篇关于用tensorflow中slim下的分类网络训练自己的数据集以及fine-tuning(可以直接实战使用)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!