使用MapReduce实现k-means算法

2024-06-20 18:18

本文主要是介绍使用MapReduce实现k-means算法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

主要的算法流程就是:

(1)随机选择k个点,放到磁盘上供个个点进行共享

(2)每一个map读取中心点,每一条及记录找到最近的Cluster,发出的记录是<(id),(cluster)>,Reduce的功能就是重新计算新的k均值,并写到hdfs中,供下一次的迭代使用

(3)当迭代停止,根据最终的中心点,分配所有的点,形成最终的聚类。

以下是具体的代码:

package kmeans;


import java.io.DataInput;


/*
 * k-means聚类算法簇信息
 */
public class Cluster implements Writable {
private int clusterID;
private long numOfPoints;
private Instance center;


public Cluster() {
this.setClusterID(-1);
this.setNumOfPoints(0);
this.setCenter(new Instance());
}


public Cluster(int clusterID, Instance center) {
this.setClusterID(clusterID);
this.setNumOfPoints(0);
this.setCenter(center);
}


public Cluster(String line) {
String[] value = line.split(",", 3);
clusterID = Integer.parseInt(value[0]);
numOfPoints = Long.parseLong(value[1]);
center = new Instance(value[2]);
}


public String toString() {
String result = String.valueOf(clusterID) + ","
+ String.valueOf(numOfPoints) + "," + center.toString();
return result;
}


public int getClusterID() {
return clusterID;
}


public void setClusterID(int clusterID) {
this.clusterID = clusterID;
}


public long getNumOfPoints() {
return numOfPoints;
}


public void setNumOfPoints(long numOfPoints) {
this.numOfPoints = numOfPoints;
}


public Instance getCenter() {
return center;
}


public void setCenter(Instance center) {
this.center = center;
}


public void observeInstance(Instance instance) {
try {
Instance sum = center.multiply(numOfPoints).add(instance);
numOfPoints++;
center = sum.divide(numOfPoints);
} catch (Exception e) {
e.printStackTrace();
}
}


@Override
public void write(DataOutput out) throws IOException {
out.writeInt(clusterID);
out.writeLong(numOfPoints);
center.write(out);
}


@Override
public void readFields(DataInput in) throws IOException {
clusterID = in.readInt();
numOfPoints = in.readLong();
center.readFields(in);
}
}


package kmeans;


import java.io.DataInput;


public class Instance implements Writable {
ArrayList<Double> value;


public Instance() {
value = new ArrayList<Double>();
}


public Instance(String line) {
String[] valueString = line.split(",");
value = new ArrayList<Double>();
for (int i = 0; i < valueString.length; i++) {
value.add(Double.parseDouble(valueString[i]));
}
}


public Instance(Instance ins) {
value = new ArrayList<Double>();
for (int i = 0; i < ins.getValue().size(); i++) {
value.add(new Double(ins.getValue().get(i)));
}
}


public Instance(int k) {
value = new ArrayList<Double>();
for (int i = 0; i < k; i++) {
value.add(0.0);
}
}


public ArrayList<Double> getValue() {
return value;
}


public Instance add(Instance instance) {
if (value.size() == 0)
return new Instance(instance);
else if (instance.getValue().size() == 0)
return new Instance(this);
else if (value.size() != instance.getValue().size())
try {
throw new Exception("can not add! dimension not compatible!"
+ value.size() + "," + instance.getValue().size());
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
return null;
}
else {
Instance result = new Instance();
for (int i = 0; i < value.size(); i++) {
result.getValue()
.add(value.get(i) + instance.getValue().get(i));
}
return result;
}
}


public Instance multiply(double num) {
Instance result = new Instance();
for (int i = 0; i < value.size(); i++) {
result.getValue().add(value.get(i) * num);
}
return result;
}


public Instance divide(double num) {
Instance result = new Instance();
for (int i = 0; i < value.size(); i++) {
result.getValue().add(value.get(i) / num);
}
return result;
}


public String toString() {
String s = new String();
for (int i = 0; i < value.size() - 1; i++) {
s += (value.get(i) + ",");
}
s += value.get(value.size() - 1);
return s;
}


@Override
public void write(DataOutput out) throws IOException {
// TODO Auto-generated method stub
out.writeInt(value.size());
for (int i = 0; i < value.size(); i++) {
out.writeDouble(value.get(i));
}
}


@Override
public void readFields(DataInput in) throws IOException {
// TODO Auto-generated method stub
int size = 0;
value = new ArrayList<Double>();
if ((size = in.readInt()) != 0) {
for (int i = 0; i < size; i++) {
value.add(in.readDouble());
}
}
}
}


package kmeans;


import java.io.BufferedReader;


/**
 * KMeans聚类算法
 * 
 */
public class KMeans {
public static class KMeansMapper extends
Mapper<LongWritable, Text, IntWritable, Cluster> {
private ArrayList<Cluster> kClusters = new ArrayList<Cluster>();


/**
* 读入目前的簇信息
*/
@Override
protected void setup(Context context) throws IOException,
InterruptedException {
super.setup(context);
FileSystem fs = FileSystem.get(context.getConfiguration());
FileStatus[] fileList = fs.listStatus(new Path(context
.getConfiguration().get("clusterPath")));
BufferedReader in = null;
FSDataInputStream fsi = null;
String line = null;
for (int i = 0; i < fileList.length; i++) {
if (!fileList[i].isDir()) {
fsi = fs.open(fileList[i].getPath());
in = new BufferedReader(new InputStreamReader(fsi, "UTF-8"));
while ((line = in.readLine()) != null) {
System.out.println("read a line:" + line);
Cluster cluster = new Cluster(line);
cluster.setNumOfPoints(0);
kClusters.add(cluster);
}
}
}
in.close();
fsi.close();
}


/**
* 读取一行然后寻找离该点最近的簇发射(clusterID,instance)
*/
@Override
public void map(LongWritable key, Text value, Context context)
throws IOException, InterruptedException {
Instance instance = new Instance(value.toString());
int id;
try {
id = getNearest(instance);
if (id == -1)
throw new InterruptedException("id == -1");
else {
Cluster cluster = new Cluster(id, instance);
cluster.setNumOfPoints(1);
System.out.println("cluster that i emit is:"
+ cluster.toString());
context.write(new IntWritable(id), cluster);
}
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}


/**
* 返回离instance最近的簇的ID

* @param instance
* @return
* @throws Exception
*/
public int getNearest(Instance instance) throws Exception {
int id = -1;
double distance = Double.MAX_VALUE;
Distance<Double> distanceMeasure = new EuclideanDistance<Double>();
double newDis = 0.0;
for (Cluster cluster : kClusters) {
newDis = distanceMeasure.getDistance(cluster.getCenter()
.getValue(), instance.getValue());
if (newDis < distance) {
id = cluster.getClusterID();
distance = newDis;
}
}
return id;
}


public Cluster getClusterByID(int id) {
for (Cluster cluster : kClusters) {
if (cluster.getClusterID() == id)
return cluster;
}
return null;
}
}


public static class KMeansCombiner extends
Reducer<IntWritable, Cluster, IntWritable, Cluster> {
public void reduce(IntWritable key, Iterable<Cluster> value,
Context context) throws IOException, InterruptedException {
Instance instance = new Instance();
int numOfPoints = 0;
for (Cluster cluster : value) {
numOfPoints += cluster.getNumOfPoints();
System.out.println("cluster is:" + cluster.toString());
instance = instance.add(cluster.getCenter().multiply(
cluster.getNumOfPoints()));
}
Cluster cluster = new Cluster(key.get(), instance
.divide(numOfPoints));
cluster.setNumOfPoints(numOfPoints);
System.out.println("combiner emit cluster:" + cluster.toString());
context.write(key, cluster);
}
}


public static class KMeansReducer extends
Reducer<IntWritable, Cluster, NullWritable, Cluster> {
public void reduce(IntWritable key, Iterable<Cluster> value,
Context context) throws IOException, InterruptedException {
Instance instance = new Instance();
int numOfPoints = 0;
for (Cluster cluster : value) {
numOfPoints += cluster.getNumOfPoints();
instance = instance.add(cluster.getCenter().multiply(
cluster.getNumOfPoints()));
}
Cluster cluster = new Cluster(key.get(), instance
.divide(numOfPoints));
cluster.setNumOfPoints(numOfPoints);
context.write(NullWritable.get(), cluster);
}
}
}


package kmeans;


import java.io.BufferedReader;


/**
 * 在收敛条件满足且所有簇中心的文件最后产生后,再对输入文件 中的所有实例进行划分簇的工作,最后把所有实例按照(实例,簇id) 的方式写进结果文件
 * 
 * @author KING
 * 
 */
public class KMeansCluster {
public static class KMeansClusterMapper extends
Mapper<LongWritable, Text, Text, IntWritable> {
private ArrayList<Cluster> kClusters = new ArrayList<Cluster>();


/**
* 读入目前的簇信息
*/
@Override
protected void setup(Context context) throws IOException,
InterruptedException {
super.setup(context);
FileSystem fs = FileSystem.get(context.getConfiguration());
FileStatus[] fileList = fs.listStatus(new Path(context
.getConfiguration().get("clusterPath")));
BufferedReader in = null;
FSDataInputStream fsi = null;
String line = null;
for (int i = 0; i < fileList.length; i++) {
if (!fileList[i].isDir()) {
fsi = fs.open(fileList[i].getPath());
in = new BufferedReader(new InputStreamReader(fsi, "UTF-8"));
while ((line = in.readLine()) != null) {
System.out.println("read a line:" + line);
Cluster cluster = new Cluster(line);
cluster.setNumOfPoints(0);
kClusters.add(cluster);
}
}
}
in.close();
fsi.close();
}


/**
* 读取一行然后寻找离该点最近的簇id发射(instance,clusterID)
*/
@Override
public void map(LongWritable key, Text value, Context context)
throws IOException, InterruptedException {
Instance instance = new Instance(value.toString());
int id;
try {
id = getNearest(instance);
if (id == -1)
throw new InterruptedException("id == -1");
else {
context.write(value, new IntWritable(id));
}
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}


public int getNearest(Instance instance) throws Exception {
int id = -1;
double distance = Double.MAX_VALUE;
Distance<Double> distanceMeasure = new EuclideanDistance<Double>();
double newDis = 0.0;
for (Cluster cluster : kClusters) {
newDis = distanceMeasure.getDistance(cluster.getCenter()
.getValue(), instance.getValue());
if (newDis < distance) {
id = cluster.getClusterID();
distance = newDis;
}
}
return id;
}
}
}


package kmeans;


import java.io.IOException;


/**
 * 调度整个KMeans运行的过程
 * 
 */
public class KMeansDriver {
private int k;
private int iterationNum;
private String sourcePath;
private String outputPath;
private Configuration conf;


public KMeansDriver(int k, int iterationNum, String sourcePath,
String outputPath, Configuration conf) {
this.k = k;
this.iterationNum = iterationNum;
this.sourcePath = sourcePath;
this.outputPath = outputPath;
this.conf = conf;
}


public void clusterCenterJob() throws IOException, InterruptedException,
ClassNotFoundException {
for (int i = 0; i < iterationNum; i++) {
Job clusterCenterJob = new Job();
clusterCenterJob.setJobName("clusterCenterJob" + i);
clusterCenterJob.setJarByClass(KMeans.class);


clusterCenterJob.getConfiguration().set("clusterPath",
outputPath + "/cluster-" + i + "/");


clusterCenterJob.setMapperClass(KMeans.KMeansMapper.class);
clusterCenterJob.setMapOutputKeyClass(IntWritable.class);
clusterCenterJob.setMapOutputValueClass(Cluster.class);


clusterCenterJob.setCombinerClass(KMeans.KMeansCombiner.class);
clusterCenterJob.setReducerClass(KMeans.KMeansReducer.class);
clusterCenterJob.setOutputKeyClass(NullWritable.class);
clusterCenterJob.setOutputValueClass(Cluster.class);


FileInputFormat
.addInputPath(clusterCenterJob, new Path(sourcePath));
FileOutputFormat.setOutputPath(clusterCenterJob, new Path(
outputPath + "/cluster-" + (i + 1) + "/"));


clusterCenterJob.waitForCompletion(true);
System.out.println("finished!");
}
}


public void KMeansClusterJod() throws IOException, InterruptedException,
ClassNotFoundException {
Job kMeansClusterJob = new Job();
kMeansClusterJob.setJobName("KMeansClusterJob");
kMeansClusterJob.setJarByClass(KMeansCluster.class);


kMeansClusterJob.getConfiguration().set("clusterPath",
outputPath + "/cluster-" + (iterationNum - 1) + "/");


kMeansClusterJob
.setMapperClass(KMeansCluster.KMeansClusterMapper.class);
kMeansClusterJob.setMapOutputKeyClass(Text.class);
kMeansClusterJob.setMapOutputValueClass(IntWritable.class);


kMeansClusterJob.setNumReduceTasks(0);


FileInputFormat.addInputPath(kMeansClusterJob, new Path(sourcePath));
FileOutputFormat.setOutputPath(kMeansClusterJob, new Path(outputPath
+ "/clusteredInstances" + "/"));


kMeansClusterJob.waitForCompletion(true);
System.out.println("finished!");
}


public void generateInitialCluster() {
RandomClusterGenerator generator = new RandomClusterGenerator(conf,
sourcePath, k);
generator.generateInitialCluster(outputPath + "/");
}


public static void main(String[] args) throws IOException,
InterruptedException, ClassNotFoundException {
System.out.println("start");
Configuration conf = new Configuration();
int k = Integer.parseInt(args[0]);
int iterationNum = Integer.parseInt(args[1]);
String sourcePath = args[2];
String outputPath = args[3];
KMeansDriver driver = new KMeansDriver(k, iterationNum, sourcePath,
outputPath, conf);
driver.generateInitialCluster();
System.out.println("initial cluster finished");
driver.clusterCenterJob();
driver.KMeansClusterJod();
}
}


package kmeans;


import java.io.IOException;


/**
 * This class generates the initial Cluster centers as the input of successive
 * process. it randomly chooses k instances as the initial k centers and store
 * it as a sequenceFile.Specificly,we scan all the instances and each time when
 * we scan a new instance.we first check if the number of clusters no less than
 * k. we simply add current instance to our cluster if the condition is
 * satisfied or we will replace the first cluster with it with probability
 * 1/(currentNumber + 1).
 * 
 */
public final class RandomClusterGenerator {
private int k;
private FileStatus[] fileList;
private FileSystem fs;
private ArrayList<Cluster> kClusters;
private Configuration conf;


public RandomClusterGenerator(Configuration conf, String filePath, int k) {
this.k = k;
try {
fs = FileSystem.get(URI.create(filePath), conf);
fileList = fs.listStatus((new Path(filePath)));
kClusters = new ArrayList<Cluster>(k);
this.conf = conf;
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}


}


/**

* @param destinationPath
*            the destination Path we will store our cluster file in.the
*            initial file will be named clusters-0
*/
public void generateInitialCluster(String destinationPath) {
Text line = new Text();
FSDataInputStream fsi = null;
try {
for (int i = 0; i < fileList.length; i++) {
fsi = fs.open(fileList[i].getPath());
LineReader lineReader = new LineReader(fsi, conf);
while (lineReader.readLine(line) > 0) {
// 判断是否应该加入到中心集合中去
System.out.println("read a line:" + line);
Instance instance = new Instance(line.toString());
makeDecision(instance);
}
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} finally {
try {
// in.close();
fsi.close();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}


}


writeBackToFile(destinationPath);


}


public void makeDecision(Instance instance) {
if (kClusters.size() < k) {
Cluster cluster = new Cluster(kClusters.size() + 1, instance);
kClusters.add(cluster);
} else {
int choice = randomChoose(k);
if (!(choice == -1)) {
int id = kClusters.get(choice).getClusterID();
kClusters.remove(choice);
Cluster cluster = new Cluster(id, instance);
kClusters.add(cluster);
}
}
}


/**
* 以1/(1+k)的概率返回一个[0,k-1]中的正整数,以 k/k+1的概率返回-1.

* @param k
* @return
*/
public int randomChoose(int k) {
Random random = new Random();
if (random.nextInt(k + 1) == 0) {
return new Random().nextInt(k);
} else
return -1;
}


public void writeBackToFile(String destinationPath) {
// /clusters
Path path = new Path(destinationPath + "cluster-0");
FSDataOutputStream fsi = null;
try {
fsi = fs.create(path);
for (Cluster cluster : kClusters) {
fsi.write((cluster.toString() + "\n").getBytes());
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} finally {
try {
fsi.close();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}


}
}


数据:

2,1,3,4,1,4
3,2,5,2,3,5
4,4,4,3,1,5
2,3,1,2,0,3
4,0,1,1,1,5
1,2,3,5,0,1
5,3,2,2,1,3
3,4,1,1,2,1
0,2,3,3,1,4
0,2,5,0,2,2
2,1,4,5,4,3
4,1,4,3,3,2
0,3,2,2,0,1
1,3,1,0,3,0
3,3,4,2,1,3
3,5,3,5,3,2
2,3,2,3,0,1
4,3,3,2,2,3
1,4,3,4,3,1
3,2,3,0,2,5
1,0,2,1,0,4
4,4,3,5,5,4
5,1,4,3,5,2
3,4,4,4,1,1
2,2,4,4,5,5
5,2,0,3,1,3
1,1,3,1,1,3
2,4,2,0,3,5
1,1,1,1,0,4
1,1,4,1,3,0

这篇关于使用MapReduce实现k-means算法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1078888

相关文章

Python使用getopt处理命令行参数示例解析(最佳实践)

《Python使用getopt处理命令行参数示例解析(最佳实践)》getopt模块是Python标准库中一个简单但强大的命令行参数处理工具,它特别适合那些需要快速实现基本命令行参数解析的场景,或者需要... 目录为什么需要处理命令行参数?getopt模块基础实际应用示例与其他参数处理方式的比较常见问http

python实现svg图片转换为png和gif

《python实现svg图片转换为png和gif》这篇文章主要为大家详细介绍了python如何实现将svg图片格式转换为png和gif,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录python实现svg图片转换为png和gifpython实现图片格式之间的相互转换延展:基于Py

Python利用ElementTree实现快速解析XML文件

《Python利用ElementTree实现快速解析XML文件》ElementTree是Python标准库的一部分,而且是Python标准库中用于解析和操作XML数据的模块,下面小编就来和大家详细讲讲... 目录一、XML文件解析到底有多重要二、ElementTree快速入门1. 加载XML的两种方式2.

C 语言中enum枚举的定义和使用小结

《C语言中enum枚举的定义和使用小结》在C语言里,enum(枚举)是一种用户自定义的数据类型,它能够让你创建一组具名的整数常量,下面我会从定义、使用、特性等方面详细介绍enum,感兴趣的朋友一起看... 目录1、引言2、基本定义3、定义枚举变量4、自定义枚举常量的值5、枚举与switch语句结合使用6、枚

Java的栈与队列实现代码解析

《Java的栈与队列实现代码解析》栈是常见的线性数据结构,栈的特点是以先进后出的形式,后进先出,先进后出,分为栈底和栈顶,栈应用于内存的分配,表达式求值,存储临时的数据和方法的调用等,本文给大家介绍J... 目录栈的概念(Stack)栈的实现代码队列(Queue)模拟实现队列(双链表实现)循环队列(循环数组

使用Python从PPT文档中提取图片和图片信息(如坐标、宽度和高度等)

《使用Python从PPT文档中提取图片和图片信息(如坐标、宽度和高度等)》PPT是一种高效的信息展示工具,广泛应用于教育、商务和设计等多个领域,PPT文档中常常包含丰富的图片内容,这些图片不仅提升了... 目录一、引言二、环境与工具三、python 提取PPT背景图片3.1 提取幻灯片背景图片3.2 提取

C++如何通过Qt反射机制实现数据类序列化

《C++如何通过Qt反射机制实现数据类序列化》在C++工程中经常需要使用数据类,并对数据类进行存储、打印、调试等操作,所以本文就来聊聊C++如何通过Qt反射机制实现数据类序列化吧... 目录设计预期设计思路代码实现使用方法在 C++ 工程中经常需要使用数据类,并对数据类进行存储、打印、调试等操作。由于数据类

Python实现图片分割的多种方法总结

《Python实现图片分割的多种方法总结》图片分割是图像处理中的一个重要任务,它的目标是将图像划分为多个区域或者对象,本文为大家整理了一些常用的分割方法,大家可以根据需求自行选择... 目录1. 基于传统图像处理的分割方法(1) 使用固定阈值分割图片(2) 自适应阈值分割(3) 使用图像边缘检测分割(4)

Android实现在线预览office文档的示例详解

《Android实现在线预览office文档的示例详解》在移动端展示在线Office文档(如Word、Excel、PPT)是一项常见需求,这篇文章为大家重点介绍了两种方案的实现方法,希望对大家有一定的... 目录一、项目概述二、相关技术知识三、实现思路3.1 方案一:WebView + Office Onl

C# foreach 循环中获取索引的实现方式

《C#foreach循环中获取索引的实现方式》:本文主要介绍C#foreach循环中获取索引的实现方式,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录一、手动维护索引变量二、LINQ Select + 元组解构三、扩展方法封装索引四、使用 for 循环替代