Spark实现Apriori算法,Apriori算法如果支持度高那么效果会很好,如果支持度设置的低,性能会比较差。
1:createNewItemsetsFromPreviousOnes需要优化,没有想清楚。
2:结果没有转换正确的ID,API可以优化成泛型。
3:原理不复杂,代码看看就全懂了。
代码
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import scala.Tuple2;
/**
*
*/
public class SparkApriori implements Serializable
{
private static final Pattern SPACE = Pattern.compile( "\\s+" );
private double minSupport;
private int itemsetNumber;
public SparkApriori( double minSupport )
{
super( );
this.minSupport = minSupport;
}
public void run( JavaRDD<String> data )
{
final long count = data.count( );
JavaPairRDD<String, Long> pairs = data.flatMap( s ->
{
List<Tuple2<String, Long>> list = new ArrayList<Tuple2<String, Long>>( );
String[] strs = SPACE.split( s );
// return new Tuple2<String, Long>(str);
for ( String str : strs )
{
list.add( new Tuple2<String, Long>( str, 1L ) );
}
return list;
} ).mapToPair( s -> s ).aggregateByKey( 0L, ( s1, s2 ) -> s1 + s2, ( s1,
s2 ) -> s1 + s2 );
JavaPairRDD<String, Long> itemWithIndex = pairs.map( s -> s._1 ).zipWithIndex( );
List<Tuple2<String, Long>> items = itemWithIndex.collect( );
final Map<String, Long> map = itemWithIndex.collectAsMap( );
System.out.println( map );
final int elemtnCount = ( (Long) pairs.count( ) ).intValue( );
Vector v = Vectors.dense( null );
JavaRDD<SparseBooleanArray> rdd = data.map( s ->
{
String[] strs = SPACE.split( s );
List<Long> list = new ArrayList<Long>( );
for ( String str : strs )
{
try
{
list.add( map.get( str ) );
}
catch ( Exception e )
{
// Do nothing
}
}
// return 1;
return new SparseBooleanArray( elemtnCount, list.toArray( new Long[list.size( )] ) );
} );
JavaRDD<Tuple2<Long, Long>> caData = pairs.join( itemWithIndex ).map( s -> new Tuple2<Long, Long>( s._2._2, s._2._1 ) );
List<long[]> itemSets = caData.filter( s -> s._2.doubleValue( )
/ count > minSupport ).map( s -> new long[]{
s._1
} ).collect( );
itemsetNumber = 1;
while ( itemSets.size( ) > 0 )
{
writeData( itemSets, map );
itemSets = calculateFrequentItemsets( rdd, itemSets, count, minSupport );
if ( itemSets.size( ) > 0 )
{
itemSets = createNewItemsetsFromPreviousOnes( itemSets );
}
itemsetNumber = itemsetNumber + 1;
}
}
private static void writeData( List<long[]> data, Map<String, Long> map)
{
//log the item
}
private static void foundFrequentItemSet( long[] itemset, long support, long numTransactions )
{
//write, please note the id need trans to the ori id.
System.out.println( Arrays.toString( itemset )
+ " ("
+ ( ( support / (double) numTransactions ) )
+ " "
+ support
+ ")" );
}
private static List<long[]>
createNewItemsetsFromPreviousOnes( List<long[]> itemSets )
{
int currentItenLength = itemSets.get( 0 ).length;
Map<String, long[]> map = new HashMap<String, long[]>( );
for ( int i = 0; i < itemSets.size( ); i++ )
{
long[] X = itemSets.get( i );
for ( int j = i + 1; j < itemSets.size( ); j++ )
{
long[] Y = itemSets.get( j );
long[] newCand = new long[currentItenLength + 1];
for (int k=0; k<currentItenLength; k++)
{
newCand[k] = X[k];
}
int different = 0;
for ( long y : Y )
{
boolean find = false;
for ( long x : X )
{
if ( x == y )
{
find = true;
break;
}
}
if ( !find )
{
different = different + 1;
newCand[newCand.length - 1] = y;
}
}
if ( different == 1 )
{
Arrays.sort( newCand );
map.put( Arrays.toString( newCand ), newCand );
}
}
}
return new ArrayList<long[]>( map.values( ) );
}
private static List<long[]> calculateFrequentItemsets(
JavaRDD<SparseBooleanArray> rdd, List<long[]> itemSets,
long totalCount, double minSup )
{
JavaPairRDD<String, Tuple2<long[], Long>> pairs = rdd.flatMap( s ->
{
List<Tuple2<String, Tuple2<long[], Long>>> retValue = new ArrayList<Tuple2<String, Tuple2<long[], Long>>>( );
for ( long[] ls : itemSets )
{
boolean match = true;
for ( long l : ls )
{
if ( !s.contain( l ) )
{
match = false;
break;
}
}
if ( match )
{
retValue.add( new Tuple2<String, Tuple2<long[], Long>>( Arrays.toString( ls ), new Tuple2<long[], Long>( ls, 1L ) ) );
}
}
return retValue;
} ).mapToPair( s -> s );
Map a = pairs.collectAsMap( );
List<Tuple2<long[], Long>> list = pairs.aggregateByKey( new Tuple2<long[], Long>( new long[]{}, 0L ), (
s1, s2 ) -> new Tuple2<long[], Long>( s2._1, s1._2 + s2._2 ), (
s1, s2 ) -> new Tuple2<long[], Long>( s2._1, s1._2
+ s2._2 ) ).map( s -> s._2 ).collect( );
List<long[]> values = new ArrayList<long[]>( );
for ( Tuple2<long[], Long> s : list )
{
if ( s._2.doubleValue( ) / totalCount > minSup )
{
foundFrequentItemSet( s._1, s._2, totalCount );
values.add( s._1 );
}
}
return values;
}
private static class SparseBooleanArray implements Serializable
{
private int count;
private Long[] pos;
public SparseBooleanArray( int count, Long[] pos )
{
super( );
this.count = count;
this.pos = pos;
}
public boolean contain( long l )
{
for ( long is : pos )
{
if ( is == l )
{
return true;
}
}
return false;
}
}
}