tensorflow example 入门例子(线型回归与逻辑回归)

2023-12-29 06:30

本文主要是介绍tensorflow example 入门例子(线型回归与逻辑回归),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1. 前言–线性回归与逻辑回归介绍

tensorflow一般入门都至少会讲两种例子,一个是线型回归,一个是逻辑回归。(或者也可以说,回归算法 & 分类算法
线性回归用来做回归预测,逻辑回归用于做二分类,一个是解决回归问题,一个用于解决分类问题。两者区别:

拟合函数不同:

线性回归: f(x)=θTX=θ1x1+θ2x2++θnxn f ( x ) = θ T X = θ 1 x 1 + θ 2 x 2 + ⋯ + θ n x n
逻辑回归: f(x)=p(y=1x;θ)=g(θTX)g(z)=11+ez f ( x ) = p ( y = 1 ∣ x ; θ ) = g ( θ T X ) , 其 中 , g ( z ) = 1 1 + e − z 也就是第二个例子提的sigmod函数。

取值范围不同:
线性回归的样本的输出,都是连续值, y(+,)y(+,) y ∈ ( + ∞ , − ∞ ) y ∈ ( + ∞ , − ∞ ) 而,逻辑回归中 y0,1y0,1 y ∈ 0 , 1 y ∈ 0 , 1 ,只能取0和1。

在线性回归中 θTX θ T X 为预测值的拟合函数;而在逻辑回归中 θTX=0 θ T X = 0 为决策边界(<0则y < 0.5, >0则>0.5,正负无穷,则是1或0)。

2. 线型回归example

模拟样本,不去下载,免得懵逼。
样本的输入是x_vals,样本的输出是y_vals。然后,模型就是个线型函数:
y = w * x

上代码:

import tensorflow as tf
import numpy as np# 样本,输入列表,正太分布(Normal Destribution),均值为1, 均方误差为0.1, 数据量为100个
x_vals = np.random.normal(1, 0.1, 100)
# 样本输出列表, 100个值为10.0的列表
y_vals = np.repeat(10.0, 100)x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype= tf.float32)A = tf.Variable(tf.random_normal(shape=[1]))# 我们定义的模型,是一个线型函数,即 y = w * x, 也就是my_output = A * x_data
# x_data将用样本x_vals。我们的目标是,算出A的值。
# 其实已经能猜出,y都是10.0的话,x均值为1, 那么A应该是10。哈哈
my_output = tf.multiply(x_data, A)# 损失函数, 用的是模型算的值,减去实际值, 的平方。y_target就是上面的y_vals。
loss = tf.square(my_output - y_target)sess = tf.Session()
init = tf.global_variables_initializer()#初始化变量
sess.run(init)# 梯度下降算法, 学习率0.02, 可以认为每次迭代修改A,修改一次0.02。比如A初始化为20, 发现不好,于是猜测下一个A为20-0.02
my_opt = tf.train.GradientDescentOptimizer(learning_rate=0.02)
train_step = my_opt.minimize(loss)#目标,使得损失函数达到最小值for i in range(100):#0到100,不包括100# 随机从样本中取值rand_index = np.random.choice(100)rand_x = [x_vals[rand_index]]rand_y = [y_vals[rand_index]]#损失函数引用的placeholder(直接或间接用的都算), x_data使用样本rand_x, y_target用样本rand_ysess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})#打印if i%5==0:print('step: ' + str(i) + ' A = ' + str(sess.run(A)))print('loss: ' + str(sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})))

输出结果:

step: 0 A = [-0.29324722]
loss: [107.103676]
step: 5 A = [1.6392573]
loss: [69.70741]
step: 10 A = [3.1867485]
loss: [43.80286]
step: 15 A = [4.4426436]
loss: [32.31665]
step: 20 A = [5.454427]
loss: [28.393408]
step: 25 A = [6.2705126]
loss: [16.252668]
step: 30 A = [6.9050083]
loss: [6.105043]
step: 35 A = [7.4409676]
loss: [8.232471]
step: 40 A = [7.9324813]
loss: [6.8031826]
step: 45 A = [8.33953]
loss: [2.0388503]
step: 50 A = [8.59281]
loss: [0.87392443]
step: 55 A = [8.817221]
loss: [1.0634136]
step: 60 A = [9.096114]
loss: [3.1473236]
step: 65 A = [9.231487]
loss: [1.9277898]
step: 70 A = [9.415066]
loss: [0.22827132]
step: 75 A = [9.4759245]
loss: [0.3650696]
step: 80 A = [9.474044]
loss: [3.6430228]
step: 85 A = [9.543233]
loss: [0.00908985]
step: 90 A = [9.6931]
loss: [0.03487607]
step: 95 A = [9.785975]
loss: [0.1980833]

3. 逻辑回归example

也就是分类器的例子。
tensorflow很多数据源,都是从网上获取。比如经典的iris数据集。
鸢(拼音yuan)尾花的英文名就是iris。Iris数据集就是鸢尾花卉数据集,是一类多重变量分析的数据集。通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。

来个花的样子,形象一点。^^
这里写图片描述
这是数据集的详细信息。
这里写图片描述

以下是获取该数据的代码。


#sklearn是机器学习套件,有很多数据集
#安装:pip install -U scikit-learn
#      sklearn依赖python>=2.7, numpy(python擅长数组处理的数学库), scipy(python算法库和数据工具包)
from sklearn import datasets
iris = datasets.load_iris()
print('sample feature: feature_names: ' + str(iris.feature_names) + " data length: " + str(len(iris.data)))
print('sample target: target_names: ' + str(iris.target_names) + " target length: " + str(len(iris.target)))
#样本数据,一个150x4的二维列表
print(iris.data)
#样本标签,一个长度为150的一维列表
print(iris.target)

结果:

sample feature: feature_names: ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'] data length: 150
sample target: feature_names: ['setosa' 'versicolor' 'virginica'] target length: 150
[[5.1 3.5 1.4 0.2][4.9 3.  1.4 0.2][4.7 3.2 1.3 0.2][4.6 3.1 1.5 0.2]#....省略[6.  3.  4.8 1.8][6.9 3.1 5.4 2.1][6.7 3.1 5.6 2.4][6.9 3.1 5.1 2.3][5.8 2.7 5.1 1.9][6.8 3.2 5.9 2.3][6.7 3.3 5.7 2.5][6.7 3.  5.2 2.3][6.3 2.5 5.  1.9][6.5 3.  5.2 2. ][6.2 3.4 5.4 2.3][5.9 3.  5.1 1.8]]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 00 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 11 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 22 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 22 2]

因我们只要实现一个简单的二值分类器来预测一朵花是否为Setosa。所以数据需要稍微做一下转换。

代码如下:

#import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
#sklearn是机器学习套件,有很多数据集
#安装:pip install -U scikit-learn
#      sklearn依赖python>=2.7, numpy(python擅长数组处理的数学库), scipy(python算法库和数据工具包)
from sklearn import datasetsiris = datasets.load_iris()
print('sample feature: feature_names: ' + str(iris.feature_names) + " data length: " + str(len(iris.data)))
print('sample target: target_names: ' + str(iris.target_names) + " target length: " + str(len(iris.target)))
#样本数据,一个150x4的二维列表
#print(iris.data)
#样本标签,一个长度为150的一维列表
#print(iris.target)#抽取的样本标签, 只要第一种,是第一种,则为1,否则为0
temp = []
for x in iris.target:temp.append(1 if x== 0 else 0)
binary_target = np.array(temp)#列表转数组,以上几行,也可以写成:binary_target = np.array([1 if x== 0 else 0 for x in iris.target])
print('binary_target: ')
print(binary_target)#抽取的样本输入,只用两个参数,也就是花瓣长度和宽度
iris_2d = np.array([[x[2], x[3]] for x in iris.data])
print('iris_2d: ')
print(iris_2d)#批量训练大小为20
batch_size = 20
x1_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
x2_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))#定义模型
my_mult = tf.matmul(x2_data, A)
my_add = tf.add(my_mult, b)
my_output = tf.subtract(x1_data, my_add)#损失函数
xentropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=my_output, logits=y_target)
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(xentropy)#初始化变量
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)#开始迭代,更新模型,也就是计算出A和b
for i in range(1000):#从np.arange(len(iris_2d))生成大小为20的均匀随机样本,如:[ 66  42  96 115  45 127  31  70 148  57  60 127  56  96   7  63  75 127 110 144]rand_index = np.random.choice(len(iris_2d), size=batch_size)print('rand_index ' + str(rand_index))#rand_x为20x2的数组,类似酱紫 [[4.5 1.5] 。。。。[1.3 0.2]]rand_x = iris_2d[rand_index]#print(' rand_x ' + str(rand_x))print(' rand_x shape: ' + str(rand_x.shape))rand_x1 = np.array([[x[0]] for x in rand_x])rand_x2 = np.array([[x[1]] for x in rand_x])#print(' rand_x1 ' + str(rand_x1))#print(' rand_x2 ' + str(rand_x2))#rand_y如果直接使用binary_target[rand_index],则得到的是[0 0 0 0 0 0 1 0 0 0 0 1 1 1 0 1 1 0 0 0],是一维的, shape是(20,)也就是一维数组,数组有20个元素#但是我们想要多维数组,也就是shape为(20, 1),因为placeholder也是这样的维度#[[y] for y in binary_target[rand_index]] 得到的是[[0], [0], [0], [0], [0], [0], [1], [0], [1], [0], [0], [0], [0], [1], [1], [1], [1], [1], [0], [0]]#然后,转化为数组,维度是(20, 1)rand_y = np.array([[y] for y in binary_target[rand_index]])print('rand_y shape ' + str(rand_y.shape))print('rand_y  ' + str(rand_y))sess.run(train_step, feed_dict={x1_data: rand_x1, x2_data: rand_x2, y_target: rand_y})if(i+1)%200==0:print('step: ' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))

总结

  1. 生成样本数据
  2. 初始化占位符(用于喂样本数据)和变量(用于迭代自变量A,也就是我们要的模型数据)
  3. 创建损失函数
  4. 定义一个优化器算法
  5. 通过随机样本数据进行迭代,更新变量。

ps: 本文例子为了简单容易理解,所以没加上预测逻辑。要想看预测逻辑,请看下文分解。

这篇关于tensorflow example 入门例子(线型回归与逻辑回归)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/548622

相关文章

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

数论入门整理(updating)

一、gcd lcm 基础中的基础,一般用来处理计算第一步什么的,分数化简之类。 LL gcd(LL a, LL b) { return b ? gcd(b, a % b) : a; } <pre name="code" class="cpp">LL lcm(LL a, LL b){LL c = gcd(a, b);return a / c * b;} 例题:

Java 创建图形用户界面(GUI)入门指南(Swing库 JFrame 类)概述

概述 基本概念 Java Swing 的架构 Java Swing 是一个为 Java 设计的 GUI 工具包,是 JAVA 基础类的一部分,基于 Java AWT 构建,提供了一系列轻量级、可定制的图形用户界面(GUI)组件。 与 AWT 相比,Swing 提供了许多比 AWT 更好的屏幕显示元素,更加灵活和可定制,具有更好的跨平台性能。 组件和容器 Java Swing 提供了许多

【IPV6从入门到起飞】5-1 IPV6+Home Assistant(搭建基本环境)

【IPV6从入门到起飞】5-1 IPV6+Home Assistant #搭建基本环境 1 背景2 docker下载 hass3 创建容器4 浏览器访问 hass5 手机APP远程访问hass6 更多玩法 1 背景 既然电脑可以IPV6入站,手机流量可以访问IPV6网络的服务,为什么不在电脑搭建Home Assistant(hass),来控制你的设备呢?@智能家居 @万物互联

poj 2104 and hdu 2665 划分树模板入门题

题意: 给一个数组n(1e5)个数,给一个范围(fr, to, k),求这个范围中第k大的数。 解析: 划分树入门。 bing神的模板。 坑爹的地方是把-l 看成了-1........ 一直re。 代码: poj 2104: #include <iostream>#include <cstdio>#include <cstdlib>#include <al

MySQL-CRUD入门1

文章目录 认识配置文件client节点mysql节点mysqld节点 数据的添加(Create)添加一行数据添加多行数据两种添加数据的效率对比 数据的查询(Retrieve)全列查询指定列查询查询中带有表达式关于字面量关于as重命名 临时表引入distinct去重order by 排序关于NULL 认识配置文件 在我们的MySQL服务安装好了之后, 会有一个配置文件, 也就

音视频入门基础:WAV专题(10)——FFmpeg源码中计算WAV音频文件每个packet的pts、dts的实现

一、引言 从文章《音视频入门基础:WAV专题(6)——通过FFprobe显示WAV音频文件每个数据包的信息》中我们可以知道,通过FFprobe命令可以打印WAV音频文件每个packet(也称为数据包或多媒体包)的信息,这些信息包含该packet的pts、dts: 打印出来的“pts”实际是AVPacket结构体中的成员变量pts,是以AVStream->time_base为单位的显

C语言指针入门 《C语言非常道》

C语言指针入门 《C语言非常道》 作为一个程序员,我接触 C 语言有十年了。有的朋友让我推荐 C 语言的参考书,我不敢乱推荐,尤其是国内作者写的书,往往七拼八凑,漏洞百出。 但是,李忠老师的《C语言非常道》值得一读。对了,李老师有个官网,网址是: 李忠老师官网 最棒的是,有配套的教学视频,可以试看。 试看点这里 接下来言归正传,讲解指针。以下内容很多都参考了李忠老师的《C语言非

MySQL入门到精通

一、创建数据库 CREATE DATABASE 数据库名称; 如果数据库存在,则会提示报错。 二、选择数据库 USE 数据库名称; 三、创建数据表 CREATE TABLE 数据表名称; 四、MySQL数据类型 MySQL支持多种类型,大致可以分为三类:数值、日期/时间和字符串类型 4.1 数值类型 数值类型 类型大小用途INT4Bytes整数值FLOAT4By

【QT】基础入门学习

文章目录 浅析Qt应用程序的主函数使用qDebug()函数常用快捷键Qt 编码风格信号槽连接模型实现方案 信号和槽的工作机制Qt对象树机制 浅析Qt应用程序的主函数 #include "mywindow.h"#include <QApplication>// 程序的入口int main(int argc, char *argv[]){// argc是命令行参数个数,argv是