Java 实现 BP 神经网络完成 Iris 数据分类

2024-05-23 21:32

本文主要是介绍Java 实现 BP 神经网络完成 Iris 数据分类,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

继了解了 BP 神经网络的原理后,笔者之前用 Java 实现三层的 BP 神经网络完成 Iris 鸢尾花数据集的分类预测,特此记录了实现过程,附源码。

1. Iris 鸢尾花数据集

Iris 也称鸢尾花卉数据集,是一类多重变量分析的数据集,来自 UCI 机器学习库,下载地址请戳这里。通过 sepal length(花萼长度),sepal width (花萼宽度),petal length (花瓣长度),petal width (花瓣宽度)4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。

这里写图片描述

该数据集一共有150条记录,选取 Iris 数据集中的124条数据作为训练集,剩余的26条数据作为测试集。

注:选取训练集时尽量覆盖全面,不要出现只包含一类的情况。

测试集
这里写图片描述

训练集
这里写图片描述

2. BP 算法模型的建立

  • 输入层和输出层节点数量分别为数据集的属性数量和类别数量,采用一个隐层,隐层节点数=√(输入节点数+输出节点数)+5求得;
  • 激活函数选择单极性S型函数;
  • 学习率 η η =0.5;
  • 初始权值随机生成,值在-0.5~0.5之间,初始阈值设为0;
  • 设置最大训练次数为2000次;
  • 误差允许范围:Iris:0.015;
  • 动量常数 α α =0.1;
  • 输入数据归一化处理:(0.1,0.9)范围内;
  • 输出层节点处理,进行one-hot编程:
    这里写图片描述

3. Java 实现代码

一共包含三个类: BPNN.java 、DataUtil.java 、Test.java

BPNN.java

BP 神经网络核心代码以及预测处理代码,注释部分是附加动量项的处理代码:

import java.io.IOException;
import java.util.ArrayList;
import java.util.Random;class BPNN {// private static int LAYER = 3; // 三层神经网络private static int NodeNum = 10; // 每层的最多节点数private static final int ADJUST = 5; // 隐层节点数调节常数private static final int MaxTrain = 2000; // 最大训练次数private static final double ACCU = 0.015; // 每次迭代允许的误差 iris:0.015private double ETA_W = 0.5; // 权值学习效率0.5private double ETA_T = 0.5; // 阈值学习效率private double accu;// 附加动量项//private static final double ETA_A = 0.3; // 动量常数0.1//private double[][] in_hd_last; // 上一次的权值调整量//private double[][] hd_out_last;private int in_num; // 输入层节点数private int hd_num; // 隐层节点数private int out_num; // 输入出节点数private ArrayList<ArrayList<Double>> list = new ArrayList<>(); // 输入输出数据private double[][] in_hd_weight; // BP网络in-hidden突触权值private double[][] hd_out_weight; // BP网络hidden_out突触权值private double[] in_hd_th; // BP网络in-hidden阈值private double[] hd_out_th; // BP网络hidden-out阈值private double[][] out; // 每个神经元的值经S型函数转化后的输出值,输入层就为原值private double[][] delta; // delta学习规则中的值// 获得网络三层中神经元最多的数量public int GetMaxNum() {return Math.max(Math.max(in_num, hd_num), out_num);}// 设置权值学习率public void SetEtaW() {ETA_W = 0.5;}// 设置阈值学习率public void SetEtaT() {ETA_T = 0.5;}// BPNN训练public void Train(int in_number, int out_number,ArrayList<ArrayList<Double>> arraylist) throws IOException {list = arraylist;in_num = in_number;out_num = out_number;GetNums(in_num, out_num); // 获取输入层、隐层、输出层的节点数// SetEtaW(); // 设置学习率// SetEtaT();InitNetWork(); // 初始化网络的权值和阈值int datanum = list.size(); // 训练数据的组数int createsize = GetMaxNum(); // 比较创建存储每一层输出数据的数组out = new double[3][createsize];for (int iter = 0; iter < MaxTrain; iter++) {for (int cnd = 0; cnd < datanum; cnd++) {// 第一层输入节点赋值for (int i = 0; i < in_num; i++) {out[0][i] = list.get(cnd).get(i); // 为输入层节点赋值,其输入与输出相同}Forward(); // 前向传播Backward(cnd); // 误差反向传播}System.out.println("This is the " + (iter + 1)+ " th trainning NetWork !");accu = GetAccu();System.out.println("All Samples Accuracy is " + accu);if (accu < ACCU)break;}}// 获取输入层、隐层、输出层的节点数,in_number、out_number分别为输入层节点数和输出层节点数public void GetNums(int in_number, int out_number) {in_num = in_number;out_num = out_number;hd_num = (int) Math.sqrt(in_num + out_num) + ADJUST;if (hd_num > NodeNum)hd_num = NodeNum; // 隐层节点数不能大于最大节点数}// 初始化网络的权值和阈值public void InitNetWork() {// 初始化上一次权值量,范围为-0.5-0.5之间//in_hd_last = new double[in_num][hd_num];//hd_out_last = new double[hd_num][out_num];in_hd_weight = new double[in_num][hd_num];for (int i = 0; i < in_num; i++)for (int j = 0; j < hd_num; j++) {int flag = 1; // 符号标志位(-1或者1)if ((new Random().nextInt(2)) == 1)flag = 1;elseflag = -1;in_hd_weight[i][j] = (new Random().nextDouble() / 2) * flag; // 初始化in-hidden的权值//in_hd_last[i][j] = 0;}hd_out_weight = new double[hd_num][out_num];for (int i = 0; i < hd_num; i++)for (int j = 0; j < out_num; j++) {int flag = 1; // 符号标志位(-1或者1)if ((new Random().nextInt(2)) == 1)flag = 1;elseflag = -1;hd_out_weight[i][j] = (new Random().nextDouble() / 2) * flag; // 初始化hidden-out的权值//hd_out_last[i][j] = 0;}// 阈值均初始化为0in_hd_th = new double[hd_num];for (int k = 0; k < hd_num; k++)in_hd_th[k] = 0;hd_out_th = new double[out_num];for (int k = 0; k < out_num; k++)hd_out_th[k] = 0;}// 计算单个样本的误差public double GetError(int cnd) {double ans = 0;for (int i = 0; i < out_num; i++)ans += 0.5 * (out[2][i] - list.get(cnd).get(in_num + i))* (out[2][i] - list.get(cnd).get(in_num + i));return ans;}// 计算所有样本的平均精度public double GetAccu() {double ans = 0;int num = list.size();for (int i = 0; i < num; i++) {int m = in_num;for (int j = 0; j < m; j++)out[0][j] = list.get(i).get(j);Forward();int n = out_num;for (int k = 0; k < n; k++)ans += 0.5 * (list.get(i).get(in_num + k) - out[2][k])* (list.get(i).get(in_num + k) - out[2][k]);}return ans / num;}// 前向传播public void Forward() {// 计算隐层节点的输出值for (int j = 0; j < hd_num; j++) {double v = 0;for (int i = 0; i < in_num; i++)v += in_hd_weight[i][j] * out[0][i];v += in_hd_th[j];out[1][j] = Sigmoid(v);}// 计算输出层节点的输出值for (int j = 0; j < out_num; j++) {double v = 0;for (int i = 0; i < hd_num; i++)v += hd_out_weight[i][j] * out[1][i];v += hd_out_th[j];out[2][j] = Sigmoid(v);}}// 误差反向传播public void Backward(int cnd) {CalcDelta(cnd); // 计算权值调整量UpdateNetWork(); // 更新BP神经网络的权值和阈值}// 计算delta调整量public void CalcDelta(int cnd) {int createsize = GetMaxNum(); // 比较创建数组delta = new double[3][createsize];// 计算输出层的delta值for (int i = 0; i < out_num; i++) {delta[2][i] = (list.get(cnd).get(in_num + i) - out[2][i])* SigmoidDerivative(out[2][i]);}// 计算隐层的delta值for (int i = 0; i < hd_num; i++) {double t = 0;for (int j = 0; j < out_num; j++)t += hd_out_weight[i][j] * delta[2][j];delta[1][i] = t * SigmoidDerivative(out[1][i]);}}// 更新BP神经网络的权值和阈值public void UpdateNetWork() {// 隐含层和输出层之间权值和阀值调整for (int i = 0; i < hd_num; i++) {for (int j = 0; j < out_num; j++) {hd_out_weight[i][j] += ETA_W * delta[2][j] * out[1][i]; // 未加权值动量项/* 动量项* hd_out_weight[i][j] += (ETA_A * hd_out_last[i][j] + ETA_W* delta[2][j] * out[1][i]); hd_out_last[i][j] = ETA_A ** hd_out_last[i][j] + ETA_W delta[2][j] * out[1][i];*/}}for (int i = 0; i < out_num; i++)hd_out_th[i] += ETA_T * delta[2][i];// 输入层和隐含层之间权值和阀值调整for (int i = 0; i < in_num; i++) {for (int j = 0; j < hd_num; j++) {in_hd_weight[i][j] += ETA_W * delta[1][j] * out[0][i]; // 未加权值动量项/* 动量项* in_hd_weight[i][j] += (ETA_A * in_hd_last[i][j] + ETA_W* delta[1][j] * out[0][i]); in_hd_last[i][j] = ETA_A ** in_hd_last[i][j] + ETA_W delta[1][j] * out[0][i];*/}}for (int i = 0; i < hd_num; i++)in_hd_th[i] += ETA_T * delta[1][i];}// 符号函数signpublic int Sign(double x) {if (x > 0)return 1;else if (x < 0)return -1;elsereturn 0;}// 返回最大值public double Maximum(double x, double y) {if (x >= y)return x;elsereturn y;}// 返回最小值public double Minimum(double x, double y) {if (x <= y)return x;elsereturn y;}// log-sigmoid函数public double Sigmoid(double x) {return (double) (1 / (1 + Math.exp(-x)));}// log-sigmoid函数的倒数public double SigmoidDerivative(double y) {return (double) (y * (1 - y));}// tan-sigmoid函数public double TSigmoid(double x) {return (double) ((1 - Math.exp(-x)) / (1 + Math.exp(-x)));}// tan-sigmoid函数的倒数public double TSigmoidDerivative(double y) {return (double) (1 - (y * y));}// 分类预测函数public ArrayList<ArrayList<Double>> ForeCast(ArrayList<ArrayList<Double>> arraylist) {ArrayList<ArrayList<Double>> alloutlist = new ArrayList<>();ArrayList<Double> outlist = new ArrayList<Double>();int datanum = arraylist.size();for (int cnd = 0; cnd < datanum; cnd++) {for (int i = 0; i < in_num; i++)out[0][i] = arraylist.get(cnd).get(i); // 为输入节点赋值Forward();for (int i = 0; i < out_num; i++) {if (out[2][i] > 0 && out[2][i] < 0.5)out[2][i] = 0;else if (out[2][i] > 0.5 && out[2][i] < 1) {out[2][i] = 1;}outlist.add(out[2][i]);}alloutlist.add(outlist);outlist = new ArrayList<Double>();outlist.clear();}return alloutlist;}}

DataUtil.java

数据处理类,将训练数据和测试数据进行处理。

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;class DataUtil {private ArrayList<ArrayList<Double>> alllist = new ArrayList<ArrayList<Double>>(); // 存放所有数据private ArrayList<String> outlist = new ArrayList<String>(); // 存放输出数据,索引对应每个everylist的输出private ArrayList<String> checklist = new ArrayList<String>();  //存放测试集的真实输出字符串private int in_num = 0;private int out_num = 0; // 输入输出数据的个数private int type_num = 0; // 输出的类型数量private double[][] nom_data; //归一化输入数据中的最大值和最小值private int in_data_num = 0; //提前获得输入数据的个数// 获取输出类型的个数public int GetTypeNum() {return type_num;}// 设置输出类型的个数public void SetTypeNum(int type_num) {this.type_num = type_num;}// 获取输入数据的个数public int GetInNum() {return in_num;}// 获取输出数据的个数public int GetOutNum() {return out_num;}// 获取所有数据的数组public ArrayList<ArrayList<Double>> GetList() {return alllist;}// 获取输出为字符串形式的数据public ArrayList<String> GetOutList() {return outlist;}// 获取输出为字符串形式的数据public ArrayList<String> GetCheckList() {return checklist;}//返回归一化数据所需最大最小值public double[][] GetMaxMin(){return nom_data;}// 读取文件初始化数据public void ReadFile(String filepath, String sep, int flag)throws Exception {ArrayList<Double> everylist = new ArrayList<Double>(); // 存放每一组输入输出数据int readflag = flag; // flag=0,train;flag=1,testString encoding = "GBK";File file = new File(filepath);if (file.isFile() && file.exists()) { // 判断文件是否存在InputStreamReader read = new InputStreamReader(new FileInputStream(file), encoding);// 考虑到编码格式BufferedReader bufferedReader = new BufferedReader(read);String lineTxt = null;while ((lineTxt = bufferedReader.readLine()) != null) {int in_number = 0;String splits[] = lineTxt.split(sep); // 按','截取字符串if (readflag == 0) {for (int i = 0; i < splits.length; i++)try {everylist.add(Normalize(Double.valueOf(splits[i]),nom_data[i][0],nom_data[i][1]));in_number++;} catch (Exception e) {if (!outlist.contains(splits[i]))outlist.add(splits[i]); // 存放字符串形式的输出数据for (int k = 0; k < type_num; k++) {everylist.add(0.0);}everylist.set(in_number + outlist.indexOf(splits[i]),1.0);}} else if (readflag == 1) {for (int i = 0; i < splits.length; i++)try {everylist.add(Normalize(Double.valueOf(splits[i]),nom_data[i][0],nom_data[i][1]));in_number++;} catch (Exception e) {checklist.add(splits[i]); // 存放字符串形式的输出数据}}alllist.add(everylist); // 存放所有数据in_num = in_number;out_num = type_num;everylist = new ArrayList<Double>();everylist.clear();}bufferedReader.close();}}//向文件写入分类结果public void WriteFile(String filepath, ArrayList<ArrayList<Double>> list, int in_number,  ArrayList<String> resultlist) throws IOException{File file = new File(filepath);FileWriter fw = null;BufferedWriter writer = null;try {fw = new FileWriter(file);writer = new BufferedWriter(fw);for(int i=0;i<list.size();i++){for(int j=0;j<in_number;j++)writer.write(list.get(i).get(j)+",");writer.write(resultlist.get(i));writer.newLine();}writer.flush();} catch (IOException e) {e.printStackTrace();}finally{writer.close();fw.close();}}//学习样本归一化,找到输入样本数据的最大值和最小值public void NormalizeData(String filepath) throws IOException{//提前获得输入数据的个数   GetBeforIn(filepath);int flag=1;nom_data = new double[in_data_num][2];String encoding = "GBK";File file = new File(filepath);if (file.isFile() && file.exists()) { // 判断文件是否存在InputStreamReader read = new InputStreamReader(new FileInputStream(file), encoding);// 考虑到编码格式BufferedReader bufferedReader = new BufferedReader(read);String lineTxt = null;while ((lineTxt = bufferedReader.readLine()) != null) {String splits[] = lineTxt.split(","); // 按','截取字符串for (int i = 0; i < splits.length-1; i++){if(flag==1){nom_data[i][0]=Double.valueOf(splits[i]);nom_data[i][1]=Double.valueOf(splits[i]);}else{if(Double.valueOf(splits[i])>nom_data[i][0])nom_data[i][0]=Double.valueOf(splits[i]);if(Double.valueOf(splits[i])<nom_data[i][1])nom_data[i][1]=Double.valueOf(splits[i]);}}flag=0;}bufferedReader.close();}}//归一化前获得输入数据的个数public void GetBeforIn(String filepath) throws IOException{String encoding = "GBK";File file = new File(filepath);if (file.isFile() && file.exists()) { // 判断文件是否存在InputStreamReader read = new InputStreamReader(new FileInputStream(file), encoding);// 考虑到编码格式//提前获得输入数据的个数BufferedReader beforeReader = new BufferedReader(read);String beforetext = beforeReader.readLine();String splits[] = beforetext.split(",");in_data_num = splits.length-1;beforeReader.close();}}//归一化公式public double Normalize(double x, double max, double min){double y = 0.1+0.8*(x-min)/(max-min);return y;}
}

Test.java

import java.util.ArrayList;public class Test {public static void main(String args[]) throws Exception {ArrayList<ArrayList<Double>> alllist = new ArrayList<ArrayList<Double>>(); // 存放所有数据ArrayList<String> outlist = new ArrayList<String>(); // 存放分类的字符串int in_num = 0, out_num = 0; // 输入输出数据的个数DataUtil dataUtil = new DataUtil(); // 初始化数据dataUtil.NormalizeData("E:\\BP_data\\train.txt");dataUtil.SetTypeNum(3); // 设置输出类型的数量dataUtil.ReadFile("E:\\BP_data\\train.txt", ",", 0);in_num = dataUtil.GetInNum(); // 获得输入数据的个数out_num = dataUtil.GetOutNum(); // 获得输出数据的个数(个数代表类型个数)alllist = dataUtil.GetList(); // 获得初始化后的数据outlist = dataUtil.GetOutList();System.out.print("分类的类型:");for(int i =0 ;i<outlist.size();i++)System.out.print(outlist.get(i)+"  ");System.out.println();System.out.println("训练集的数量:"+alllist.size());BPNN bpnn = new BPNN();// 训练System.out.println("Train Start!");System.out.println(".............");bpnn.Train(in_num, out_num, alllist);System.out.println("Train End!");// 测试DataUtil testUtil = new DataUtil();testUtil.NormalizeData("E:\\BP_data\\test.txt");testUtil.SetTypeNum(3); // 设置输出类型的数量testUtil.ReadFile("E:\\BP_data\\test.txt", ",", 1);ArrayList<ArrayList<Double>> testList = new ArrayList<ArrayList<Double>>();ArrayList<ArrayList<Double>> resultList = new ArrayList<ArrayList<Double>>();ArrayList<String> normallist = new ArrayList<String>(); // 存放测试集标准的输出字符串ArrayList<String> resultlist = new ArrayList<String>(); // 存放测试集计算后的输出字符串double right = 0; // 分类正确的数量int type_num = 0; // 类型的数量double all_num = 0; //测试集的数量type_num = outlist.size();testList = testUtil.GetList(); // 获取测试数据normallist = testUtil.GetCheckList(); int errorcount=0; // 分类错误的数量resultList = bpnn.ForeCast(testList); // 测试all_num=resultList.size();for (int i = 0; i < resultList.size(); i++) {String checkString = "unknow";for (int j = 0; j < type_num; j++) {if(resultList.get(i).get(j)==1.0){checkString = outlist.get(j);resultlist.add(checkString);}/*else{resultlist.add(checkString);}*/}/*if(checkString.equals("unknow"))errorcount++;*/if(checkString.equals(normallist.get(i)))right++;}testUtil.WriteFile("E:\\BP_data\\result.txt",testList,in_num,resultlist);System.out.println("测试集的数量:"+ (new Double(all_num)).intValue());System.out.println("分类正确的数量:"+(new Double(right)).intValue());System.out.println("算法的分类正确率为:"+right/all_num);System.out.println("分类结果存储在:E:\\BP_data\\result.txt");      }
}

在这里笔者只通过 Java 代码建立了 BP 神经网络的基本模型,实现 Iris 数据集的分类预测,效果如下:

这里写图片描述
….
这里写图片描述

其实,也可以用交叉预测去判断模型的分类性能。通过简单的代码可以对 BP 神经网络的数学原理有一个更好的巩固。

这篇关于Java 实现 BP 神经网络完成 Iris 数据分类的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/996465

相关文章

JVM 的类初始化机制

前言 当你在 Java 程序中new对象时,有没有考虑过 JVM 是如何把静态的字节码(byte code)转化为运行时对象的呢,这个问题看似简单,但清楚的同学相信也不会太多,这篇文章首先介绍 JVM 类初始化的机制,然后给出几个易出错的实例来分析,帮助大家更好理解这个知识点。 JVM 将字节码转化为运行时对象分为三个阶段,分别是:loading 、Linking、initialization

Spring Security 基于表达式的权限控制

前言 spring security 3.0已经可以使用spring el表达式来控制授权,允许在表达式中使用复杂的布尔逻辑来控制访问的权限。 常见的表达式 Spring Security可用表达式对象的基类是SecurityExpressionRoot。 表达式描述hasRole([role])用户拥有制定的角色时返回true (Spring security默认会带有ROLE_前缀),去

浅析Spring Security认证过程

类图 为了方便理解Spring Security认证流程,特意画了如下的类图,包含相关的核心认证类 概述 核心验证器 AuthenticationManager 该对象提供了认证方法的入口,接收一个Authentiaton对象作为参数; public interface AuthenticationManager {Authentication authenticate(Authenti

Spring Security--Architecture Overview

1 核心组件 这一节主要介绍一些在Spring Security中常见且核心的Java类,它们之间的依赖,构建起了整个框架。想要理解整个架构,最起码得对这些类眼熟。 1.1 SecurityContextHolder SecurityContextHolder用于存储安全上下文(security context)的信息。当前操作的用户是谁,该用户是否已经被认证,他拥有哪些角色权限…这些都被保

Spring Security基于数据库验证流程详解

Spring Security 校验流程图 相关解释说明(认真看哦) AbstractAuthenticationProcessingFilter 抽象类 /*** 调用 #requiresAuthentication(HttpServletRequest, HttpServletResponse) 决定是否需要进行验证操作。* 如果需要验证,则会调用 #attemptAuthentica

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

Java架构师知识体认识

源码分析 常用设计模式 Proxy代理模式Factory工厂模式Singleton单例模式Delegate委派模式Strategy策略模式Prototype原型模式Template模板模式 Spring5 beans 接口实例化代理Bean操作 Context Ioc容器设计原理及高级特性Aop设计原理Factorybean与Beanfactory Transaction 声明式事物

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

基于MySQL Binlog的Elasticsearch数据同步实践

一、为什么要做 随着马蜂窝的逐渐发展,我们的业务数据越来越多,单纯使用 MySQL 已经不能满足我们的数据查询需求,例如对于商品、订单等数据的多维度检索。 使用 Elasticsearch 存储业务数据可以很好的解决我们业务中的搜索需求。而数据进行异构存储后,随之而来的就是数据同步的问题。 二、现有方法及问题 对于数据同步,我们目前的解决方案是建立数据中间表。把需要检索的业务数据,统一放到一张M

关于数据埋点,你需要了解这些基本知识

产品汪每天都在和数据打交道,你知道数据来自哪里吗? 移动app端内的用户行为数据大多来自埋点,了解一些埋点知识,能和数据分析师、技术侃大山,参与到前期的数据采集,更重要是让最终的埋点数据能为我所用,否则可怜巴巴等上几个月是常有的事。   埋点类型 根据埋点方式,可以区分为: 手动埋点半自动埋点全自动埋点 秉承“任何事物都有两面性”的道理:自动程度高的,能解决通用统计,便于统一化管理,但个性化定