Hadoop实现KNN算法

本人java基础较弱,有什么需要改进的欢迎大家评论

一.环境

ubuntu虚拟机,使用的是伪分布式的hadoop集群(对于做实验使用伪分布式的更方便),代码通过eclipse来提交
在这里插入图片描述

二.数据说明

使用的是著名的鸢尾花数据集。据集内包含 3 类共 150 条记录,每类各 50 个数据,每条记录都有 4 项特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度,可以通过这4个特征预测鸢尾花卉属于(iris-setosa, iris-versicolour, iris-virginica)中的哪一品种。

链接:https://pan.baidu.com/s/11ocUk8hFrT7oCQMvQBL85A
提取码:seqm
在这里插入图片描述
其中训练集有120条记录,测试集有30条记录

iris_train.csv:包括属性和标签
在这里插入图片描述


iris_test_data.csv:只有测试集的属性
在这里插入图片描述


iris_test_lable.csv:只有测试集的标签,用来检验结果的好坏
在这里插入图片描述

三.MapReduce设计

1.KNN算法的基本思想即传统KNN算法的的性能瓶颈

KNN算法是一种很简单的分类算法,不需要构建模型,直接通过训练数据对测试数据进行分类,首先要定义一种度量样本之间距离的方法,对于鸢尾花数据集我所选用的是欧氏距离,只需要计算测试数据到所有训练数据的欧式距离,然后升序排列,取前N个数据的标签通过投票的方式来决定测试数据的样本,哪个类票数多就分为哪个类。
传统KNN算法在分析大数据的时候,当训练数据或者测试数据很大的时候,由于单机内存有限和单机计算资源有限,导致传统KNN算法失效,所以我们需要对其进行并行化实现。

2.并行化KNN设计思想

KNN的并行化设计主要由三种情况,分别是训练数据量大,测试数据量大,训练数据量和测试数据量都大的情况,这里我分析的是训练数据量大测试数据量小的情况。
MapReduce的核心思想是分而治之,KNN之所以能够实现并行化是因为每个训练样本不受其他训练样本影响。因为训练数据量大所以将训练数据分布式存储读入map,因为测试数据量小所以将测试数据量作为全局文件读入,在map中每输入一个训练样本就计算它和所有测试数据的距离并传到reduce,然后reduce将同一个测试数据的距离合并然后排序计数得到标签,再输出。

3.map函数设计

4.reduce函数设计

在这里插入图片描述

四.实现步骤

1.main函数

为了代码的可扩展性,KNN算法的N值通过configuration参数传入。训练集是将其放在分布式文件系统上的,而测试集的属性和标签都是将其地址传入conf中,之后在MapReduce中读取

public static void main(String[] args) throws Exception {
    Configuration conf = new Configuration();
    //String[] otherArgs = new GenericOptionsParser(conf, args).getRemainingArgs();
    //考虑的是测试集少量的情况,所以将测试集和测试集的标签的文件位置传入conf,在mapreduce中读取
    conf.setStrings("test", "hdfs://localhost:9000/iris/iris_test_data.csv");
    conf.setStrings("label", "hdfs://localhost:9000/iris/iris_test_lable.csv");
    // 从命令行传入参数N
    conf.setInt("n", Integer.parseInt(args[0]));
    String[] otherArgs = new String[]{"hdfs://localhost:9000/iris/iris_train.csv","hdfs://localhost:9000/iris/output/"};
    if (otherArgs.length < 2) {
      System.err.println("Usage: wordcount <in> [<in>...] <out>");
      System.exit(2);
    }
    Job job = Job.getInstance(conf, "KNN");
    job.setJarByClass(KNN.class);
    job.setMapperClass(TokenizerMapper.class);
    job.setReducerClass(IntSumReducer.class);
    job.setOutputKeyClass(IntWritable.class);
    job.setOutputValueClass(Text.class);
    for (int i = 0; i < otherArgs.length - 1; ++i) {
      //由于是训练集大的情况,所以将训练集输入
      FileInputFormat.addInputPath(job, new Path(otherArgs[i]));
    }
    FileOutputFormat.setOutputPath(job,
      new Path(otherArgs[otherArgs.length - 1]));
    System.exit(job.waitForCompletion(true) ? 0 : 1);
  }

2.map函数实现

防止测试集在map重复读取而降低效率,所以在setup函数中读取,一共只需要执行一次,将每一个测试数据的属性作为一个大小为4的数据然后再将每个测试数据的数组放到一个数组中,代表所有的测试数据。
在map中key是行偏移量(在这里不需要使用),value是每一行的数据,需要先根据","来分割,然后将属性转换为double型放在一个数组里,标签单独存起来。计算每一个测试数据与该训练数据的欧氏距离,将测试数据的索引作为key,将该训练数据的标签以及两者的距离作为value传入reduce。

public static class TokenizerMapper 
       extends Mapper<Object, Text, IntWritable,Text >{
    // 存放测试集路径
    private String localFiles;
    // 存放测试数据
    private List test = new ArrayList();
    @Override
	public void setup(Context context) throws IOException,InterruptedException{
    	Configuration conf = context.getConfiguration();
		// 获取测试集所在的hdfs路径
		localFiles  = conf.getStrings("test")[0];
		FileSystem fs = FileSystem.get(URI.create(localFiles), conf);  
		FSDataInputStream hdfsInStream = fs.open(new Path(localFiles));  
		// 从hdfs中读取测试集
		InputStreamReader isr = new InputStreamReader(hdfsInStream, "utf-8");	
		String line;
		BufferedReader br = new BufferedReader(isr);
		while ((line = br.readLine()) != null) {
			StringTokenizer itr = new StringTokenizer(line);
			while (itr.hasMoreTokens()) {
//				System.out.println(itr.nextToken().split(",").getClass().getName().toString());
				//每一行作为一个数组
				String[] tmp = itr.nextToken().split(",");
				List data = new ArrayList();
				for (String i : tmp){
					data.add(Double.parseDouble(i));
				}
				test.add(data);
			}
		}
		// 存储了所有的测试集
		System.out.println("测试数据");
		System.out.println(test);
    }
      
    public void map(Object key, Text value, Context context
                    ) throws IOException, InterruptedException {
      StringTokenizer itr = new StringTokenizer(value.toString());
      while (itr.hasMoreTokens()) {
    	    // 将训练数据分割
    	    String[] tmp = itr.nextToken().split(",");
    	    // 记录该训练集的标签
    	    String label = tmp[4];
    	    // 记录该训练集的属性值
			List data = new ArrayList();
			for (int i = 0;i<=3;i++){
				data.add(Double.parseDouble(tmp[i]));
			}
//			System.out.println(label);
//			System.out.println(data);
			for (int i = 0;i<test.size();i++){
				// 获得每个测试数据
				List tmp2 = (List) test.get(i);
				// 每个测试数据和训练数据的距离(这里使用欧氏距离)
				double dis = 0;
				for (int j=0;j<4;j++){
					dis += Math.pow( (double)tmp2.get(j)-(double)data.get(j),2);
				}
				dis = Math.sqrt(dis);
				// out 为类标签,距离
				String out = label + "," + String.valueOf(dis);
//				System.out.println(out.toString());
				// i为测试数据的标号
//				System.out.println(i);
				context.write(new IntWritable(i), new Text(out));
			}
      }
    }
  }

3.reduce函数实现

因为java不是很熟悉,所以用的方法比较笨。
在setup中读取测试数据的标签并且读取了N值。
在reduce中将values的每个值放入list中并进行排序,重写了排序compare方法,因为每个value是"标签,距离",让compare方法根据","后面的值来排序。排序后取前N个值的标签放入一个list中对其中的标签进行统计,数量最多的作为该测试数据的预测标签。然后将测试数据的索引值作为key,预测标签和真实标签作为value写入输出文件。

public static class IntSumReducer 
       extends Reducer<IntWritable,Text,IntWritable,Text> {
	  
	    private String localFiles;
	    private List tgt = new ArrayList();
	    private int n;
	    //读取测试集的标签
	    @Override
		public void setup(Context context) throws IOException,InterruptedException{
	    	Configuration conf = context.getConfiguration();
			// 获取测试集标签所在的hdfs路径
			localFiles  = conf.getStrings("label")[0];
			// 读取n值
			n = conf.getInt("n", 3);
			FileSystem fs = FileSystem.get(URI.create(localFiles), conf);  
			FSDataInputStream hdfsInStream = fs.open(new Path(localFiles));  
			// 从hdfs中读取测试集
			InputStreamReader isr = new InputStreamReader(hdfsInStream, "utf-8");
			String line;
			BufferedReader br = new BufferedReader(isr);
			while ((line = br.readLine()) != null) {
				StringTokenizer itr = new StringTokenizer(line);
				while (itr.hasMoreTokens()) {
//					System.out.println(itr.nextToken().split(",").getClass().getName().toString());
					//每一行作为一个数组
					tgt.add(itr.nextToken());
				}
			}
			// 测试集标签
			System.out.println("测试集标签");
			System.out.println(tgt);
	    }


    public void reduce(IntWritable key, Iterable<Text> values, 
                       Context context
                       ) throws IOException, InterruptedException {
    
    List<String> sortvalue = new ArrayList<String>();
    // 将每个值放入list中方便排序
    for (Text val : values) {
//    	System.out.println("###");
//    	System.out.println(val.toString());
        sortvalue.add(val.toString());
      }
    
    // 对距离进行排序
    Collections.sort(sortvalue, new Comparator<String>() {
    	 
        @Override
        public int compare(String o1, String o2) {
            // 升序
            //return o1.getAge()-o2.getAge();
        	double x = Double.parseDouble(o1.split(",")[1]); 
        	double y = Double.parseDouble(o2.split(",")[1]); 
            return Double.compare(x, y);
            // 降序
            // return Double.compare(y, x);
        }
    });
//    System.out.println(sortvalue.toString());
    // 存放前n个数据的标签
    List<String> labels = new ArrayList<String>();
    for (int i =0;i<n;i++){
    	labels.add(sortvalue.get(i).split(",")[0]);
    }
    // 将标签转换成集合方便计数
    Set<String> set = new LinkedHashSet<>();
    set.addAll(labels);
    List<String> labelset = new ArrayList<>(set);
    int[] count = new int[labelset.size()];
    // 将计数数组全部初始化为0
    for (int i=0;i<count.length;i++){
    	count[i] = 0;
    }
    // 对每个标签计数得到count,位置对应labelset
    for(int i=0;i<labelset.size();i++){
    	for (int j=0;j<labels.size();j++){
    		if (labelset.get(i).equals(labels.get(j))){
    			count[i] += 1;
    		}
    	}
    }
    
    // 求count最大值所在的索引
    int max = 0;
    for(int i=1;i<count.length;i++){
    	if(count[i] > count[max]){
    		max = i;
    	}
    }
    
    context.write(key, new Text("预测标签:" + labelset.get(max) + "\t" + "真实标签:" + String.valueOf(tgt.get(key.get()))));
    }
  }

五.运行结果

效果还不错,三十个测试数据只有一个预测错误
在这里插入图片描述

六.代码总览

package test;


/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */


import java.awt.datatransfer.StringSelection;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.StringTokenizer;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.Mapper.Context;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.util.GenericOptionsParser;
import org.apache.hadoop.yarn.webapp.example.MyApp;

import com.google.common.base.Strings;

public class KNN {

  public static class TokenizerMapper 
       extends Mapper<Object, Text, IntWritable,Text >{
    // 存放测试集路径
    private String localFiles;
    // 存放测试数据
    private List test = new ArrayList();
    @Override
	public void setup(Context context) throws IOException,InterruptedException{
    	Configuration conf = context.getConfiguration();
		// 获取测试集所在的hdfs路径
		localFiles  = conf.getStrings("test")[0];
		FileSystem fs = FileSystem.get(URI.create(localFiles), conf);  
		FSDataInputStream hdfsInStream = fs.open(new Path(localFiles));  
		// 从hdfs中读取测试集
		InputStreamReader isr = new InputStreamReader(hdfsInStream, "utf-8");	
		String line;
		BufferedReader br = new BufferedReader(isr);
		while ((line = br.readLine()) != null) {
			StringTokenizer itr = new StringTokenizer(line);
			while (itr.hasMoreTokens()) {
//				System.out.println(itr.nextToken().split(",").getClass().getName().toString());
				//每一行作为一个数组
				String[] tmp = itr.nextToken().split(",");
				List data = new ArrayList();
				for (String i : tmp){
					data.add(Double.parseDouble(i));
				}
				test.add(data);
			}
		}
		// 存储了所有的测试集
		System.out.println("测试数据");
		System.out.println(test);
    }
      
    public void map(Object key, Text value, Context context
                    ) throws IOException, InterruptedException {
      StringTokenizer itr = new StringTokenizer(value.toString());
      while (itr.hasMoreTokens()) {
    	    // 将训练数据分割
    	    String[] tmp = itr.nextToken().split(",");
    	    // 记录该训练集的标签
    	    String label = tmp[4];
    	    // 记录该训练集的属性值
			List data = new ArrayList();
			for (int i = 0;i<=3;i++){
				data.add(Double.parseDouble(tmp[i]));
			}
//			System.out.println(label);
//			System.out.println(data);
			for (int i = 0;i<test.size();i++){
				// 获得每个测试数据
				List tmp2 = (List) test.get(i);
				// 每个测试数据和训练数据的距离(这里使用欧氏距离)
				double dis = 0;
				for (int j=0;j<4;j++){
					dis += Math.pow( (double)tmp2.get(j)-(double)data.get(j),2);
				}
				dis = Math.sqrt(dis);
				// out 为类标签,距离
				String out = label + "," + String.valueOf(dis);
//				System.out.println(out.toString());
				// i为测试数据的标号
//				System.out.println(i);
				context.write(new IntWritable(i), new Text(out));
			}
      }
    }
  }
  
  public static class IntSumReducer 
       extends Reducer<IntWritable,Text,IntWritable,Text> {
	  
	    private String localFiles;
	    private List tgt = new ArrayList();
	    private int n;
	    //读取测试集的标签
	    @Override
		public void setup(Context context) throws IOException,InterruptedException{
	    	Configuration conf = context.getConfiguration();
			// 获取测试集标签所在的hdfs路径
			localFiles  = conf.getStrings("label")[0];
			// 读取n值
			n = conf.getInt("n", 3);
			FileSystem fs = FileSystem.get(URI.create(localFiles), conf);  
			FSDataInputStream hdfsInStream = fs.open(new Path(localFiles));  
			// 从hdfs中读取测试集
			InputStreamReader isr = new InputStreamReader(hdfsInStream, "utf-8");
			String line;
			BufferedReader br = new BufferedReader(isr);
			while ((line = br.readLine()) != null) {
				StringTokenizer itr = new StringTokenizer(line);
				while (itr.hasMoreTokens()) {
//					System.out.println(itr.nextToken().split(",").getClass().getName().toString());
					//每一行作为一个数组
					tgt.add(itr.nextToken());
				}
			}
			// 测试集标签
			System.out.println("测试集标签");
			System.out.println(tgt);
	    }


    public void reduce(IntWritable key, Iterable<Text> values, 
                       Context context
                       ) throws IOException, InterruptedException {
    
    List<String> sortvalue = new ArrayList<String>();
    // 将每个值放入list中方便排序
    for (Text val : values) {
//    	System.out.println("###");
//    	System.out.println(val.toString());
        sortvalue.add(val.toString());
      }
    
    // 对距离进行排序
    Collections.sort(sortvalue, new Comparator<String>() {
    	 
        @Override
        public int compare(String o1, String o2) {
            // 升序
            //return o1.getAge()-o2.getAge();
        	double x = Double.parseDouble(o1.split(",")[1]); 
        	double y = Double.parseDouble(o2.split(",")[1]); 
            return Double.compare(x, y);
            // 降序
            // return Double.compare(y, x);
        }
    });
//    System.out.println(sortvalue.toString());
    // 存放前n个数据的标签
    List<String> labels = new ArrayList<String>();
    for (int i =0;i<n;i++){
    	labels.add(sortvalue.get(i).split(",")[0]);
    }
    // 将标签转换成集合方便计数
    Set<String> set = new LinkedHashSet<>();
    set.addAll(labels);
    List<String> labelset = new ArrayList<>(set);
    int[] count = new int[labelset.size()];
    // 将计数数组全部初始化为0
    for (int i=0;i<count.length;i++){
    	count[i] = 0;
    }
    // 对每个标签计数得到count,位置对应labelset
    for(int i=0;i<labelset.size();i++){
    	for (int j=0;j<labels.size();j++){
    		if (labelset.get(i).equals(labels.get(j))){
    			count[i] += 1;
    		}
    	}
    }
    
    // 求count最大值所在的索引
    int max = 0;
    for(int i=1;i<count.length;i++){
    	if(count[i] > count[max]){
    		max = i;
    	}
    }
    
    context.write(key, new Text("预测标签:" + labelset.get(max) + "\t" + "真实标签:" + String.valueOf(tgt.get(key.get()))));
    }
  }

  public static void main(String[] args) throws Exception {
    Configuration conf = new Configuration();
    //String[] otherArgs = new GenericOptionsParser(conf, args).getRemainingArgs();
    //考虑的是测试集少量的情况,所以将测试集和测试集的标签的文件位置传入conf,在mapreduce中读取
    conf.setStrings("test", "hdfs://localhost:9000/iris/iris_test_data.csv");
    conf.setStrings("label", "hdfs://localhost:9000/iris/iris_test_lable.csv");
    // 从命令行传入参数N
    conf.setInt("n", Integer.parseInt(args[0]));
    String[] otherArgs = new String[]{"hdfs://localhost:9000/iris/iris_train.csv","hdfs://localhost:9000/iris/output/"};
    if (otherArgs.length < 2) {
      System.err.println("Usage: wordcount <in> [<in>...] <out>");
      System.exit(2);
    }
    Job job = Job.getInstance(conf, "KNN");
    job.setJarByClass(KNN.class);
    job.setMapperClass(TokenizerMapper.class);
    job.setReducerClass(IntSumReducer.class);
    job.setOutputKeyClass(IntWritable.class);
    job.setOutputValueClass(Text.class);
    for (int i = 0; i < otherArgs.length - 1; ++i) {
      //由于是训练集大的情况,所以将训练集输入
      FileInputFormat.addInputPath(job, new Path(otherArgs[i]));
    }
    FileOutputFormat.setOutputPath(job,
      new Path(otherArgs[otherArgs.length - 1]));
    System.exit(job.waitForCompletion(true) ? 0 : 1);
  }
}

七.问题与解决

1.测试集怎样传入?

因为考虑的是测试集较小的情况,所以在这里我是先将其传入hdfs上,然后将其地址作为conf的一个参数,再在setup函数中读取测试集,确保所有的节点都能获取所有测试集。除此之外还可以使用分布式缓存的方法。

2.怎样对"标签+距离"的List数组直接进行排序?

这个问题困扰了我很久,发现可以使用Collections.sort方法并重写其compare方法即可,将其compare方法改为将字符串中距离部分提取出来进行排序返回结果,实现了对定制数据的排序。使用这个方法就可以对各种数据按自己的方法来进行排序了。

3.怎样计算List数组中每个元素的数目并找到最大值?

这个我本来想找java自带的方法但是没找到,java不像python直接调用就可以了,所以就用了比较笨的方法,先将数组去重得到所有的标签,然后再新建一个数组进行计数,两个数组的索引是对应的,对原数组遍历来进行计数。得到计数数组后还要求最大数所在的索引,就用最简单的方法,设置一个max变量记录最大值所在的索引,遍历数组,如果比当前的max大则将max改为该值的索引。(方法真的太笨了,怀念用python的日子,一直想为什么不实现一个基于python的hadoop)

八.总结与感悟

1.java用的还不熟

用惯了python再用java就很多简单的方法都需要自己去实现去重写,在实现算法的过程中感觉最困难的不是逻辑上的困难而是各种数据类型的转换,各种简单算法的实现上,用起来就很不方便,也可能是我属于java初学者,有简单方法和各种技巧还不会用吧,还是得加强java的学习。

2.先设计再动手

在实现算法的时候一定不要看到就直接动手敲代码,可能敲到最后发现逻辑上有问题再改就很麻烦了,应该先把map函数和reduce函数都设计好,输入是什么,输出是什么,什么数据放在分布式上,什么数据作为全局文件,都要先设计好然后再动手,逻辑上就会顺利很多,至少大方向上不会出现问题。

猜你喜欢

转载自blog.csdn.net/weixin_43622131/article/details/106966924