本文主要是介绍seq2seq架构略解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
用于序列翻译任务(下图来自d2l)
训练时输入输出格式:
若数据集为{ <(a1,a2,a3,a4,a5),(b1,b2,b3,b4,b5)> }(AB语言对应的句子组)
输入
A语言的单词序列+结束符(a1,a2,a3,a4,a5,<eos>)
开始符+B语言的单词序列(<bos>,b1,b2,b3,b4,b5)
输出
B语言的单词序列(b1,b2,b3,b4,b5,<eos>)
预测时输入格式:
A语言的单词序列+结束符(a1,a2,a3,a4,a5,<eos>)
开始符+空(<bos>,,,,,)
编码器、解码器两部分使用的暂时还是RNN
编码器RNN
隐藏层使用普通初始化
输入:
A语言序列单词的独热向量+<eos>的独热向量
输出:
特征向量序列(但后续并没有使用)+隐藏层参数H1
解码器RNN
使用编码器计算结束之后的隐藏层H1,初始化隐藏层
输入:
训练时,输入<bos>+B语言序列单词的独热向量 + H1(直接拼接)
预测时,输入<bos>,然后将当前RNN预测结果作为下一次预测的输入。
输出:
B语言单词序列的独热向量+<eos>的独热向量
训练时seq2seq的计算图:
预测时seq2seq的计算图:
代码d2l官网十分详细了,就不再赘述了。
另外一个小插曲
实际上d2l官方给出的代码实现的解码器架构有一些小问题
在预测时,在第一步预测完毕之后,使用的dec_state会继承解码器RNN的隐藏层状态,而不是保持编码器所获取的隐藏层H1的信息
在评论区里面已经有大佬给出了正确的代码实现,可以围观。
这篇关于seq2seq架构略解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!