自然语言处理(扩展学习1):Scheduled Sampling(计划采样)与Teacher forcing(教师强制)

本文主要是介绍自然语言处理(扩展学习1):Scheduled Sampling(计划采样)与Teacher forcing(教师强制),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

自然语言处理(扩展学习1):Scheduled Sampling(计划采样)与2. Teacher forcing(教师强制)


作者:安静到无声 个人主页

作者简介:人工智能和硬件设计博士生、CSDN与阿里云开发者博客专家,多项比赛获奖者,发表SCI论文多篇。

Thanks♪(・ω・)ノ 如果觉得文章不错或能帮助到你学习,可以点赞👍收藏📁评论📒+关注哦! o( ̄▽ ̄)d

欢迎大家来到安静到无声的 《基于pytorch的自然语言处理入门与实践》,如果对所写内容感兴趣请看《基于pytorch的自然语言处理入门与实践》系列讲解 - 总目录,同时这也可以作为大家学习的参考。欢迎订阅,请多多支持!

目录标题

  • 自然语言处理(扩展学习1):Scheduled Sampling(计划采样)与2. Teacher forcing(教师强制)
  • 1. Scheduled Sampling(计划采样)
    • 1.1 概念解释
    • 1.2 代码实现
  • 2. Teacher forcing(教师强制)
    • 2.1 概念解释
    • 2.1 代码实现
  • 参考

1. Scheduled Sampling(计划采样)

1.1 概念解释

Scheduled Sampling是一种用于训练序列生成模型的策略,旨在缓解曝光偏差(Exposure Bias)问题。曝光偏差是指模型在训练时接触到的数据分布与测试时的数据分布不一致,导致性能下降。

在Scheduled Sampling中,模型在每个时间步骤都有一定的概率选择使用真实目标序列中的单词作为输入,而不是使用前一个时间步骤生成的单词。这样可以使模型更好地适应真实数据分布,减少曝光偏差问题。

具体来说,Scheduled Sampling使用以下公式计算每个时间步骤生成当前单词的概率:

P ( y t ∣ y 1 , . . . , y t − 1 ) = ( 1 − ϵ ) ∗ P model ( y t ∣ y 1 , . . . , y t − 1 ) + ϵ ∗ P data ( y t ∣ y 1 , . . . , y t − 1 ) P(y_t|y_1, ..., y_{t-1}) = (1 - \epsilon) * P_{\text{model}}(y_t|y_1, ..., y_{t-1}) + \epsilon * P_{\text{data}}(y_t|y_1, ..., y_{t-1}) P(yty1,...,yt1)=(1ϵ)Pmodel(yty1,...,yt1)+ϵPdata(yty1,...,yt1)其中, P ( y t ∣ y 1 , . . . , y t − 1 ) P(y_t|y_1, ..., y_{t-1}) P(yty1,...,yt1)表示在给定前面的生成序列条件下生成当前单词 y t y_t yt的概率, P model ( y t ∣ y 1 , . . . , y t − 1 ) P_{\text{model}}(y_t|y_1, ..., y_{t-1}) Pmodel(yty1,...,yt1)表示模型生成该单词的概率, P data ( y t ∣ y 1 , . . . , y t − 1 ) P_{\text{data}}(y_t|y_1, ..., y_{t-1}) Pdata(yty1,...,yt1)表示真实目标序列中该单词的概率。参数 ϵ \epsilon ϵ用于控制采样策略,可以随着训练的进行而逐渐增加。

1.2 代码实现

下面是一个使用Python实现Scheduled Sampling的示例代码:

在这里插入图片描述

其中, P ( y t ∣ y < t , x ) P(y_t | y_{<t}, x) P(yty<t,x)表示在给定前文和输入的条件下,生成当前时间步的输出的概率。 P model P_{\text{model}} Pmodel表示由模型生成的概率分布, P prev P_{\text{prev}} Pprev表示根据上一个时间步的真实输出计算得到的概率分布。sample是从均匀分布中采样得到的一个随机数,threshold是一个控制Scheduled Sampling引入程度的超参数。

2. Teacher forcing(教师强制)

2.1 概念解释

Teacher forcing(教师强制)是一种在序列生成模型中使用的训练技术。具体来说,当使用RNN(循环神经网络)或类似架构的模型进行序列生成时,每个时间步都会根据前一个时间步的输入和隐藏状态生成输出。在训练期间,如果使用teacher forcing,那么每个时间步的输入将是真实的目标序列(而不是模型自身生成的序列)。这意味着模型在每个时间步都能够观察到正确的答案,从而更容易地学习到正确的模式和规律。

2.1 代码实现

import torch
import torch.nn as nn# 定义序列到序列模型
class Seq2SeqModel(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim):super(Seq2SeqModel, self).__init__()self.hidden_dim = hidden_dim# 定义编码器self.encoder = nn.RNN(input_dim, hidden_dim)# 定义解码器self.decoder = nn.RNN(output_dim, hidden_dim)# 定义全连接层,将解码器的输出映射为目标序列self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, input_seq, target_seq):# 编码器计算输入序列的隐藏状态_, hidden_state = self.encoder(input_seq)# 解码器初始化隐藏状态decoder_hidden_state = hidden_state# 用真实目标序列作为输入来指导解码器的生成过程decoder_outputs, _ = self.decoder(target_seq, decoder_hidden_state)# 对解码器的输出应用全连接层进行映射output_seq = self.fc(decoder_outputs)return output_seq# 创建模型实例
input_dim = 10
hidden_dim = 20
output_dim = 10
model = Seq2SeqModel(input_dim, hidden_dim, output_dim)# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练模型
num_epochs = 10
for epoch in range(num_epochs):# 步骤1:将模型设为训练模式model.train()# 步骤2:清零梯度optimizer.zero_grad()# 步骤3:前向传播input_seq = torch.randn(5, 3, input_dim)  # 输入序列target_seq = torch.randn(5, 3, output_dim)  # 目标序列output_seq = model(input_seq, target_seq)# 步骤4:计算损失loss = criterion(output_seq, target_seq)# 步骤5:反向传播和优化loss.backward()optimizer.step()print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

上述代码中,我们定义了一个简单的序列到序列模型Seq2SeqModel,其中包括一个RNN编码器、一个RNN解码器和一个全连接层。在forward方法中,我们首先使用编码器计算输入序列的隐藏状态,然后将隐藏状态作为解码器的初始隐藏状态。接下来,我们使用真实目标序列来指导解码器的生成过程,并将解码器的输出映射为目标序列。在训练阶段,我们使用真实目标序列作为输入来指导模型的生成过程。最后,我们定义了损失函数和优化器,并进行训练。

需要注意的是,在实际应用中,模型的推理阶段并不会使用真实目标序列来指导生成过程。在推理阶段,可以将前一个时间步的模型输出作为下一个时间步的输入,从而进行序列的自我生成。

--------推荐专栏--------
🔥 手把手实现Image captioning
💯CNN模型压缩
💖模式识别与人工智能(程序与算法)
🔥FPGA—Verilog与Hls学习与实践
💯基于Pytorch的自然语言处理入门与实践

参考

Scheduled Sampling的搜索结果_百度图片搜索 (baidu.com)
Teacher forcing RNN的搜索结果_百度图片搜索 (baidu.com)

在这里插入图片描述

这篇关于自然语言处理(扩展学习1):Scheduled Sampling(计划采样)与Teacher forcing(教师强制)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/235192

相关文章

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

无人叉车3d激光slam多房间建图定位异常处理方案-墙体画线地图切分方案

墙体画线地图切分方案 针对问题:墙体两侧特征混淆误匹配,导致建图和定位偏差,表现为过门跳变、外月台走歪等 ·解决思路:预期的根治方案IGICP需要较长时间完成上线,先使用切分地图的工程化方案,即墙体两侧切分为不同地图,在某一侧只使用该侧地图进行定位 方案思路 切分原理:切分地图基于关键帧位置,而非点云。 理论基础:光照是直线的,一帧点云必定只能照射到墙的一侧,无法同时照到两侧实践考虑:关

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

csu 1446 Problem J Modified LCS (扩展欧几里得算法的简单应用)

这是一道扩展欧几里得算法的简单应用题,这题是在湖南多校训练赛中队友ac的一道题,在比赛之后请教了队友,然后自己把它a掉 这也是自己独自做扩展欧几里得算法的题目 题意:把题意转变下就变成了:求d1*x - d2*y = f2 - f1的解,很明显用exgcd来解 下面介绍一下exgcd的一些知识点:求ax + by = c的解 一、首先求ax + by = gcd(a,b)的解 这个

零基础学习Redis(10) -- zset类型命令使用

zset是有序集合,内部除了存储元素外,还会存储一个score,存储在zset中的元素会按照score的大小升序排列,不同元素的score可以重复,score相同的元素会按照元素的字典序排列。 1. zset常用命令 1.1 zadd  zadd key [NX | XX] [GT | LT]   [CH] [INCR] score member [score member ...]

科研绘图系列:R语言扩展物种堆积图(Extended Stacked Barplot)

介绍 R语言的扩展物种堆积图是一种数据可视化工具,它不仅展示了物种的堆积结果,还整合了不同样本分组之间的差异性分析结果。这种图形表示方法能够直观地比较不同物种在各个分组中的显著性差异,为研究者提供了一种有效的数据解读方式。 加载R包 knitr::opts_chunk$set(warning = F, message = F)library(tidyverse)library(phyl

【机器学习】高斯过程的基本概念和应用领域以及在python中的实例

引言 高斯过程(Gaussian Process,简称GP)是一种概率模型,用于描述一组随机变量的联合概率分布,其中任何一个有限维度的子集都具有高斯分布 文章目录 引言一、高斯过程1.1 基本定义1.1.1 随机过程1.1.2 高斯分布 1.2 高斯过程的特性1.2.1 联合高斯性1.2.2 均值函数1.2.3 协方差函数(或核函数) 1.3 核函数1.4 高斯过程回归(Gauss

【生成模型系列(初级)】嵌入(Embedding)方程——自然语言处理的数学灵魂【通俗理解】

【通俗理解】嵌入(Embedding)方程——自然语言处理的数学灵魂 关键词提炼 #嵌入方程 #自然语言处理 #词向量 #机器学习 #神经网络 #向量空间模型 #Siri #Google翻译 #AlexNet 第一节:嵌入方程的类比与核心概念【尽可能通俗】 嵌入方程可以被看作是自然语言处理中的“翻译机”,它将文本中的单词或短语转换成计算机能够理解的数学形式,即向量。 正如翻译机将一种语言