使用early stopping解决神经网络过拟合问题

2024-05-26 08:48

本文主要是介绍使用early stopping解决神经网络过拟合问题,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

神经网络训练多少轮是一个很关键的问题,训练轮数少了欠拟合(underfit),训练轮数多了过拟合(overfit),那如何选择训练轮数呢?

Early stopping可以帮助我们解决这个问题,它的作用就是当模型在验证集上的性能不再增加的时候就停止训练,从而达到充分训练的作用,又避免过拟合。

一、在Keras中使用early stopping

完整代码

Keras中有EarlyStopping类,可以直接拿来使用,非常方便

from keras.callbacks import EarlyStoppingearlystop = EarlyStopping(monitor = 'val_loss',mode='min',min_delta = 0,patience = 3,verbose = 1,)
  1. monitor。想要监控的指标,比如在这里我们主要看的是验证集上的loss,当loss不再降低的时候就停止
  2. mode。想要最大值还是最小值,在这里我们使用的min,当时loss越小越好
  3. min_delta。指标的变化超过min_delta才认为产生了变化,否则都认为不再上升或下降
  4. patience。多少轮不发生变化才停止
  5. verbose。设置为1的时候,训练结束会打印出epoch的情况

二、保存最佳模型

完整代码

在early stopping结束后得到模型不一定是最佳模型,所以我们需要把训练过程中表现最好的模型保存下来,以便使用。在这里我们可以使用Keras提供的另一callback来实现:

from keras.callbacks import ModelCheckpointmc = ModelCheckpoint(file_path='./best_model.h5',monitor='val_accuracy',mode='max',verbose=1,save_best_only=True)
  1. filepath,模型存储的路径
  2. monitor,监控的指标
  3. mode,最大还是最小模式
  4. verbose,日志显示控制
  5. save_best_only,是否只存储最好的模型

通过使用这个方法我们就可以把最好的模型存储下来,在使用的时候直接load就可以了。

三、在IMDB数据集上使用Early Stopping

完整代码​​​​​​​

IMDB是一个情感分析数据集,我们首先在这个数据集上使用一个简单的CNN看看效果,然后再使用Early Stopping作为对比。首先看看CNN代码。先对句子embedding, 然后使用一层Conv1D+Maxpooling。

# Build model
sentence = Input(batch_shape=(None, max_words), dtype='int32', name='sentence')
embedding_layer = Embedding(top_words, embedding_dims, input_length=max_words)
sent_embed = embedding_layer(sentence)
conv_layer = Conv1D(filters, kernel_size, padding='valid', activation='relu')
sent_conv = conv_layer(sent_embed)
sent_pooling = GlobalMaxPooling1D()(sent_conv)
sent_repre = Dense(250)(sent_pooling)
sent_repre = Activation('relu')(sent_repre)
sent_repre = Dense(1)(sent_repre)
pred = Activation('sigmoid')(sent_repre)
model = Model(inputs=sentence, outputs=pred)
rmsprop = optimizers.rmsprop(lr=0.0003)
model.compile(loss='binary_crossentropy', optimizer=rmsprop, metrics=['accuracy'])

最终在数据集上的结果如下,在训练集上基本达到了100,而在测试集上还不到90,看起来有点过拟合了

Training Accuracy: 100%
Test Accuracy: 88.50%

我们再看Loss曲线,大约在第8轮的时候,验证集上的Loss达到最低,但是在往后Loss开始升高,这就更加确定发生了过拟合,我们需要提前停止训练,最好在第8轮之后就停下来。

在IMDB数据集上使用Early Stopping

我们再训练过程中加上一个patience=10的earlystop,监控验证集loss。当验证集的loss在近10轮都没有下降的话就停止。

#early stopping
earlystop = EarlyStopping(monitor='val_loss',min_delta=0,patience=10,verbose=1)# fit the model
history = model.fit(x_train, y_train, batch_size=batch_size,epochs=epochs, verbose=1, validation_data=(x_test, y_test), callbacks[earlystop])

结果如下,我们可以看到训练最终在第16轮停止了,停止时在测试集上的准确率为88.40%,并没有高于不使用Early Stopping的情况,但是在训练的第12轮模型的准确达到了89.30%,超过了Baseline。所以我们需要加上存储最好模型的callback。

Epoch 2/50
5000/5000 [==============================] - 5s 951us/step - loss: 0.4851 - acc: 0.7986 - val_loss: 0.4320 - val_acc: 0.8170
Epoch 3/50
5000/5000 [==============================] - 5s 918us/step - loss: 0.3193 - acc: 0.8802 - val_loss: 0.3599 - val_acc: 0.8370
Epoch 4/50
5000/5000 [==============================] - 4s 882us/step - loss: 0.2093 - acc: 0.9322 - val_loss: 0.3392 - val_acc: 0.8530
Epoch 5/50
5000/5000 [==============================] - 4s 880us/step - loss: 0.1209 - acc: 0.9702 - val_loss: 0.4001 - val_acc: 0.8260
Epoch 6/50
5000/5000 [==============================] - 4s 887us/step - loss: 0.0600 - acc: 0.9884 - val_loss: 0.2900 - val_acc: 0.8710
Epoch 7/50
5000/5000 [==============================] - 4s 865us/step - loss: 0.0208 - acc: 0.9986 - val_loss: 0.2978 - val_acc: 0.8840
Epoch 8/50
5000/5000 [==============================] - 4s 883us/step - loss: 0.0053 - acc: 1.0000 - val_loss: 0.3180 - val_acc: 0.8840
Epoch 9/50
5000/5000 [==============================] - 4s 856us/step - loss: 0.0011 - acc: 1.0000 - val_loss: 0.3570 - val_acc: 0.8830
Epoch 10/50
5000/5000 [==============================] - 4s 845us/step - loss: 1.7574e-04 - acc: 1.0000 - val_loss: 0.4035 - val_acc: 0.8800
Epoch 11/50
5000/5000 [==============================] - 4s 869us/step - loss: 2.0190e-05 - acc: 1.0000 - val_loss: 0.4490 - val_acc: 0.8820
Epoch 12/50
5000/5000 [==============================] - 4s 846us/step - loss: 1.6874e-06 - acc: 1.0000 - val_loss: 0.5164 - val_acc: 0.8930
Epoch 13/50
5000/5000 [==============================] - 4s 860us/step - loss: 2.6231e-07 - acc: 1.0000 - val_loss: 0.5429 - val_acc: 0.8840
Epoch 14/50
5000/5000 [==============================] - 4s 870us/step - loss: 1.4614e-07 - acc: 1.0000 - val_loss: 0.5754 - val_acc: 0.8810
Epoch 15/50
5000/5000 [==============================] - 4s 888us/step - loss: 1.2477e-07 - acc: 1.0000 - val_loss: 0.5744 - val_acc: 0.8850
Epoch 16/50
5000/5000 [==============================] - 4s 876us/step - loss: 1.1823e-07 - acc: 1.0000 - val_loss: 0.5909 - val_acc: 0.8840
Epoch 00016: early stopping
Accuracy: 88.40%

存储最好模型

我们使用ModelCheckPoint存储最好的模型,具体如下,通过监控验证集上的准确率,我们把准确率最高的模型存储下来

from keras.callbacks import EarlyStopping, ModelCheckpointmc = ModelCheckpoint(filepath='best_model.h5',monitor='val_acc',mode='max',verbose=1,save_best_only=True)

然后在使用的时候进行load,然后就可以进行预测了

from keras.models import load_model
saved_model = load_model('best_model.h5')
# evaluate the model
_, train_acc = saved_model.evaluate(x_train, y_train, verbose=0)
_, test_acc = saved_model.evaluate(x_test, y_test, verbose=0)
print('Train: %.3f, Test: %.3f' % (train_acc, test_acc))

最终的结果如下

Train: 1.000, Test: 0.893

正确使用Early Stopping加上存储最佳模型可以帮助我们减轻过拟合,从而训练出表现更好的模型。

完整代码​​​​​​​​​​​​​​

这篇关于使用early stopping解决神经网络过拟合问题的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1003992

相关文章

C#中checked关键字的使用小结

《C#中checked关键字的使用小结》本文主要介绍了C#中checked关键字的使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学... 目录✅ 为什么需要checked? 问题:整数溢出是“静默China编程”的(默认)checked的三种用

C#中预处理器指令的使用小结

《C#中预处理器指令的使用小结》本文主要介绍了C#中预处理器指令的使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录 第 1 名:#if/#else/#elif/#endif✅用途:条件编译(绝对最常用!) 典型场景: 示例

JAVA Calendar设置上个月时,日期不存在或错误提示问题及解决

《JAVACalendar设置上个月时,日期不存在或错误提示问题及解决》在使用Java的Calendar类设置上个月的日期时,如果遇到不存在的日期(如4月31日),默认会自动调整到下个月的相应日期(... 目录Java Calendar设置上个月时,日期不存在或错误提示java进行日期计算时如果出现不存在的

Mybatis对MySQL if 函数的不支持问题解读

《Mybatis对MySQLif函数的不支持问题解读》接手项目后,为了实现多租户功能,引入了Mybatis-plus,发现之前运行正常的SQL语句报错,原因是Mybatis不支持MySQL的if函... 目录MyBATis对mysql if 函数的不支持问题描述经过查询网上搜索资料找到原因解决方案总结Myb

Nginx错误拦截转发 error_page的问题解决

《Nginx错误拦截转发error_page的问题解决》Nginx通过配置错误页面和请求处理机制,可以在请求失败时展示自定义错误页面,提升用户体验,下面就来介绍一下Nginx错误拦截转发error_... 目录1. 准备自定义错误页面2. 配置 Nginx 错误页面基础配置示例:3. 关键配置说明4. 生效

Mysql中RelayLog中继日志的使用

《Mysql中RelayLog中继日志的使用》MySQLRelayLog中继日志是主从复制架构中的核心组件,负责将从主库获取的Binlog事件暂存并应用到从库,本文就来详细的介绍一下RelayLog中... 目录一、什么是 Relay Log(中继日志)二、Relay Log 的工作流程三、Relay Lo

使用Redis实现会话管理的示例代码

《使用Redis实现会话管理的示例代码》文章介绍了如何使用Redis实现会话管理,包括会话的创建、读取、更新和删除操作,通过设置会话超时时间并重置,可以确保会话在用户持续活动期间不会过期,此外,展示了... 目录1. 会话管理的基本概念2. 使用Redis实现会话管理2.1 引入依赖2.2 会话管理基本操作

Springboot请求和响应相关注解及使用场景分析

《Springboot请求和响应相关注解及使用场景分析》本文介绍了SpringBoot中用于处理HTTP请求和构建HTTP响应的常用注解,包括@RequestMapping、@RequestParam... 目录1. 请求处理注解@RequestMapping@GetMapping, @PostMappin

Java调用DeepSeek API的8个高频坑与解决方法

《Java调用DeepSeekAPI的8个高频坑与解决方法》现在大模型开发特别火,DeepSeek因为中文理解好、反应快、还便宜,不少Java开发者都用它,本文整理了最常踩的8个坑,希望对... 目录引言一、坑 1:Token 过期未处理,鉴权异常引发服务中断问题本质典型错误代码解决方案:实现 Token

springboot3.x使用@NacosValue无法获取配置信息的解决过程

《springboot3.x使用@NacosValue无法获取配置信息的解决过程》在SpringBoot3.x中升级Nacos依赖后,使用@NacosValue无法动态获取配置,通过引入SpringC... 目录一、python问题描述二、解决方案总结一、问题描述springboot从2android.x