梯度下降(Gradient Descent)原理以及Python代码

2024-05-26 08:48

本文主要是介绍梯度下降(Gradient Descent)原理以及Python代码,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

给定一个函数f(x),我们想知道当x是值是多少的时候使这个函数达到最小值。为了实现这个目标,我们可以使用梯度下降(Gradient Descent)进行近似求解。

梯度下降是一个迭代算法,具体地,下一次迭代令

x_{n+1} = x_{n} - \eta {f}'(x_{n})

{f}'(x)是梯度,其中\eta是学习率(learning rate),代表这一轮迭代使用多少负梯度进行更新。梯度下降非常简单有效,但是其中的原理是怎么样呢?

原理

为什么每次使用负梯度进行更新呢?这要从泰勒公式(Taylor's formula)说起:

f(x) = f(x_{0}) + \frac{​{f}'(x_{0})}{1!}(x-x_{0}) + \frac{f{}''(x_{0})}{2!}(x-x_{0}) + ...

泰勒公式的目的是使用x-x_{0}的多项式去逼近函数f(x),这里可以理解泰勒公式在x-x_{0}的展开是原函数的一个近似函数。

那泰勒公式跟梯度下降有什么关系呢?

我们的目标是使f(x_{n+1})\leq f(x_{n}),我们对f(x_{n+1})x_{n}处进行一阶泰勒展开:

f(x_{n+1}) \approx f(x_{n}) + {f}'(x_{n})(x_{n+1}-x_{n})

由此可知,我们只需令x_{n+1}-x_{n} = -{f}'(x_{n}),就会使f(x_{n+1})\leq f(x_{n})

所以迭代公式可以为x_{n+1}= x_{n} -\eta {f}'(x_{n})

案例

下面我们看具体例子,假设我们有以下函数

f(x) = \frac{1}{2}\left \| Ax-b \right \|^2

矩阵和A向量b已知,我们想知道当x取值为多少的时候,函数f(x)的值最小。

根据梯度下降法,我们只需计算出负梯度,然给定一个初始值x_{0},不断迭代就能找到一个近似解了。负梯度计算如下:

{f}'(x) =A ^{T}(Ax-b)=A ^{T}Ax-A ^{T}b

接下来让我写一段代码解决这个问题

定义梯度下降函数

首先,定义cal_gradient函数用来计算梯度,然后使用gradient_decent进行迭代,其中learning_rate就是公式中的\eta,这个值需要合理设置,过大的话会导致震荡,过下的话又会导致迭代时间过长。step代表迭代的次数,理想情况下找到满意的解就停止。

我们会在代码中调整这两个参数查看它们对求解过程的影响。

import numpy as np
import time#calculate gradient
def cal_gradient(A, b, x):left = np.dot(np.dot(A.T, A), x)right = np.dot(A.T, b)gradient = left - rightreturn gradient# iteration
def gradient_decent(x, A, b, learning_rate, step):start = time.time()for i in range(step):gradient = cal_gradient(A, b, x)delta = learning_rate * gradientx = x - deltaend = time.time()time_cost = round(end - start, 4)print('done! x = {a}, time cost = {b}s'.format(a=x, b=time_cost))

求解过程

我们给了矩阵和A向量b的值以及标准答案 [29, 16, 3],然后我们随机初始化一个x_{0},让学习率\eta =0.01,迭代次数step=1000000

A = np.array([[1.0, -2.0, 1.0], [0.0, 2.0, -8.0], [-4.0, 5.0, 9.0]])
b = np.array([0.0, 8.0, -9.0])
# Giveb A and b,the solution x is [29, 16, 3]x0 = np.array([1.0, 1.0, 1.0])
learning_rate = 0.01
step = 1000000gradient_decent(x0, A, b, learning_rate, step)

结果

以下为结果,可以看出求得的近似解和标准答案 [29, 16, 3]还是非常接近的。

done! x = [28.98272933 15.99042465  2.99763054], time cost = 4.6037s

调整学习率

其他参数都一样,我们让学习率变小,运行相同的步数,从以下结果看到求得的近似解跟标准答案还有一定差距。这意味着小的过小学习率需要学习更久的时间。

learning_rate = 0.001# result
# done! x = [15.8048349   8.68422815  1.18968306], time cost = 4.5997s

调整初始值

我们只调整初始值,学习相同的步数,发现求得的近似解尽管与标准答案相似,但是不如第一个方法求得解。这说明梯度下降方法也会受到初始值得影响。

x0 = np.array([1000, 1000, 1000])# result
# done! x = [29.78036839 16.43265826  3.10706301], time cost = 4.5528s

总结

梯度下降方法是一种非常有效的优化方法,它的效果会受到初始值、学习率、步数的影响。如果要说缺点的话,就是它容易找到局部最优解,有时候会发生震荡现象。

 

参考

https://sm1les.com/2019/03/01/gradient-descent-and-newton-method/

这篇关于梯度下降(Gradient Descent)原理以及Python代码的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1003996

相关文章

Python调用Orator ORM进行数据库操作

《Python调用OratorORM进行数据库操作》OratorORM是一个功能丰富且灵活的PythonORM库,旨在简化数据库操作,它支持多种数据库并提供了简洁且直观的API,下面我们就... 目录Orator ORM 主要特点安装使用示例总结Orator ORM 是一个功能丰富且灵活的 python O

Python使用国内镜像加速pip安装的方法讲解

《Python使用国内镜像加速pip安装的方法讲解》在Python开发中,pip是一个非常重要的工具,用于安装和管理Python的第三方库,然而,在国内使用pip安装依赖时,往往会因为网络问题而导致速... 目录一、pip 工具简介1. 什么是 pip?2. 什么是 -i 参数?二、国内镜像源的选择三、如何

Java调用DeepSeek API的最佳实践及详细代码示例

《Java调用DeepSeekAPI的最佳实践及详细代码示例》:本文主要介绍如何使用Java调用DeepSeekAPI,包括获取API密钥、添加HTTP客户端依赖、创建HTTP请求、处理响应、... 目录1. 获取API密钥2. 添加HTTP客户端依赖3. 创建HTTP请求4. 处理响应5. 错误处理6.

python使用fastapi实现多语言国际化的操作指南

《python使用fastapi实现多语言国际化的操作指南》本文介绍了使用Python和FastAPI实现多语言国际化的操作指南,包括多语言架构技术栈、翻译管理、前端本地化、语言切换机制以及常见陷阱和... 目录多语言国际化实现指南项目多语言架构技术栈目录结构翻译工作流1. 翻译数据存储2. 翻译生成脚本

如何通过Python实现一个消息队列

《如何通过Python实现一个消息队列》这篇文章主要为大家详细介绍了如何通过Python实现一个简单的消息队列,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录如何通过 python 实现消息队列如何把 http 请求放在队列中执行1. 使用 queue.Queue 和 reque

Python如何实现PDF隐私信息检测

《Python如何实现PDF隐私信息检测》随着越来越多的个人信息以电子形式存储和传输,确保这些信息的安全至关重要,本文将介绍如何使用Python检测PDF文件中的隐私信息,需要的可以参考下... 目录项目背景技术栈代码解析功能说明运行结php果在当今,数据隐私保护变得尤为重要。随着越来越多的个人信息以电子形

使用 sql-research-assistant进行 SQL 数据库研究的实战指南(代码实现演示)

《使用sql-research-assistant进行SQL数据库研究的实战指南(代码实现演示)》本文介绍了sql-research-assistant工具,该工具基于LangChain框架,集... 目录技术背景介绍核心原理解析代码实现演示安装和配置项目集成LangSmith 配置(可选)启动服务应用场景

使用Python快速实现链接转word文档

《使用Python快速实现链接转word文档》这篇文章主要为大家详细介绍了如何使用Python快速实现链接转word文档功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 演示代码展示from newspaper import Articlefrom docx import

Python Jupyter Notebook导包报错问题及解决

《PythonJupyterNotebook导包报错问题及解决》在conda环境中安装包后,JupyterNotebook导入时出现ImportError,可能是由于包版本不对应或版本太高,解决方... 目录问题解决方法重新安装Jupyter NoteBook 更改Kernel总结问题在conda上安装了

Python如何计算两个不同类型列表的相似度

《Python如何计算两个不同类型列表的相似度》在编程中,经常需要比较两个列表的相似度,尤其是当这两个列表包含不同类型的元素时,下面小编就来讲讲如何使用Python计算两个不同类型列表的相似度吧... 目录摘要引言数字类型相似度欧几里得距离曼哈顿距离字符串类型相似度Levenshtein距离Jaccard相