第11篇 Fast AI深度学习课程——机器翻译

2024-02-27 00:32

本文主要是介绍第11篇 Fast AI深度学习课程——机器翻译,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在上节课程中,我们使用语言模型对IMDB影评进行了情感分析。对于语言模型而言,使用的神经网络是一个seq2seq的网络,即输入和输出均为序列;每输入一个单词,就需输出一个单词,因此输入输出的序列长度是一致的。对于影评分析,是一个由字词序列得到单一分类结果的网络,即为seq2one的网络。本节将介绍由法语到英语的机器翻译,该类型网络也是seq2seq,但与语言模型不同之处在于,其在读入整个字符序列后,再输出另一个字符序列,两个序列长度可不一致,而序列之间的字词也没有一一对应的关系。(关于RNN的分类情况,可参见CS231n的相关内容。)

本节秉承了本系列课程自顶而下的学习思路,在前一节的基础上(前一节主要还是在FastAI的代码基础上进行网络的构建),将从底层开始实现用于机器学习的网络。

一、数据

1. 构建词库

本课所用数据为某网站的法语版和英语版的文章,运行如下命令进行下载:

wget http://www.statmt.org/wmt10/training-giga-fren.tar
tar -xvf training-giga-fren.tar
gunzip giga-fren.release2.fixed.en.gz
gunzip giga-fren.release2.fixed.fr.gz

数据压缩包为2.5G,可能得下个把小时。为简化问题,我们将在数据集的问句集合上进行讨论,具体而言,就是英文句库中以whatwherewhichwhen开头的语句。筛选后所得语句大致为52000条。

对语句进行分词。在对法语进行分词前,可能需要下载spacy的法语支持数据包:

python -m spacy download fr

分词代码如下:

en_tok = Tokenizer.proc_all_mp(partition_by_cores(en_qs))
fr_tok = Tokenizer.proc_all_mp(partition_by_cores(fr_qs), 'fr')

分词后,所得英文序列的90%23个词以内,发文序列90%在28个词以内。接下来构建词库。注意按照词频进行筛选,并补充特殊字词:_bos__pad__unk__eos_

得到词库后,将字词数字化。这一步的转化是用fasttext包的词向量实现的。转化后,每个字词使用300维的向量标识。词向量下载:

wget https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.fr.zip
wget https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.en.zip

得下好长时间。

2. 构建数据模型

首先构建适合seq2seq网络的Dataset。由上节课已知,Dataset实际为一个索引类,其只需实现__getitem__()__len__()函数:

class Seq2SeqDataset(Dataset):def __init__(self, x, y): self.x,self.y = x,ydef __getitem__(self, idx): return A(self.x[idx], self.y[idx])def __len__(self): return len(self.x)
np.random.seed(42)
trn_keep = np.random.rand(len(en_ids_tr))>0.1
en_trn,fr_trn = en_ids_tr[trn_keep],fr_ids_tr[trn_keep]
en_val,fr_val = en_ids_tr[~trn_keep],fr_ids_tr[~trn_keep]
trn_ds = Seq2SeqDataset(fr_trn,en_trn)
val_ds = Seq2SeqDataset(fr_val,en_val)

然后由Dataset构建数据加载器Dataloader,这一部分与上一节大致相同。不同之处在于对文本序列进行补齐时,本例中是在序列末尾补齐,而分类网络是在序列开头补齐。这里的直观理解是:在分类网络里,对于一个批次中的最长文本,那么在读完文本后再做判定是合适的;而对短文本,如果在末端补齐,则填充的无意义字符会极大影响分类结果。在翻译网络中,我们只关心句子结束符之前的内容,这一部分要尽量减少填充字符的影响,因此在句子末尾补齐是合适的。

bs=125
trn_samp = SortishSampler(en_trn, key=lambda x: len(en_trn[x]), bs=bs)
val_samp = SortSampler(en_val, key=lambda x: len(en_val[x]))
trn_dl = DataLoader(trn_ds, bs, transpose=True, transpose_y=True, num_workers=1, pad_idx=1, pre_pad=False, sampler=trn_samp)
val_dl = DataLoader(val_ds, int(bs*1.6), transpose=True, transpose_y=True, num_workers=1, pad_idx=1, pre_pad=False, sampler=val_samp)

由数据加载器构建ModelData。事实上,Model Data就是整合训练集、验证集、可选的测试集,并提供可用于临时存储的路径。

md = ModelData(PATH, trn_dl, val_dl)

二、网络架构

翻译网络的结构如下图所示。整个流程为:将一种语言的语句通过一个Encoder网络,获得最终的一个表征语句句法结构等特征的隐藏状态向量,以之为下一个Decoder网络的初始隐藏状态,并以_bos_为初始输入,按照训练语言模型时的方式,一步一词地生成另一语言的完整语句。

图 1. 翻译网络的架构
class Seq2SeqRNN(nn.Module):def __init__(self, vecs_enc, itos_enc, em_sz_enc, vecs_dec, itos_dec, em_sz_dec, nh, out_sl, nl=2):super().__init__()self.nl,self.nh,self.out_sl = nl,nh,out_slself.emb_enc = create_emb(vecs_enc, itos_enc, em_sz_enc)self.emb_enc_drop = nn.Dropout(0.15)self.gru_enc = nn.GRU(em_sz_enc, nh, num_layers=nl, dropout=0.25)self.out_enc = nn.Linear(nh, em_sz_dec, bias=False)self.emb_dec = create_emb(vecs_dec, itos_dec, em_sz_dec)self.gru_dec = nn.GRU(em_sz_dec, em_sz_dec, num_layers=nl, dropout=0.1)self.out_drop = nn.Dropout(0.35)self.out = nn.Linear(em_sz_dec, len(itos_dec))self.out.weight.data = self.emb_dec.weight.datadef forward(self, inp):sl,bs = inp.size()h = self.initHidden(bs)emb = self.emb_enc_drop(self.emb_enc(inp))enc_out, h = self.gru_enc(emb, h)h = self.out_enc(h)dec_inp = V(torch.zeros(bs).long())res = []for i in range(self.out_sl):emb = self.emb_dec(dec_inp).unsqueeze(0)outp, h = self.gru_dec(emb, h)outp = self.out(self.out_drop(outp[0]))res.append(outp)dec_inp = V(outp.data.max(1)[1])if (dec_inp==1).all(): breakreturn torch.stack(res)def initHidden(self, bs): return V(torch.zeros(self.nl, bs, self.nh))

注意forward()函数。其中``Decoder的输入dec_inp初始化为0,即bos的索引值;Decoder的初始隐藏状态为Encoder的输出;outp`表示在词库中所有词上的概率。

值得说明的要点如下:

1. Encoder的内嵌矩阵

使用fast.text提供的词向量矩阵作为内嵌矩阵。由于fast.text的词向量矩阵的标准差为0.3,为得到大致满足高斯分布的内嵌矩阵,需要乘以系数3

2. 如何确定目标语句完结

首先统计一个目标语言的最长语句长度。然后以这个长度为终值做循环,直至结束或输出_pad_

三、损失函数

损失函数使用的是交互熵函数。由于生成的翻译语句可能和目标语句长度不一致,所以可能需要做填充。所使用的Pytorchpad函数,其需要六个参数,分别指明了在次序列方向、批索引方向的填充的头尾起始位置以及长度。

def seq2seq_loss(input, target):sl,bs = target.size()sl_in,bs_in,nc = input.size()if sl>sl_in: input = F.pad(input, (0,0,0,0,0,sl-sl_in))input = input[:sl]return F.cross_entropy(input.view(-1,nc), target.view(-1))#, ignore_index=1)

四、一些技巧

1. 双向训练设置

一般设置Encoderbidirectional=True,而不对Decoder做双向设置。这样,网络会同时在输入序列的倒序序列上训练得到相应的隐藏状态。

2. 初始阶段的强制校正

考虑训练初始时,网络对两种语言还未学习到有效信息,此时Decoder的每一步输出的单词都是随机的,从而导致后续输出远偏离于真值。而如果此时强制以正确的目标语句进行Decoder状态的推进,可有效提高网络收敛的速度。(这实际和GAN的策略很接近。)实际应用中,设置pr_force参数,当预测出的词的概率低于pr_force时,就采取强制校正措施。然后逐渐缩小pr_force,减弱强制校正的力度。

在前向传播中加入强制校正还是挺直观的,修改Seq2SeqRNNforward()函数:

    def forward(self, inp, y=None):sl,bs = inp.size()h = self.initHidden(bs)emb = self.emb_enc_drop(self.emb_enc(inp))enc_out, h = self.gru_enc(emb, h)h = self.out_enc(h)dec_inp = V(torch.zeros(bs).long())res = []for i in range(self.out_sl):emb = self.emb_dec(dec_inp).unsqueeze(0)outp, h = self.gru_dec(emb, h)outp = self.out(self.out_drop(outp[0]))res.append(outp)dec_inp = V(outp.data.max(1)[1])if (dec_inp==1).all(): breakif (y is not None) and (random.random()<self.pr_force):if i>=len(y): breakdec_inp = y[i]return torch.stack(res)

注意和pr_force相关的那一行。

那么如何加入使得pr_force逐步减小的机制呢?实际上控制epoch之间的循环的是fit()函数,在其定义中,调用了stepper.step(),该函数实现了模型的前向传播、损失函数的计算、梯度的反向传播等。因此只需定义一个新的stepper,重写其step()函数,实现pr_force的逐步减小即可。

class Seq2SeqStepper(Stepper):def step(self, xs, y, epoch):self.m.pr_force = (10-epoch)*0.1 if epoch<10 else 0return super.step(xs, y, epoch)

在调用learner.fit()时,指明stepper=Seq2SeqStepper

3. 注意力模型

Encoder不仅输出了最后一步的隐藏状态,还保存了前面步骤的隐藏状态。如果能够在输出目标语言的某个字词时,在源语言的语句中找到与之最相关的部分,然后对该相关部分的隐藏状态进行加权求和,并传递到Decoder中,那么Decoder所获取的信息就更全面,应当能够改善翻译效果。而这种加权信息,可以通过一个小型网络得到。

def forward(self, inp, y=None, ret_attn=False):sl,bs = inp.size()h = self.initHidden(bs)emb = self.emb_enc_drop(self.emb_enc(inp))enc_out, h = self.gru_enc(emb, h)h = self.out_enc(h)dec_inp = V(torch.zeros(bs).long())res,attns = [],[]w1e = enc_out @ self.W1for i in range(self.out_sl):w2h = self.l2(h[-1])u = F.tanh(w1e + w2h)a = F.softmax(u @ self.V, 0)attns.append(a)Xa = (a.unsqueeze(2) * enc_out).sum(0)emb = self.emb_dec(dec_inp)wgt_enc = self.l3(torch.cat([emb, Xa], 1))outp, h = self.gru_dec(wgt_enc.unsqueeze(0), h)outp = self.out(self.out_drop(outp[0]))res.append(outp)dec_inp = V(outp.data.max(1)[1])if (dec_inp==1).all(): breakif (y is not None) and (random.random()<self.pr_force):if i>=len(y): breakdec_inp = y[i]res = torch.stack(res)if ret_attn: res = res,torch.stack(attns)return res

五、更广泛的应用实例

附注

  • 若一个python代码包的git库中,包含setup.pyrequirements.txt,那么可通过如下命令进行安装:pip install git+https://github.com/facebookresearch/fastText.git
  • 一个小技巧:对于网络,可以使用to_gpu()函数替代model.cuda()方法,这样在没有GPU时,会自动使用CPU进行计算。在调试时,可通过设置fastai.core.GPUFalse,以提供方便。

一些有用的链接

  • 课程wiki: 本节课程的一些相关资源,包括课程笔记、课上提到的博客地址等。

  • 注意力模型在机器翻译中的应用: 首次引入注意力模型的论文。

  • 注意力模型的博客: 博客很有意思,还支持用户交互。

这篇关于第11篇 Fast AI深度学习课程——机器翻译的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/750660

相关文章

SpringBoot整合DeepSeek实现AI对话功能

《SpringBoot整合DeepSeek实现AI对话功能》本文介绍了如何在SpringBoot项目中整合DeepSeekAPI和本地私有化部署DeepSeekR1模型,通过SpringAI框架简化了... 目录Spring AI版本依赖整合DeepSeek API key整合本地化部署的DeepSeek

PyCharm接入DeepSeek实现AI编程的操作流程

《PyCharm接入DeepSeek实现AI编程的操作流程》DeepSeek是一家专注于人工智能技术研发的公司,致力于开发高性能、低成本的AI模型,接下来,我们把DeepSeek接入到PyCharm中... 目录引言效果演示创建API key在PyCharm中下载Continue插件配置Continue引言

Go中sync.Once源码的深度讲解

《Go中sync.Once源码的深度讲解》sync.Once是Go语言标准库中的一个同步原语,用于确保某个操作只执行一次,本文将从源码出发为大家详细介绍一下sync.Once的具体使用,x希望对大家有... 目录概念简单示例源码解读总结概念sync.Once是Go语言标准库中的一个同步原语,用于确保某个操

Ubuntu系统怎么安装Warp? 新一代AI 终端神器安装使用方法

《Ubuntu系统怎么安装Warp?新一代AI终端神器安装使用方法》Warp是一款使用Rust开发的现代化AI终端工具,该怎么再Ubuntu系统中安装使用呢?下面我们就来看看详细教程... Warp Terminal 是一款使用 Rust 开发的现代化「AI 终端」工具。最初它只支持 MACOS,但在 20

五大特性引领创新! 深度操作系统 deepin 25 Preview预览版发布

《五大特性引领创新!深度操作系统deepin25Preview预览版发布》今日,深度操作系统正式推出deepin25Preview版本,该版本集成了五大核心特性:磐石系统、全新DDE、Tr... 深度操作系统今日发布了 deepin 25 Preview,新版本囊括五大特性:磐石系统、全新 DDE、Tree

Node.js 中 http 模块的深度剖析与实战应用小结

《Node.js中http模块的深度剖析与实战应用小结》本文详细介绍了Node.js中的http模块,从创建HTTP服务器、处理请求与响应,到获取请求参数,每个环节都通过代码示例进行解析,旨在帮... 目录Node.js 中 http 模块的深度剖析与实战应用一、引言二、创建 HTTP 服务器:基石搭建(一

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

AI绘图怎么变现?想做点副业的小白必看!

在科技飞速发展的今天,AI绘图作为一种新兴技术,不仅改变了艺术创作的方式,也为创作者提供了多种变现途径。本文将详细探讨几种常见的AI绘图变现方式,帮助创作者更好地利用这一技术实现经济收益。 更多实操教程和AI绘画工具,可以扫描下方,免费获取 定制服务:个性化的创意商机 个性化定制 AI绘图技术能够根据用户需求生成个性化的头像、壁纸、插画等作品。例如,姓氏头像在电商平台上非常受欢迎,

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06