R:查询数据集
S:训练数据集
需要对R中每一个元素计算与S中每一个元素的距离
可以首先计算得到两个数据集的笛卡尔积
用map找到distance
groupBykey同一个r的distance 找出其中的KNN
public class KNN{
public static void main(String[] args) throws Exception{
//处理输入参数
//创建spark上下文对象
JavaSparkContext ctx = createJavaSparkContext("knn");
//广播共享对象 将k和d作为共享对象,在所有节点均能访问
final Broadcast<Integr> broadcastK= ctx.broadcast(k);
final Broadcast<Integer> broadcastD=ctx.broadcast(d);.
//建立RDD
JavaRDD<String> R= ctx.textfile("",1);
JavaRDD<String> S= ctx.textfile("",1);.
//计算笛卡尔积
JavaPairRDD<String,String> cart = R.cartesian(s);
//计算distance 输出是(r,(distance,classification))
JavaPairRDD<String,Tuple2<Integer,String>> knnmapped =cart.mapToPair(
new PairFunction<Tuple2<String,String>,String,Tupple<Integer,String>>(){
public Tuple2<String,Tuple2<Integer,String>>call(Tuple2<String,String> cartrecord){
String rrecord=cartrecord._1;
String srecord=cartrecord._2;
String[] rtokens=rrecord.split(";");
String rrecordid=tokens[0];
String r=tokens[1];
String[] stokens=srecord.split(";");
String sclassifiationid=stokens[1];
String s=stokens[2];
Integer d= broadcastd.value();
double distance =calculatedistance(r,s,d);
String k =rrecordid;
Tuple2<Double,String> v=new Tuple2<Double,String>(distance,sclassification);
return new Tuple2<String,Tuple2<Double,String>>(k,v);
}
})
//按r对距离分组
JavaPairRDD<String,Iterable<Tuple2<Double,String>>> knngrouped= knnmapped.groupByKey();
//找出knn对r分类
JavaPairRDD<String,String> knnoutput= knngrouped.mapvalues(
new Function<Iterable<Tuple2<Double,String>>, String>(){
public String call(terable<Tuple2<Double,String>> neighbors){
SortedMap<Double,String> nearestK= findNearest(neighbors,k);
Map<String,Integer> majority=buildClassificationCount(nearestK);
String selectedClassification= classifByMajority(majority);
return selectedClassfication;
}
})
}
//str 输入字符串 delimiter 分隔符
static List<Double > splitOnToListOfDouble(String str,String delimiter){
Splitter splitter=Splitter.on(delimiter).trimResults();
Iterable<String> tokens=splitter.split(str);
List<Double> list =new ArrayList<Double>();
for(String token: tokens){
double data = Double.parseDouble(token);
list.add(data);
}
return list;
}
static double calculateDistance(String rasstring,String,sasstring,int d){
List<Double> r=splitOnToListOfDouble(rasstring);
List<Double> s=splitOnToListOfDoble(sasstring);
double sum =0.0;
for (int i=0; i<d;i++){
double difference = r.get(i)-s.get(i);
sum +=difference *diffrerence;
}
return Math.sqrt(sum);
}
//给定{(distance,classification)} findNearestK()根据距离找出k个近邻
static SortedMap<Double, String> findNearestK(Iterable<Tuple2<Double,String>> neighbors, int k){
//只保留k个近邻
//Map的单元是对键值对的处理,之前分析过的两种Map,HashMap和LinkedHashMap都是用哈希值去寻找我们想要的键值对,优点是由O(1)的查找速度。
那如果我们在一个对查找性能要求不那么高,反而对有序性要求比较高的应用场景呢?
这个时候HashMap就不再适用了,我们需要一种新的Map,在JDK中提供了一个接口:SortedMap
//TreeMap 是一个有序的key-value集合TreeMap基于红黑树(Red-Black tree)实现。该映射根据其键的自然顺序进行排序,或者根据创建映射时提供的 Comparator 进行排序,具体取决于使用的构造方法。
SortedMap接口主要提供有序的Map实现。
Map的主要实现有HashMap,TreeMap,HashTable,LinkedHashMap。
TreeMap实现了SortedMap接口,保证了有序性。默认的排序是根据key值进行升序排序,也可以重写comparator方法来根据value进行排序
SortedMap<Double,String> nearestK=new TreeMap<Double,String>();
for(Tuple2<Double,String> neighbor:neighbors){
Double distance =neighbor._1;
String classificationID =neighbor._2;
nearstK.put(distance,classificationID);
if(nearestK.size()>k){
nearestK.remove(nearestK.lastKey());}
}
return nearestK;
}
// buildClassificationCount() 根据多数计数选择分类
Static Map<String, Integer> buildClassificationcount(Map<Double,String> nearestK){
Map<String,Integer> majority = new HashMap<String,Integer>();
for(Map.Entry<Double,String> entry:nearestK.entryset()){
String classificationID=entry.getvalue();
Integer cont =majority.get(classificationID);
if(count==null){
majority.put(classificationID,1);
}
else{
majority.put(classificationID, count+1)
}
}
return majority;
}
//classifyByMajrity()根据多数原则进行分类
static String classifyByMajority(Map<Strinig,Integer> majority){
int votes=0;
String selectedClassification=null;
for(Map.Entry<String,Interger> entry:majority.entrySet()){
if (selectedClassification==null){
selectedClassification=entry.getKey();
votes= entry.getValue();
}
else{
int count = entry.getValue();
if (count>votes){selectedClassification = entry.getKey(); votes=count;}
}
}
return selectedClassification;
}
}