关键词短语生成的无监督方法12——Train.py

2023-11-07 23:50

本文主要是介绍关键词短语生成的无监督方法12——Train.py,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

2021SC@SDUSC

文章目录

    • 一、model.train()与model.eval()
    • 二、训练模型
    • 三、总结

一、model.train()与model.eval()

由于训练模型和评估模型由model.train()和model.eval()两个函数支撑,故我首先对它们展开学习与分析。

model.train()和model.eval()的区别主要在于Batch Normalization和Dropout两层。

model.train()
在这里插入图片描述
启用Batch Normalization和Dropout。
如果模型中有BN层(Batch Normalization和Dropout,需要在训练时添加model.train()。model.train()是保证BN层能够用到每一批数据的均值和方差。对于Dropout,model.train()是随机取一部分网络连接来训练更新参数。

model.eval()
在这里插入图片描述
不启用 Batch Normalization和Dropout。
如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()。model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。

训练完train样本后,生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有BN层和Dropout所带来的的性质。
在做one classification的时候,训练集和测试集的样本分布是不一样的,尤其需要注意这一点。

model.train()/modeal.eval():

    def train(self, mode=True):self.training = modefor module in self.children():module.train(mode)return selfdef eval(self):return self.train(False)

需要记住当前的self.training的值是True还是False。

以Dropout为例,进入其对应的源代码,下方对应的self.training就是第一步中的self.training,原因在于Dropout继承了 _DropoutNd类,而 _DropoutNd由继承了Module类,Module类中自带变量self.training,通过这种方法,来控制train/eval模型下是否进行Dropout。

class Dropout(_DropoutNd):@weak_script_methoddef forward(self, input):return F.dropout(input, self.p, self.training, self.inplace)

PyTorch会关注是训练还是评估模型的原因是dropout和BN层。这项技术在训练中随机去除神经元。
在这里插入图片描述
如果右边被删除的神经元是唯一促成正确结果的神经元。一旦移除了被删除的神经元,它就迫使其他神经元训练和学习如何在没有被删除神经元的情况下保持准确。这种dropout提高了最终测试的性能,但它对训练期间的性能产生了负面影响,因为网络是不全的。

二、训练模型

在分析了model.train()和model.eval()后,对train()函数展开分析。

#训练模型
def train(iterator):#启用batch normalization和drop outmodel.train()#定义损失值epoch_loss = 0#声明计数器cnt=0m = 0for i,(src,trg) in enumerate(iterator):#for i,(x,cls) in enumerate(iterator):src = src.long().permute(1,0).to(device)trg = trg.long().permute(1,0).to(device)

代码分段解析: 此段代码启用batch normalization和drop out并定义函数损失值epoch_loss,通过枚举迭代器,初始化src、trg,存储数据集。

		#置零梯度optimizer.zero_grad()#前向传播求出预测值prediction和隐藏值hiddenoutput = model.forward(src, trg)output_dim = output.shape[-1]    output = output[1:].view(-1, output_dim)trg = trg[1:].reshape(5*trg.shape[1])#求lossloss = criterion(output, trg)  #反向传播求梯度loss.backward()#更新所有参数optimizer.step()

代码分段解析: 此段代码为训练神经网络的基本过程,即zero_grad+forward+loss+backward+step

  1. zaro_grad:算一个batch计算一次梯度,然后进行一次梯度更新。进行下一次batch梯度计算的时候,前一个batch的梯度计算结果没有保留的必要。所以在下一次梯度更新的时候,先使用optimizer.zero_grad把梯度信息设置为0。
  2. loss:我们使用loss来定义损失函数,是要确定优化的目标是什么,然后以目标为头,才可以进行链式法则和反向传播。
  3. backward:调用loss.backward方法时候,Pytorch的autograd就会自动沿着计算图反向传播,计算每一个叶子节点的梯度(如果某一个变量是由用户创建的,则它为叶子节点)。使用该方法,可以计算链式法则求导之后计算的结果值。
  4. optimizer.step:用来更新参数。

总的来说,此段代码对每个batch执行梯度下降的操作。首先将梯度初始化为零,其次通过model.forward()前向传播求出预测值,接着求loss值,反向传播求梯度,最后通过优化器optimizer更新所有参数。

    #求loss的平均值,用item取出唯一的元素epoch_loss += loss.item()#释放显存torch.cuda.empty_cache()#调整学习率scheduler.step()return epoch_loss / len(iterator)

代码分段解析:
scheduler.step()每调用step_size次,对应的学习率就会按照策略调整一次。
在此,需注意optimizer.step()和scheduler.step()的区别。
optimizer.step()通常用在每个mini-batch之中,而scheduler.step()通常用在epoch里面,但是不绝对,可以根据具体的需求来做。只有用了optimizer.step(),模型才会更新,而scheduler.step()是对lr进行调整。通常有

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size = 100, gamma = 0.1)
model = net.train(model, loss_function, optimizer, scheduler, num_epochs = 100)

在scheduler的step_size表示scheduler.step()每调用step_size次,对应的学习率就会按照策略调整一次。所以如果scheduler.step()是放在mini-batch里面,那么step_size指的是经过这么多次迭代,学习率改变一次。

三、总结

本周对训练模型实现和评估模型实现中的关键函数model.train()、model.eval()展开了学习与分析,并分析了训练模型函数train()。下周将对评估模型函数eval()展开分析。

这篇关于关键词短语生成的无监督方法12——Train.py的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/366825

相关文章

Linux换行符的使用方法详解

《Linux换行符的使用方法详解》本文介绍了Linux中常用的换行符LF及其在文件中的表示,展示了如何使用sed命令替换换行符,并列举了与换行符处理相关的Linux命令,通过代码讲解的非常详细,需要的... 目录简介检测文件中的换行符使用 cat -A 查看换行符使用 od -c 检查字符换行符格式转换将

Java编译生成多个.class文件的原理和作用

《Java编译生成多个.class文件的原理和作用》作为一名经验丰富的开发者,在Java项目中执行编译后,可能会发现一个.java源文件有时会产生多个.class文件,从技术实现层面详细剖析这一现象... 目录一、内部类机制与.class文件生成成员内部类(常规内部类)局部内部类(方法内部类)匿名内部类二、

SpringBoot实现数据库读写分离的3种方法小结

《SpringBoot实现数据库读写分离的3种方法小结》为了提高系统的读写性能和可用性,读写分离是一种经典的数据库架构模式,在SpringBoot应用中,有多种方式可以实现数据库读写分离,本文将介绍三... 目录一、数据库读写分离概述二、方案一:基于AbstractRoutingDataSource实现动态

使用Jackson进行JSON生成与解析的新手指南

《使用Jackson进行JSON生成与解析的新手指南》这篇文章主要为大家详细介绍了如何使用Jackson进行JSON生成与解析处理,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1. 核心依赖2. 基础用法2.1 对象转 jsON(序列化)2.2 JSON 转对象(反序列化)3.

Java中的String.valueOf()和toString()方法区别小结

《Java中的String.valueOf()和toString()方法区别小结》字符串操作是开发者日常编程任务中不可或缺的一部分,转换为字符串是一种常见需求,其中最常见的就是String.value... 目录String.valueOf()方法方法定义方法实现使用示例使用场景toString()方法方法

Java中List的contains()方法的使用小结

《Java中List的contains()方法的使用小结》List的contains()方法用于检查列表中是否包含指定的元素,借助equals()方法进行判断,下面就来介绍Java中List的c... 目录详细展开1. 方法签名2. 工作原理3. 使用示例4. 注意事项总结结论:List 的 contain

macOS无效Launchpad图标轻松删除的4 种实用方法

《macOS无效Launchpad图标轻松删除的4种实用方法》mac中不在appstore上下载的应用经常在删除后它的图标还残留在launchpad中,并且长按图标也不会出现删除符号,下面解决这个问... 在 MACOS 上,Launchpad(也就是「启动台」)是一个便捷的 App 启动工具。但有时候,应

SpringBoot日志配置SLF4J和Logback的方法实现

《SpringBoot日志配置SLF4J和Logback的方法实现》日志记录是不可或缺的一部分,本文主要介绍了SpringBoot日志配置SLF4J和Logback的方法实现,文中通过示例代码介绍的非... 目录一、前言二、案例一:初识日志三、案例二:使用Lombok输出日志四、案例三:配置Logback一

Python实现无痛修改第三方库源码的方法详解

《Python实现无痛修改第三方库源码的方法详解》很多时候,我们下载的第三方库是不会有需求不满足的情况,但也有极少的情况,第三方库没有兼顾到需求,本文将介绍几个修改源码的操作,大家可以根据需求进行选择... 目录需求不符合模拟示例 1. 修改源文件2. 继承修改3. 猴子补丁4. 追踪局部变量需求不符合很

java中使用POI生成Excel并导出过程

《java中使用POI生成Excel并导出过程》:本文主要介绍java中使用POI生成Excel并导出过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录需求说明及实现方式需求完成通用代码版本1版本2结果展示type参数为atype参数为b总结注:本文章中代码均为