Keras入门教程 ——1.线性回归建模(快速入门)

2024-06-19 16:04

本文主要是介绍Keras入门教程 ——1.线性回归建模(快速入门),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Keras入门教程

  • 1.线性回归建模(快速入门)
  • 2.线性模型的优化
  • 3.波士顿房价回归(MPL)
  • 4.卷积神经网络(CNN)
  • 5.使用LSTM RNN进行时间序列预测
  • 6.Keras 预训练模型应用

线性回归建模(快速入门)


前言


Keras 是何物?Keras 是一个用 Python 编写的高级神经网络 API。其是以TesorFlow作为后端运行的。我们安装深度学习框架tensorflow时自动安装的,并非单独安装,作为tesorflow的API存在,使用起来非常方便。本文先用sklearn 线性回归模型,引入深入学习的keras进行建模。为了更好的快速入门深度学习的keras,仅从程序(代码)实现方面对深度学习有个感性的认识。看完这系列文章,你对机器学习中深度学习有一个更深的理解。

导入包
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import  mean_squared_error
%matplotlib inline

%matplotlib inline的作用:在Jupyter Notebook中使用Matplotlib绘图时,如果不使用%matplotlib inline,图形将会在一个新的浏览器窗口中打开。这对于用户来说可能并不方便,因为这意味着每次执行绘图代码时都需要手动打开新窗口。使用%matplotlib inline可以将Matplotlib图形直接嵌入到Notebook单元格的输出中,这样用户就不需要手动打开新窗口了。

数据导入

income数据集为例,由于数据比较小,所以直接写入代码。

data=pd.DataFrame(columns=['Education','Income'],data=[[10.00000,26.65884],
[10.40134,27.30644],
[10.84281,22.13241],
[11.24415,21.16984],
[11.64548,15.19263],
[12.08696,26.39895],
[12.48829,17.43531],
[12.88963,25.50789],
[13.29097,36.88459],
[13.73244,39.66611],
[14.13378,34.39628],
[14.53512,41.49799],
[14.97659,44.98157],
[15.37793,47.03960],
[15.77926,48.25258],
[16.22074,57.03425],
[16.62207,51.49092],
[17.02341,61.33662],
[17.46488,57.58199],
[17.86622,68.55371],
[18.26756,64.31093],
[18.70903,68.95901],
[19.11037,74.61464],
[19.51171,71.86720],
[19.91304,76.09814],
[20.35452,75.77522],
[20.75585,72.48606],
[21.15719,77.35502],
[21.59866,72.11879],
[22.00000,80.26057]])
# 可以查看data内容
data
 数据可视化
plt.scatter(data.Education,data.Income);

 分离数据
X=data.Education.values.reshape(-1,1)
y=data.Income
Sklearn 建模
model_lr=LinearRegression()
model_lr.fit(X,y)
 查看线性相关属性
print( "斜率:", model_lr.coef_[0] ," 截距:",model_lr.intercept_)
print("R^2=",model_lr.score(X,y))
MSE=mean_squared_error(y,y_pred)
print("MSE:",MSE)
斜率: 5.599483656931067 截距: -39.44626851089707
R^2= 0.9309626013230593
MSE: 29.828741902209323
进行预测
y_pred=model_lr.predict(X)
画回归曲线
plt.scatter(X,y)
plt.plot(X,y_pred,"r")

keras 建模
from keras.models import Sequential
from keras.layers import Densemodel_kr = Sequential()
model_kr.add(Dense(1,input_shape=(1,),activation='linear'))
查看模型
model_kr.summary()
Model: “sequential”

Layer (type) Output Shape Param #
dense (Dense) (None, 1) 2

=================================================================

Total params: 2
Trainable params: 2
Non-trainable params: 0
选择损失函数和优化方法
model_kr.compile(optimizer='adam' , loss='mse')
model_kr.fit(X , y , epochs=200 , verbose=1)


进行200次的结果如下

Output exceeds the size limit. Open the full output data in a text editor
Epoch 1/200
1/1 [==============================] - 0s 394ms/step - loss: 835.3273
Epoch 2/200
1/1 [==============================] - 0s 6ms/step - loss: 834.3919
Epoch 3/200
1/1 [==============================] - 0s 9ms/step - loss: 833.4571
Epoch 4/200
1/1 [==============================] - 0s 5ms/step - loss: 832.5229
Epoch 5/200
1/1 [==============================] - 0s 10ms/step - loss: 831.5893
Epoch 6/200
1/1 [==============================] - 0s 9ms/step - loss: 830.6564
Epoch 7/200
1/1 [==============================] - 0s 14ms/step - loss: 829.7242
Epoch 8/200
1/1 [==============================] - 0s 4ms/step - loss: 828.7927
Epoch 9/200
1/1 [==============================] - 0s 7ms/step - loss: 827.8617
Epoch 10/200
1/1 [==============================] - 0s 8ms/step - loss: 826.9315
Epoch 11/200
1/1 [==============================] - 0s 7ms/step - loss: 826.0019
Epoch 12/200
1/1 [==============================] - 0s 7ms/step - loss: 825.0732
Epoch 13/200
...
Epoch 199/200
1/1 [==============================] - 0s 6ms/step - loss: 666.1816
Epoch 200/200
1/1 [==============================] - 0s 9ms/step - loss: 665.4104

查看线性相关属性
W , b = model_kr.layers[0].get_weights()
print('线性回归的斜率和截距: %.2f, b: %.2f' % (W, b))
线性回归的斜率和截距: 1.82, b: 0.19
yks_pred=model_kr.predict(X)
MSE=model_kr.evaluate(y,yks_pred)
print("MSE1:",MSE)
1/1 [==============================] - 0s 99ms/step - loss: 4833.7842
MSE1: 4833.7841796875

惊奇的发现与上面做的结果相差很大

 查看回归曲线
plt.scatter(X,y)
plt.plot(X,yks_pred,"b",label='200 epochs')
plt.legend()

画一下MSE曲线
plt.plot(history.epoch,history.history.get('loss'),label="loss")
plt.xlabel("epoch")
plt.ylabel("MSE")
plt.legend()

200 次远远没有达到理想结果

增加epoch次数

再多次运行fit结果如下:

再次提醒千万别把epoch调得很大。

结论

这个结果远没有原始线性回归好。
但是,有没有发现这里根本没有“深度”,只加了一层,根本没有发挥深度学习的优势。因此在实际的模型中,会加入多层(深度),进行建模。
下一篇,在此基础上,增加层和激活函数,优化此模型,进入下一篇。

这篇关于Keras入门教程 ——1.线性回归建模(快速入门)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1075487

相关文章

从入门到精通C++11 <chrono> 库特性

《从入门到精通C++11<chrono>库特性》chrono库是C++11中一个非常强大和实用的库,它为时间处理提供了丰富的功能和类型安全的接口,通过本文的介绍,我们了解了chrono库的基本概念... 目录一、引言1.1 为什么需要<chrono>库1.2<chrono>库的基本概念二、时间段(Durat

解析C++11 static_assert及与Boost库的关联从入门到精通

《解析C++11static_assert及与Boost库的关联从入门到精通》static_assert是C++中强大的编译时验证工具,它能够在编译阶段拦截不符合预期的类型或值,增强代码的健壮性,通... 目录一、背景知识:传统断言方法的局限性1.1 assert宏1.2 #error指令1.3 第三方解决

Linux如何快速检查服务器的硬件配置和性能指标

《Linux如何快速检查服务器的硬件配置和性能指标》在运维和开发工作中,我们经常需要快速检查Linux服务器的硬件配置和性能指标,本文将以CentOS为例,介绍如何通过命令行快速获取这些关键信息,... 目录引言一、查询CPU核心数编程(几C?)1. 使用 nproc(最简单)2. 使用 lscpu(详细信

从入门到精通MySQL 数据库索引(实战案例)

《从入门到精通MySQL数据库索引(实战案例)》索引是数据库的目录,提升查询速度,主要类型包括BTree、Hash、全文、空间索引,需根据场景选择,建议用于高频查询、关联字段、排序等,避免重复率高或... 目录一、索引是什么?能干嘛?核心作用:二、索引的 4 种主要类型(附通俗例子)1. BTree 索引(

Redis 配置文件使用建议redis.conf 从入门到实战

《Redis配置文件使用建议redis.conf从入门到实战》Redis配置方式包括配置文件、命令行参数、运行时CONFIG命令,支持动态修改参数及持久化,常用项涉及端口、绑定、内存策略等,版本8... 目录一、Redis.conf 是什么?二、命令行方式传参(适用于测试)三、运行时动态修改配置(不重启服务

MySQL DQL从入门到精通

《MySQLDQL从入门到精通》通过DQL,我们可以从数据库中检索出所需的数据,进行各种复杂的数据分析和处理,本文将深入探讨MySQLDQL的各个方面,帮助你全面掌握这一重要技能,感兴趣的朋友跟随小... 目录一、DQL 基础:SELECT 语句入门二、数据过滤:WHERE 子句的使用三、结果排序:ORDE

一文详解如何在idea中快速搭建一个Spring Boot项目

《一文详解如何在idea中快速搭建一个SpringBoot项目》IntelliJIDEA作为Java开发者的‌首选IDE‌,深度集成SpringBoot支持,可一键生成项目骨架、智能配置依赖,这篇文... 目录前言1、创建项目名称2、勾选需要的依赖3、在setting中检查maven4、编写数据源5、开启热

Python中OpenCV与Matplotlib的图像操作入门指南

《Python中OpenCV与Matplotlib的图像操作入门指南》:本文主要介绍Python中OpenCV与Matplotlib的图像操作指南,本文通过实例代码给大家介绍的非常详细,对大家的学... 目录一、环境准备二、图像的基本操作1. 图像读取、显示与保存 使用OpenCV操作2. 像素级操作3.

MybatisX快速生成增删改查的方法示例

《MybatisX快速生成增删改查的方法示例》MybatisX是基于IDEA的MyBatis/MyBatis-Plus开发插件,本文主要介绍了MybatisX快速生成增删改查的方法示例,文中通过示例代... 目录1 安装2 基本功能2.1 XML跳转2.2 代码生成2.2.1 生成.xml中的sql语句头2

8种快速易用的Python Matplotlib数据可视化方法汇总(附源码)

《8种快速易用的PythonMatplotlib数据可视化方法汇总(附源码)》你是否曾经面对一堆复杂的数据,却不知道如何让它们变得直观易懂?别慌,Python的Matplotlib库是你数据可视化的... 目录引言1. 折线图(Line Plot)——趋势分析2. 柱状图(Bar Chart)——对比分析3