本文主要是介绍attention-is-all-you-need-pytorch 源码阅读,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
文章目录
- 训练数据流
- train.train_epoch
- Transformer
- Encoder
- EncoderLayer
- MultiHeadAttention
- ScaledDotProductAttention
- PositionwiseFeedForward
- Decoder
- DecoderLayer
- Predict
- TODO
训练数据流
train.train_epoch
对training_data
进行迭代, 产生batch
, 其中有src_seq
, trg_seq
src_seq.shape
Out[11]: torch.Size([256, 32])
src_seq
Out[12]:
tensor([[ 2, 4567, 4578, ..., 1, 1, 1],[ 2, 4558, 4565, ..., 1, 1, 1],[ 2, 4558, 4565, ..., 1, 1, 1],...,[ 2, 4558, 64, ..., 1, 1, 1],[ 2, 4564, 5051, ..., 1, 1, 1],[ 2, 4567, 4578, ..., 1, 1, 1]])
2
是开始, 1
是结束, 32
是句子长度, 256
是batch数
Transformer
transformer.Models.Transformer.forward
现在数据被丢进了Transformer
这个模型
src_mask = get_pad_mask(src_seq, self.src_pad_idx)
trg_mask = get_pad_mask(trg_seq, self.trg_pad_idx) & get_subsequent_mask(trg_seq)enc_output, *_ = self.encoder(src_seq, src_mask)
src_mask
会在ScaledDotProductAttention用到
编码了一波, enc_output
其实与输入数据的size一样
enc_output.shape
Out[9]: torch.Size([256, 36, 512])
去看Decoder
Encoder
transformer.Models.Encoder
# 9521 512 1
self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=pad_idx)
what-does-padding-idx-do-in-nn-embeddings
transformer.Models.Encoder.forward
这篇关于attention-is-all-you-need-pytorch 源码阅读的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!