KNN计算

用代码写了暴力计算和分区计算K近邻。

1:比较的时候其实可以不用开方。

2: Jing Wang† Jingdong Wang‡ 2012的论文Scalable k-NN graph construction for visual descriptors∗ 中复杂的分区没有实现。

代码如下,

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;


import org.apache.commons.collections.IteratorUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.StorageLevels;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.graphx.Edge;
import org.apache.spark.graphx.Graph;
import org.apache.spark.mllib.linalg.BLAS;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;


import scala.Option;
import scala.Some;
import scala.Tuple2;
import scala.Tuple4;
import scala.reflect.ClassManifestFactory;


/**
 * 
 */


public class LMeansTest
{


public static void main( String[] args )
{
SparkConf conf = new SparkConf().setAppName( "ShortPaths" ).setMaster( "local" );
JavaSparkContext ctx = new JavaSparkContext( conf );

List<KnnVertex> data = generateKnnData( );
Graph<KnnVertex, Double> g = knnGraph( data, 2, ctx );
Graph<KnnVertex, Double> g1 = knnGraphApprox( data, 2, ctx );

ctx.stop( );
System.out.println( "done" );
}

private static Graph<KnnVertex, Double> knnGraph(List<KnnVertex> a, int k, JavaSparkContext ctx)
{
JavaRDD<Tuple2<Long, KnnVertex>> v = ctx.parallelize( a ).zipWithIndex( ).map( s-> new Tuple2<Long, KnnVertex>(s._2, s._1));
List<Tuple2<Long, KnnVertex>> a2 = v.collect( );
JavaRDD<Edge<Double>> edge = v.map( v1-> 
{
List<Long> list = a2.stream( ).map( v2->{

return new Tuple2<Long, Double>(v2._1, v1._2( ).dist( v2._2 ));
}).sorted( (s1, s2) -> s1._2 - s2._2 >= 0?1:-1).map( s->s._1 )
.collect( Collectors.toList( ) ).subList( 1, k+1);
return new Tuple2<Long, List<Long>>(v1._1, list);
}).flatMap( s->{
return s._2.stream( ).map( vid2->new Edge<Double>(s._1, vid2, 1.0/(1 + a2.get( ((Long)vid2).intValue( ) )._2( ).dist( a2.get( ((Long)s._1 ).intValue( ))._2 ))) ).collect( Collectors.toList( ) );
});

return Graph.apply( v.map( s->new Tuple2<Object, KnnVertex>(s._1, s._2 )).rdd( ), edge.rdd( ), new KnnVertex( ) , StorageLevels.MEMORY_ONLY, StorageLevels.MEMORY_ONLY,
ClassManifestFactory.classType( KnnVertex.class ), ClassManifestFactory.classType( Double.class ) );
}


private static Graph<KnnVertex, Double> knnGraphApprox(List<KnnVertex> a, int k, JavaSparkContext ctx)
{

JavaRDD<Tuple2<Long, KnnVertex>> v = ctx.parallelize( a ).zipWithIndex( ).map( s-> new Tuple2<Long, KnnVertex>(s._2, s._1));
List<Tuple2<Long, KnnVertex>> a2 = v.collect( );
int n = 2;


Tuple4<Double, Double, Double, Double> minMax = v.map( s-> 
{
return new Tuple4<Double, Double, Double, Double>(s._2.pos[0], s._2( ).pos[0], s._2( ).pos[1],s._2( ).pos[1]);
} )
.reduce( (s1, s2 ) -> 
new Tuple4<Double, Double, Double, Double>(Math.min( s1._1( ), s2._1( ) ), Math.max( s1._2( ), s2._2( ) ), Math.min( s1._3( ), s2._3( ) ), Math.max( s1._4( ), s2._4( ) ))) ;

JavaRDD<Edge<Double>> edge = calcEdges( v, minMax, 0.0, n, k, a2 ).union( calcEdges( v, minMax, 0.5, n, k, a2 ) ).distinct( ).mapToPair( s->new Tuple2<Long, Edge>(s.srcId( ), s))
.groupByKey( ).flatMap( s->
{
List<Edge<Double>> list = IteratorUtils.toList( s._2.iterator( ) );
List<Edge<Double>> result = list.stream( ).sorted( (s1, s2) -> s1.attr( ) - s2.attr( )>-0?1:-1 ).collect( Collectors.toList( ) );
List<Edge<Double>> retValue ;
if (result.size( ) > k)
{
retValue = result.subList( 0, k );
}
else
{
retValue = result;
}

return retValue;
});

return Graph.apply( v.map( s->new Tuple2<Object, KnnVertex>(s._1, s._2 )).rdd( ), edge.rdd( ), new KnnVertex( ) , StorageLevels.MEMORY_ONLY, StorageLevels.MEMORY_ONLY,
ClassManifestFactory.classType( KnnVertex.class ), ClassManifestFactory.classType( Double.class ) );


}

private static JavaRDD<Edge<Double>> calcEdges(JavaRDD<Tuple2<Long, KnnVertex>> v, Tuple4<Double, Double, Double, Double> minMax, double offset, int n, int k, List<Tuple2<Long, KnnVertex>> a2 )
{
double xRange = minMax._2( ) - minMax._1( );
double yRang = minMax._4( ) - minMax._3( );

List lsit = v.map( s->{
double dis = Math.floor( (s._2( ).pos[0] - minMax._1( ))/xRange*(n - 1)  + offset)*n + Math.floor( (s._2.pos[1] - minMax._3( ) )/yRang * (n-1) + offset);
return new Tuple2<Double, Tuple2<Long, KnnVertex>>(dis, s);
}).collect( );
System.out.println( "lsit" );
return v.map( s->{
double dis = Math.floor( (s._2( ).pos[0] - minMax._1( ))/xRange*(n - 1)  + offset)*n + Math.floor( (s._2.pos[1] - minMax._3( ) )/yRang * (n-1) + offset);
return new Tuple2<Double, Tuple2<Long, KnnVertex>>(dis, s);
}).mapToPair( s->s ).groupByKey( n*n ).mapPartitions(  new FlatMapFunction<Iterator<Tuple2<Double,Iterable<Tuple2<Long,KnnVertex>>>>, Edge<Double>>( )
{


@Override
public Iterable<Edge<Double>>
call( Iterator<Tuple2<Double, Iterable<Tuple2<Long, KnnVertex>>>> t )
throws Exception
{
List<Tuple2<Long, KnnVertex>> list = new ArrayList<Tuple2<Long, KnnVertex>>();
while(t.hasNext( ))
{
Iterator<Tuple2<Long, KnnVertex>> itor = t.next( )._2.iterator( );
list.addAll( IteratorUtils.toList( itor ) );
}

return list.stream( ).map( s->  
{
int size = list.size( );
return new Tuple2<Long, List<Long>>(s._1, calcKDist(list, s._2, k));
}).flatMap( s->
{
return s._2.stream( ).map( vid2->new Edge<Double>( s._1, vid2, 1.0/(1 + a2.get( ((Long)s._1).intValue( ) )._2( ).dist( a2.get( ((Long)vid2 ).intValue( ))._2 )) ) );
}).collect( Collectors.toList( ) );
}
});
}

private static List<Long> calcKDist(List<Tuple2<Long, KnnVertex>> list, KnnVertex v, int k)
{
List<Long> retValue = list.stream( ).map( s->new Tuple2<Long, Double>(s._1, v.dist( s._2( ) )) ).sorted( (s1, s2) -> s1._2 - s2._2 >= 0?1:-1).map( s->s._1 )
.collect( Collectors.toList( ) );
if (retValue.size( ) > k + 1)
{
retValue = retValue.subList( 1, k + 1 );
}
else
{
retValue.remove( 0 );
}
 
return retValue;

}

public static List<KnnVertex> generateKnnData()
{
List<KnnVertex> retValue = new ArrayList<KnnVertex>( );
for (int i=0; i<6;i++)
{
Option<Integer> opt = Option.empty( );
double[] ds = new double[2];
for (int j=0; j<2; j++)
{
ds[j] = i;
}
retValue.add( new KnnVertex( opt, ds) );
}
return retValue;
}

public static List<KnnVertex> generateRandomKnnData()
{
Random random = new Random( 17L );
List<KnnVertex> retValue = new ArrayList<KnnVertex>( );
int n=20;
for (int i=0; i<2*n; i++)
{
double x = random.nextDouble( );
if (i<=n)
{

Option<Integer> opt = i%n == 0?new Some(0):Option.empty( );
int arrayCount = ((Double)(x*50)).intValue();
if (arrayCount < 2)
{
arrayCount = 2;
}
double[] ds = new double[arrayCount];
for (int j=0; j<arrayCount; j++)
{
ds[j] = 20 + (Math.sin( x*Math.PI ) + random.nextDouble( )/2.0) * 25;
}

retValue.add( new KnnVertex( opt, ds) );
}
else
{
Option<Integer> opt = i%n == 0?new Some(1):Option.empty( );
int arrayCount = ((Double)(x*50)).intValue() + 25;
double[] ds = new double[arrayCount];
for (int j=0; j<arrayCount; j++)
{
ds[j] = 30 - (Math.sin( x*Math.PI ) + random.nextDouble( )/2.0) * 25;
}

retValue.add( new KnnVertex( opt, ds) );
}
}
return retValue;
}

private static class KnnVertex implements Serializable
{
private Option<Integer> classNum;
private double[] pos;

public KnnVertex()
{

}

public KnnVertex( Option<Integer> classNum, double[] pos )
{
super( );
this.classNum = classNum;
this.pos = pos;
}


public double dist(KnnVertex other)
{
int len = Math.min( pos.length, other.pos.length );
double sum = 0.0;
for (int i=0; i<len; i++)
{
double dis = pos[i] - other.pos[i];
sum = sum + dis*dis;
}

return Math.sqrt( sum );
}
}
}

猜你喜欢

转载自blog.csdn.net/hhtop112408/article/details/80312042
kNN