动手学深度学习(pytorch)学习记录15-正则化、权重衰减[学习记录]

2024-08-25 19:12

本文主要是介绍动手学深度学习(pytorch)学习记录15-正则化、权重衰减[学习记录],希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

我们可以通过收集更多的训练数据来缓解过拟合,但这可能成本很高,耗时很多或完全失去控制,在短期内难以做到。 假设已经有了足够多的数据,接下来将重点放在正则化技术上。

权重衰减是使用最广泛的正则化技术之一,它通常也被称为L2正则化
技术方法:通过函数与零之间的距离来度量函数的复杂度;
如何精确测量这种‘距离’?
一个简单的方法是通过线性函数f(x)=w^(T) x 中权重向量的某个范数(如||w||^2)来度量复杂度
最常用的方法是将范数作为惩罚项添加到最小化损失中。
那么原来的训练目标“最小化训练标签上的预测损失”调整为“最小化训练标签上的预测损失+惩罚项”
如果权重向量增长过大,学习算法可能会更集中于最小化权重范数。
在线性回归损失是:
在这里插入图片描述
在损失函数中添加范数后:

在这里插入图片描述

通过正则化常数λ(非负超参数)来权衡这个额外的惩罚。这里除以2,当取一个2次函数的倒数时,2和1/2会抵消让更新表达式简单美观

用简单的例子实现权重衰减

由于沐神的d2l包太好用了,部分函数就直接调用包里的,使代码更简洁

%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l

用下面的公式生成数据

在这里插入图片描述

标签是线性函数,被噪声破坏,为了突出过拟合,增加线性回归维度为200,但只提供包含20个样本的数据集训练。

n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

定义一个函数来随机初始化模型参数

def init_params():w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)b = torch.zeros(1, requires_grad=True)return [w, b]

定义L2范数惩罚
实现惩罚的方法-对所有项求平方后再将他们求和。

def l2_penalty(w):return torch.sum(w.pow(2)) / 2

定义训练代码

def train(lambd):w, b = init_params()net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_lossnum_epochs, lr = 100, 0.003animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:# 增加了L2范数惩罚项,# 广播机制使l2_penalty(w)成为一个长度为batch_size的向量l = loss(net(X), y) + lambd * l2_penalty(w)l.sum().backward()d2l.sgd([w, b], lr, batch_size)if (epoch + 1) % 5 == 0:animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数是:', torch.norm(w).item())

忽略正则化直接训练的
用lambd = 0禁用权重衰减后运行这个代码。 注意,这里训练误差有了减少,但测试误差没有减少, 这意味着出现了严重的过拟合。

train(lambd=0)

在这里插入图片描述
使用正则化

train(lambd=3)

在这里插入图片描述
通过pytorch框架简洁实现

def train_concise(wd):net = nn.Sequential(nn.Linear(num_inputs, 1))for param in net.parameters():param.data.normal_()loss = nn.MSELoss(reduction='none')num_epochs, lr = 100, 0.003# 偏置参数没有衰减trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd},{"params":net[0].bias}], lr=lr)animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:trainer.zero_grad()l = loss(net(X), y)l.mean().backward()trainer.step()if (epoch + 1) % 5 == 0:animator.add(epoch + 1,(d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数:', net[0].weight.norm().item())
train_concise(0)

在这里插入图片描述

train_concise(3)

在这里插入图片描述

封面图片来源

欢迎点击我的主页查看更多文章。
本人学习地址https://zh-v2.d2l.ai/
恳请大佬批评指正。

这篇关于动手学深度学习(pytorch)学习记录15-正则化、权重衰减[学习记录]的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1106411

相关文章

Oracle查询优化之高效实现仅查询前10条记录的方法与实践

《Oracle查询优化之高效实现仅查询前10条记录的方法与实践》:本文主要介绍Oracle查询优化之高效实现仅查询前10条记录的相关资料,包括使用ROWNUM、ROW_NUMBER()函数、FET... 目录1. 使用 ROWNUM 查询2. 使用 ROW_NUMBER() 函数3. 使用 FETCH FI

Python MySQL如何通过Binlog获取变更记录恢复数据

《PythonMySQL如何通过Binlog获取变更记录恢复数据》本文介绍了如何使用Python和pymysqlreplication库通过MySQL的二进制日志(Binlog)获取数据库的变更记录... 目录python mysql通过Binlog获取变更记录恢复数据1.安装pymysqlreplicat

PyTorch使用教程之Tensor包详解

《PyTorch使用教程之Tensor包详解》这篇文章介绍了PyTorch中的张量(Tensor)数据结构,包括张量的数据类型、初始化、常用操作、属性等,张量是PyTorch框架中的核心数据结构,支持... 目录1、张量Tensor2、数据类型3、初始化(构造张量)4、常用操作5、常用属性5.1 存储(st

五大特性引领创新! 深度操作系统 deepin 25 Preview预览版发布

《五大特性引领创新!深度操作系统deepin25Preview预览版发布》今日,深度操作系统正式推出deepin25Preview版本,该版本集成了五大核心特性:磐石系统、全新DDE、Tr... 深度操作系统今日发布了 deepin 25 Preview,新版本囊括五大特性:磐石系统、全新 DDE、Tree

Springboot的ThreadPoolTaskScheduler线程池轻松搞定15分钟不操作自动取消订单

《Springboot的ThreadPoolTaskScheduler线程池轻松搞定15分钟不操作自动取消订单》:本文主要介绍Springboot的ThreadPoolTaskScheduler线... 目录ThreadPoolTaskScheduler线程池实现15分钟不操作自动取消订单概要1,创建订单后

Node.js 中 http 模块的深度剖析与实战应用小结

《Node.js中http模块的深度剖析与实战应用小结》本文详细介绍了Node.js中的http模块,从创建HTTP服务器、处理请求与响应,到获取请求参数,每个环节都通过代码示例进行解析,旨在帮... 目录Node.js 中 http 模块的深度剖析与实战应用一、引言二、创建 HTTP 服务器:基石搭建(一

Servlet中配置和使用过滤器的步骤记录

《Servlet中配置和使用过滤器的步骤记录》:本文主要介绍在Servlet中配置和使用过滤器的方法,包括创建过滤器类、配置过滤器以及在Web应用中使用过滤器等步骤,文中通过代码介绍的非常详细,需... 目录创建过滤器类配置过滤器使用过滤器总结在Servlet中配置和使用过滤器主要包括创建过滤器类、配置过滤

正则表达式高级应用与性能优化记录

《正则表达式高级应用与性能优化记录》本文介绍了正则表达式的高级应用和性能优化技巧,包括文本拆分、合并、XML/HTML解析、数据分析、以及性能优化方法,通过这些技巧,可以更高效地利用正则表达式进行复杂... 目录第6章:正则表达式的高级应用6.1 模式匹配与文本处理6.1.1 文本拆分6.1.2 文本合并6

python与QT联合的详细步骤记录

《python与QT联合的详细步骤记录》:本文主要介绍python与QT联合的详细步骤,文章还展示了如何在Python中调用QT的.ui文件来实现GUI界面,并介绍了多窗口的应用,文中通过代码介绍... 目录一、文章简介二、安装pyqt5三、GUI页面设计四、python的使用python文件创建pytho

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert