【机器学习】逻辑斯谛回归模型实现

2023-10-31 08:30

本文主要是介绍【机器学习】逻辑斯谛回归模型实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 数据准备
  • 逻辑斯谛回归模型
  • 模型参数估计
  • 总结
  • 参考


数据准备

本文实现的是二项逻辑斯谛回归模型,因此使用的是处理过后的两类别数据 mnist_binary.csv,表中对原手写数据中0~4取作负类 -1,将5~9取作正类 +1。

另根据逻辑斯谛回归模型按条件概率分布定义:
P ( Y = 1 ∣ x ) = e x p ( w ⋅ x ) 1 + e x p ( w ⋅ x ) P(Y=1|x)=\frac{exp(w\cdot x)}{1 + exp(w\cdot x)} P(Y=1∣x)=1+exp(wx)exp(wx)
P ( Y = 0 ∣ x ) = 1 1 + e x p ( w ⋅ x ) P(Y=0|x)=\frac{1}{1 + exp(w\cdot x)} P(Y=0∣x)=1+exp(wx)1

Y的取值应为0,1,因此需要将表中的-1类转换为0后再进行训练;此外由于要计算指数函数,特征取值过多会导致指数函数计算过程中的溢出,因此还需要将图像数据进行二值化操作。此部分直接在代码中完成,就不生成相应的数据集了。


逻辑斯谛回归模型

上面提到的逻辑斯谛回归模型的条件概率分布定义,可以看作是模型将线性函数 w ⋅ x w\cdot x wx通过其定义式转换为概率表现形式:
P ( Y = 1 ∣ x ) = e x p ( w ⋅ x ) 1 + e x p ( w ⋅ x ) P(Y=1|x)=\frac{exp(w\cdot x)}{1 + exp(w\cdot x)} P(Y=1∣x)=1+exp(wx)exp(wx)

上式中表示事情发生的概率,在线性函数趋近于无穷大时,概率值越接近于1;线性函数趋近于负无穷时,概率值就接近于0;函数图像如下所示,模型的临界点在线性函数为零时,条件概率值为0.5。
逻辑斯谛

逻辑斯谛回归模型也可以推广至多分类,见总结部分。


模型参数估计

设上述逻辑斯谛回归模型可改写为如下格式:

P ( Y = 1 ∣ x ) = π ( x ) , P ( Y = 0 ∣ x ) = 1 − π ( x ) P(Y=1|x)=\pi(x),P(Y=0|x)=1-\pi(x) P(Y=1∣x)=π(x)P(Y=0∣x)=1π(x)

其似然函数为:

∏ i = 1 N [ π ( x i ) ] y i [ 1 − π ( x i ) ] 1 − y i \prod_{i=1}^{N}[\pi(x_i)]^{y_i}[1-\pi(x_i)]^{1-y_i} i=1N[π(xi)]yi[1π(xi)]1yi

对数似然函数:
L ( w ) = ∑ i = 1 N [ y i l o g π ( x i ) + ( 1 − y i ) l o g ( 1 − π ( x i ) ) ] = ∑ i = 1 N [ y i l o g π ( x i ) 1 − π ( x i ) + l o g ( 1 − π ( x i ) ) ] = ∑ i = 1 N [ y i ( w ⋅ x i ) − l o g ( 1 + e x p ( w ⋅ x i ) ) ] = y i ( w ⋅ x ) − l o g ( 1 + e x p ( w ⋅ x ) ) \begin{aligned} L(w) &= \sum_{i=1}^N[y_ilog\pi(x_i) + (1 - y_i)log(1 - \pi(x_i))] \\ &=\sum_{i=1}^N[y_ilog\frac{\pi(x_i)}{1 - \pi(x_i)} + log(1 - \pi(x_i))] \\ &=\sum_{i=1}^N[y_i(w\cdot x_i) - log(1 + exp(w\cdot x_i))] \\ &=y_i(w\cdot x) - log(1+exp(w\cdot x)) \end{aligned} L(w)=i=1N[yilogπ(xi)+(1yi)log(1π(xi))]=i=1N[yilog1π(xi)π(xi)+log(1π(xi))]=i=1N[yi(wxi)log(1+exp(wxi))]=yi(wx)log(1+exp(wx))

利用随机梯度下降方法优化算法,以向量形式对权重进行求导:
∂ L ( w ) ∂ w = y i x − x ⋅ e x p ( w ⋅ x ) 1 + e x p ( w ⋅ x ) = x [ y i − e x p ( w ⋅ x ) 1 + e x p ( w ⋅ x ) ] \begin{aligned} \frac{\partial L(w)}{\partial w} &= y_ix - \frac{x\cdot exp(w\cdot x)}{1+exp(w\cdot x)} \\ &=x[y_i - \frac{exp(w\cdot x)}{1 + exp(w\cdot x)}] \end{aligned} wL(w)=yix1+exp(wx)xexp(wx)=x[yi1+exp(wx)exp(wx)]

每次迭代过程中更新权重参数:

w = w + α ∂ L ( w ) ∂ w w = w + \alpha\frac{\partial L(w)}{\partial w} w=w+αwL(w)


根据上述算法步骤,可以发现基于随机梯度下降法的二项逻辑斯谛回归和基于梯度下降法的感知机模型学习算法流程基本一致,区别在于参数步骤的更新方式。另外在判别过程中:感知机采用符号函数Sgin,逻辑斯谛回归采用逻辑斯谛分布Sigmoid进行计算,可参考感知机模型学习原始算法。

具体实现代码如下:

# @Author: phd
# @Date: 2019-08-18
# @Site: github.com/phdsky
# @Description: NULLimport time
import logging
import numpy as np
import pandas as pdfrom sklearn.model_selection import train_test_split
from sklearn.preprocessing import Binarizerdef log(func):def wrapper(*args, **kwargs):start_time = time.time()ret = func(*args, **kwargs)end_time = time.time()logging.debug('%s() cost %s seconds' % (func.__name__, end_time - start_time))return retreturn wrapperdef calc_accuracy(y_pred, y_truth):assert len(y_pred) == len(y_truth)n = len(y_pred)hit_count = 0for i in range(0, n):if y_pred[i] == y_truth[i]:hit_count += 1print("Predicting accuracy %f" % (hit_count / n))class LogisticRegression(object):def __init__(self, w, b, learning_rate, max_epoch, learning_period, learning_ratio):self.weight = wself.bias = bself.lr_rate = learning_rateself.max_epoch = max_epochself.lr_period = learning_periodself.lr_ratio = learning_ratiodef calculate(self, feature):# wx = sum([self.weight[j] * feature[j] for j in range(len(self.weight))])wx = np.dot(self.weight.transpose(), feature)exp_wx = np.exp(wx)predicted = 0 if (1 / (1 + exp_wx)) > 0.5 else 1return predicted, exp_wx@logdef train(self, X_train, y_train):# Fuse weight with biasself.weight = np.full((len(X_train[0]), 1), self.weight, dtype=float)self.weight = np.row_stack((self.weight, self.bias))epoch = 0while epoch < self.max_epoch:hit_count = 0data_count = len(X_train)for i in range(data_count):feature = X_train[i].reshape([len(X_train[i]), 1])feature = np.row_stack((feature, 1))label = y_train[i]predicted, exp_wx = self.calculate(feature)if predicted == label:hit_count += 1continue# for k in range(len(self.weight)):#     self.weight[k] += self.lr_rate * (label*feature[k] - ((feature[k] * exp_wx) / (1 + exp_wx)))self.weight += self.lr_rate * feature * (label - (exp_wx / (1 + exp_wx)))epoch += 1print("\rEpoch %d, lr_rate=%f, Acc = %f" % (epoch, self.lr_rate, hit_count / data_count), end='')# Decay learning rateif epoch % self.lr_period == 0:self.lr_rate /= self.lr_ratio# Stop trainingif self.lr_rate <= 1e-6:print("\nLearning rate is too low, Early stopping...\n")break@logdef predict(self, X_test):n = len(X_test)predict_label = np.full(n, -1)for i in range(0, n):to_predict = X_test[i].reshape([len(X_test[i]), 1])vec_predict = np.row_stack((to_predict, 1))predict_label[i], _ = self.calculate(vec_predict)return predict_labelif __name__ == "__main__":logger = logging.getLogger()logger.setLevel(logging.DEBUG)mnist_data = pd.read_csv("../data/mnist_binary.csv")mnist_values = mnist_data.valuesimages = mnist_values[::, 1::]labels = mnist_values[::, 0]X_train, X_test, y_train, y_test = train_test_split(images, labels, test_size=0.33, random_state=42)# Handle all -1 in y_train to 0y_train = y_train * (y_train == 1)y_test = y_test * (y_test == 1)# Binary the image to avoid predict_probability gets 0binarizer_train = Binarizer(threshold=127).fit(X_train)X_train_binary = binarizer_train.transform(X_train)binarizer_test = Binarizer(threshold=127).fit(X_test)X_test_binary = binarizer_test.transform(X_test)lr = LogisticRegression(w=0, b=1, learning_rate=0.001, max_epoch=100,learning_period=10, learning_ratio=3)print("Logistic regression training...")lr.train(X_train=X_train_binary, y_train=y_train)print("\nTraining done...")print("Testing on %d samples..." % len(X_test))y_predicted = lr.predict(X_test=X_test_binary)calc_accuracy(y_pred=y_predicted, y_truth=y_test)

代码输出

/Users/phd/Softwares/anaconda3/bin/python /Users/phd/Desktop/ML/logistic_regression/logistic_regression.py
Logistic regression training...
Epoch 70, lr_rate=0.000001, Acc = 0.818479
Learning rate is too low, Early stopping...Training done...
Testing on 13860 samples...
DEBUG:root:train() cost 38.08758902549744 seconds
Predicting accuracy 0.831097
DEBUG:root:predict() cost 0.2131938934326172 secondsProcess finished with exit code 0

从结果可以看出,在图像二值化后逻辑斯谛算法的训练和测试精度都在80%+,算法效果较好;预测结果优于直接使用原始数据的感知机模型。


总结

  1. 逻辑斯谛回归模型是一种分类模型
  2. 逻辑斯谛回归是由输入线性函数表示的输出对数几率模型;其模型定义由如下条件概率分布表示:(将二项推广为多项模型)

{ P ( Y = k ∣ x ) = e x p ( w k ⋅ x ) 1 + ∑ k = 1 K − 1 e x p ( w k ⋅ x ) , k = 1 , 2 , . . . , K − 1 P ( Y = K ∣ x ) = 1 1 + ∑ k = 1 K − 1 e x p ( w k ⋅ x ) \left\{ \begin{aligned} P(Y=k|x) &= \frac{exp(w_k\cdot x)}{1 + \sum\limits_{k=1}^{K-1}exp(w_k\cdot x)}, k=1,2,...,K-1 \\ P(Y=K|x) &= \frac{1}{1 + \sum\limits_{k=1}^{K-1}exp(w_k\cdot x)} \end{aligned} \right. P(Y=kx)P(Y=Kx)=1+k=1K1exp(wkx)exp(wkx),k=1,2,...,K1=1+k=1K1exp(wkx)1


参考

  1. 《统计学习方法》

这篇关于【机器学习】逻辑斯谛回归模型实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/314073

相关文章

使用Sentinel自定义返回和实现区分来源方式

《使用Sentinel自定义返回和实现区分来源方式》:本文主要介绍使用Sentinel自定义返回和实现区分来源方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Sentinel自定义返回和实现区分来源1. 自定义错误返回2. 实现区分来源总结Sentinel自定

Java实现时间与字符串互相转换详解

《Java实现时间与字符串互相转换详解》这篇文章主要为大家详细介绍了Java中实现时间与字符串互相转换的相关方法,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录一、日期格式化为字符串(一)使用预定义格式(二)自定义格式二、字符串解析为日期(一)解析ISO格式字符串(二)解析自定义

opencv图像处理之指纹验证的实现

《opencv图像处理之指纹验证的实现》本文主要介绍了opencv图像处理之指纹验证的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学... 目录一、简介二、具体案例实现1. 图像显示函数2. 指纹验证函数3. 主函数4、运行结果三、总结一、

Springboot处理跨域的实现方式(附Demo)

《Springboot处理跨域的实现方式(附Demo)》:本文主要介绍Springboot处理跨域的实现方式(附Demo),具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不... 目录Springboot处理跨域的方式1. 基本知识2. @CrossOrigin3. 全局跨域设置4.

Spring Boot 3.4.3 基于 Spring WebFlux 实现 SSE 功能(代码示例)

《SpringBoot3.4.3基于SpringWebFlux实现SSE功能(代码示例)》SpringBoot3.4.3结合SpringWebFlux实现SSE功能,为实时数据推送提供... 目录1. SSE 简介1.1 什么是 SSE?1.2 SSE 的优点1.3 适用场景2. Spring WebFlu

基于SpringBoot实现文件秒传功能

《基于SpringBoot实现文件秒传功能》在开发Web应用时,文件上传是一个常见需求,然而,当用户需要上传大文件或相同文件多次时,会造成带宽浪费和服务器存储冗余,此时可以使用文件秒传技术通过识别重复... 目录前言文件秒传原理代码实现1. 创建项目基础结构2. 创建上传存储代码3. 创建Result类4.

SpringBoot日志配置SLF4J和Logback的方法实现

《SpringBoot日志配置SLF4J和Logback的方法实现》日志记录是不可或缺的一部分,本文主要介绍了SpringBoot日志配置SLF4J和Logback的方法实现,文中通过示例代码介绍的非... 目录一、前言二、案例一:初识日志三、案例二:使用Lombok输出日志四、案例三:配置Logback一

Python如何使用__slots__实现节省内存和性能优化

《Python如何使用__slots__实现节省内存和性能优化》你有想过,一个小小的__slots__能让你的Python类内存消耗直接减半吗,没错,今天咱们要聊的就是这个让人眼前一亮的技巧,感兴趣的... 目录背景:内存吃得满满的类__slots__:你的内存管理小助手举个大概的例子:看看效果如何?1.

Python+PyQt5实现多屏幕协同播放功能

《Python+PyQt5实现多屏幕协同播放功能》在现代会议展示、数字广告、展览展示等场景中,多屏幕协同播放已成为刚需,下面我们就来看看如何利用Python和PyQt5开发一套功能强大的跨屏播控系统吧... 目录一、项目概述:突破传统播放限制二、核心技术解析2.1 多屏管理机制2.2 播放引擎设计2.3 专

Python实现无痛修改第三方库源码的方法详解

《Python实现无痛修改第三方库源码的方法详解》很多时候,我们下载的第三方库是不会有需求不满足的情况,但也有极少的情况,第三方库没有兼顾到需求,本文将介绍几个修改源码的操作,大家可以根据需求进行选择... 目录需求不符合模拟示例 1. 修改源文件2. 继承修改3. 猴子补丁4. 追踪局部变量需求不符合很