Spark 实现Apriori

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/hhtop112408/article/details/82970211

Spark实现Apriori算法,Apriori算法如果支持度高那么效果会很好,如果支持度设置的低,性能会比较差。

1:createNewItemsetsFromPreviousOnes需要优化,没有想清楚。

2:结果没有转换正确的ID,API可以优化成泛型。

3:原理不复杂,代码看看就全懂了。

代码

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;

import scala.Tuple2;

/**
 * 
 */

public class SparkApriori implements Serializable
{

    private static final Pattern SPACE = Pattern.compile( "\\s+" );
    private double minSupport;
    private int itemsetNumber;

扫描二维码关注公众号,回复: 4176821 查看本文章

    public SparkApriori( double minSupport )
    {
        super( );
        this.minSupport = minSupport;
    }

    public void run( JavaRDD<String> data )
    {
        final long count = data.count( );
        JavaPairRDD<String, Long> pairs = data.flatMap( s ->
        {
            List<Tuple2<String, Long>> list = new ArrayList<Tuple2<String, Long>>( );
            String[] strs = SPACE.split( s );
            // return new Tuple2<String, Long>(str);
            for ( String str : strs )
            {
                list.add( new Tuple2<String, Long>( str, 1L ) );
            }
            return list;
        } ).mapToPair( s -> s ).aggregateByKey( 0L, ( s1, s2 ) -> s1 + s2, ( s1,
                s2 ) -> s1 + s2 );

        JavaPairRDD<String, Long> itemWithIndex = pairs.map( s -> s._1 ).zipWithIndex( );
        List<Tuple2<String, Long>> items = itemWithIndex.collect( );
        final Map<String, Long> map = itemWithIndex.collectAsMap( );
        
        System.out.println( map );
        final int elemtnCount = ( (Long) pairs.count( ) ).intValue( );

        Vector v = Vectors.dense( null );
        JavaRDD<SparseBooleanArray> rdd = data.map( s ->
        {

            String[] strs = SPACE.split( s );
            List<Long> list = new ArrayList<Long>( );
            for ( String str : strs )
            {
                try
                {
                    list.add( map.get( str ) );
                }
                catch ( Exception e )
                {
                    // Do nothing
                }
            }
            // return 1;
            return new SparseBooleanArray( elemtnCount, list.toArray( new Long[list.size( )] ) );

        } );

        JavaRDD<Tuple2<Long, Long>> caData = pairs.join( itemWithIndex ).map( s -> new Tuple2<Long, Long>( s._2._2, s._2._1 ) );

        List<long[]> itemSets = caData.filter( s -> s._2.doubleValue( )
                / count > minSupport ).map( s -> new long[]{
                        s._1
        } ).collect( );
        itemsetNumber = 1;
        while ( itemSets.size( ) > 0 )
        {
            writeData( itemSets, map );
            itemSets = calculateFrequentItemsets( rdd, itemSets, count, minSupport );
            if ( itemSets.size( ) > 0 )
            {
                itemSets = createNewItemsetsFromPreviousOnes( itemSets );
            }
            itemsetNumber = itemsetNumber + 1;
        }

    }

    private static void writeData( List<long[]> data, Map<String, Long> map)
    {
        //log the item
    }

    private static void foundFrequentItemSet( long[] itemset, long support, long numTransactions )
    {
        //write, please note the id need trans to the ori id.
        System.out.println( Arrays.toString( itemset )
                + "  ("
                + ( ( support / (double) numTransactions ) )
                + " "
                + support
                + ")" );

    }

    private static List<long[]>
            createNewItemsetsFromPreviousOnes( List<long[]> itemSets )
    {
        int currentItenLength = itemSets.get( 0 ).length;
        Map<String, long[]> map = new HashMap<String, long[]>( );
        for ( int i = 0; i < itemSets.size( ); i++ )
        {
            long[] X = itemSets.get( i );
            for ( int j = i + 1; j < itemSets.size( ); j++ )
            {
                long[] Y = itemSets.get( j );
                long[] newCand = new long[currentItenLength + 1];
                for (int k=0; k<currentItenLength; k++)
                {
                    newCand[k] = X[k];
                }
                int different = 0;

                for ( long y : Y )
                {
                    boolean find = false;
                    for ( long x : X )
                    {
                        if ( x == y )
                        {
                            find = true;
                            break;
                        }
                    }
                    if ( !find )
                    {
                        different = different + 1;
                        newCand[newCand.length - 1] = y;
                    }
                }
                
                if ( different == 1 )
                {
                    Arrays.sort( newCand );
                    map.put( Arrays.toString( newCand ), newCand );
                }

            }
        }

        return new ArrayList<long[]>( map.values( ) );
    }

    private static List<long[]> calculateFrequentItemsets(
            JavaRDD<SparseBooleanArray> rdd, List<long[]> itemSets,
            long totalCount, double minSup )
    {
        JavaPairRDD<String, Tuple2<long[], Long>> pairs = rdd.flatMap( s ->
        {
            List<Tuple2<String, Tuple2<long[], Long>>> retValue = new ArrayList<Tuple2<String, Tuple2<long[], Long>>>( );
            for ( long[] ls : itemSets )
            {
                boolean match = true;
                for ( long l : ls )
                {
                    if ( !s.contain( l ) )
                    {
                        match = false;
                        break;
                    }
                }
                if ( match )
                {
                    retValue.add( new Tuple2<String, Tuple2<long[], Long>>( Arrays.toString( ls ), new Tuple2<long[], Long>( ls, 1L ) ) );
                }
            }
            return retValue;
        } ).mapToPair( s -> s );
        
        Map a = pairs.collectAsMap( );
        List<Tuple2<long[], Long>> list = pairs.aggregateByKey( new Tuple2<long[], Long>( new long[]{}, 0L ), (
                s1, s2 ) -> new Tuple2<long[], Long>( s2._1, s1._2 + s2._2 ), (
                        s1, s2 ) -> new Tuple2<long[], Long>( s2._1, s1._2
                                + s2._2 ) ).map( s -> s._2 ).collect( );

        List<long[]> values = new ArrayList<long[]>( );
        for ( Tuple2<long[], Long> s : list )
        {
            if ( s._2.doubleValue( ) / totalCount > minSup )
            {
                foundFrequentItemSet( s._1, s._2, totalCount );
                values.add( s._1 );
            }
        }
        return values;
    }

    private static class SparseBooleanArray implements Serializable
    {

        private int count;
        private Long[] pos;

        public SparseBooleanArray( int count, Long[] pos )
        {
            super( );
            this.count = count;
            this.pos = pos;
        }

        public boolean contain( long l )
        {
            for ( long is : pos )
            {
                if ( is == l )
                {
                    return true;
                }
            }
            return false;
        }
    }

}

猜你喜欢

转载自blog.csdn.net/hhtop112408/article/details/82970211