本文主要是介绍Tensorflow nmt的数据预处理过程,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
tensorflow nmt的数据预处理过程
在tensorflow/nmt项目中,训练数据和推断数据的输入使用了新的Dataset API,应该是tensorflow 1.2之后引入的API,方便数据的操作。如果你还在使用老的Queue和Coordinator的方式,建议升级高版本的tensorflow并且使用Dataset API。
本教程将从训练数据和推断数据两个方面,详解解析数据的具体处理过程,你将看到文本数据如何转化为模型所需要的实数,以及中间的张量的维度是怎么样的,batch_size和其他超参数又是如何作用的。
训练数据的处理
先来看看训练数据的处理。训练数据的处理比推断数据的处理稍微复杂一些,弄懂了训练数据的处理过程,就可以很轻松地理解推断数据的处理。
训练数据的处理代码位于nmt/utils/iterator_utils.py文件内的get_iterator
函数。我们先来看看这个函数所需要的参数是什么意思:
参数 | 解释 |
---|---|
src_dataset | 源数据集 |
tgt_dataset | 目标数据集 |
src_vocab_table | 源数据单词查找表,就是个单词和int类型数据的对应表 |
tgt_vocab_table | 目标数据单词查找表,就是个单词和int类型数据的对应表 |
batch_size | 批大小 |
sos | 句子开始标记 |
eos | 句子结尾标记 |
random_seed | 随机种子,用来打乱数据集的 |
num_buckets | 桶数量 |
src_max_len | 源数据最大长度 |
tgt_max_len | 目标数据最大长度 |
这篇关于Tensorflow nmt的数据预处理过程的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!