ALS java实现

用java实现了简单的ALS,Spark ALS的逻辑,Spark 是RDD分块计算。代码如下,


public class ALSCaleTest
{
private static final Long SEED = 10L;
private static final BLAS blas = BLAS.getInstance( );
private static final LAPACK lapack = LAPACK.getInstance( );
private static final int MAXITER = 10;
private static final int RANK = 2;
private static final double LAMBDA = 0.01;
public static <E> void main( String[] args )
{
List<RatingCale> list =Arrays.asList( 
new RatingCale( 1, 11, 3.0f ),
new RatingCale( 1, 12, 4.0f ),
new RatingCale( 2, 12, 3.0f ),
new RatingCale( 2, 13, 4.5f ),
new RatingCale( 3, 11, 3.0f ),
new RatingCale( 3, 12, 2.0f ));

List<Integer> users = userNumber( list );
Random seedGen = new XORShiftRandom( SEED );
List<Tuple2<Integer, float[]>> userFactors = initFactors( users, RANK,  seedGen.nextLong( ));
List<Tuple2<Integer, float[]>> itemFactors = null;
Map<Integer, List<Tuple2<Integer, Float>>> userBlocks = userBlocks( list );
Map<Integer, List<Tuple2<Integer, Float>>> itemBlocks = itemBlocks( list );


for (int i=0; i<MAXITER; i++)
{
itemFactors = computeFactors( itemBlocks, userFactors );
userFactors = computeFactors( userBlocks, itemFactors );
}

List<Tuple2<Integer, Float>> result = recommendProducts(1,13, userFactors, itemFactors );
System.out.println( result );
System.out.println( "Done" );
}

private static List<Tuple2<Integer, float[]>> computeFactors(Map<Integer, List<Tuple2<Integer, Float>>> map, List<Tuple2<Integer, float[]>> factors)
{
List<Tuple2<Integer, float[]>> retValue = new ArrayList<Tuple2<Integer, float[]>>();
Iterator<Entry<Integer, List<Tuple2<Integer, Float>>>> iter = map.entrySet( ).iterator( );
NormalEquation ne = new NormalEquation( RANK );
while(iter.hasNext( ))
{
Entry<Integer, List<Tuple2<Integer, Float>>> value = iter.next( );
int srcId = value.getKey( );
List<Tuple2<Integer, Float>> list = value.getValue( );
ne.reset( );

for (Tuple2<Integer, Float> t:list)
{
ne.add( getFactorFromList( factors, t._1 ), t._2, 1.0 );
}

float[] newFactors = choleskySolver( ne, LAMBDA*list.size( ) );
retValue.add( new Tuple2<Integer, float[]>(srcId, newFactors) );
}
return retValue;
}

private static float[] getFactorFromList(List<Tuple2<Integer, float[]>> list, int id)
{
for (Tuple2<Integer, float[]> t:list)
{
if (t._1 == id)
{
return t._2;
}
}
throw new RuntimeException( "Error" );
}

private static Map<Integer, List<Tuple2<Integer, Float>>> userBlocks(List<RatingCale> list)
{
Map<Integer, List<Tuple2<Integer, Float>>> retValue = new HashMap<Integer, List<Tuple2<Integer, Float>>>( );
for (RatingCale c:list)
{
List<Tuple2<Integer, Float>> tuples = null;
if (retValue.containsKey( c.user ))
{
tuples = retValue.get( c.user );
tuples.add( new Tuple2<Integer, Float> (c.item, c.rating));
}
else
{
tuples = new ArrayList<Tuple2<Integer, Float>>( );
tuples.add( new Tuple2<Integer, Float> (c.item, c.rating));
retValue.put( c.user, tuples );
}
}
return retValue;
}

private static Map<Integer, List<Tuple2<Integer, Float>>> itemBlocks(List<RatingCale> list)
{
Map<Integer, List<Tuple2<Integer, Float>>> retValue = new HashMap<Integer, List<Tuple2<Integer, Float>>>( );
for (RatingCale c:list)
{
List<Tuple2<Integer, Float>> tuples = null;
if (retValue.containsKey( c.item ))
{
tuples = retValue.get( c.item );
tuples.add( new Tuple2<Integer, Float> (c.user, c.rating));
}
else
{
tuples = new ArrayList<Tuple2<Integer, Float>>( );
tuples.add( new Tuple2<Integer, Float> (c.user, c.rating));
retValue.put( c.item, tuples );
}
}
return retValue;
}

private static List<Integer> userNumber(List<RatingCale> list)
{
List<Integer> retValue = new ArrayList<Integer>();
for (RatingCale c:list)
{
if (!retValue.contains( c.user ))
{
retValue.add( c.user );
}
}
return retValue;
}

private static List<Tuple2<Integer, float[]>> initFactors(List<Integer> list, int rank, long seed)
{
List<Tuple2<Integer, float[]>> retValue = new ArrayList<Tuple2<Integer, float[]>>();
Random random = new XORShiftRandom( package$.MODULE$.byteswap64( seed ) );
for (int i=0; i<list.size( ); i++)
{
float[] factor = new float[rank];
for (int j=0; j<rank; j++)
{
factor[j] = ((Double)random.nextGaussian( )).floatValue( ); 
}
float nrm = blas.snrm2( rank, factor, 1 );
blas.sscal( rank, 1.0f / nrm, factor, 1 );
retValue.add( new Tuple2<Integer, float[]>(list.get( i ), factor) );
}
return retValue;
}
private static float[] choleskySolver( NormalEquation ne, double lambda )
{
int k = ne.k;
int i = 0;
int j = 2;
while (i < ne.trik)
{
ne.ata[i]  = ne.ata[i] + lambda;
i = i + j;
j = j + 1;
}
solve( ne.ata, ne.atb );
float[] x = new float[k];
i=0;
while( i< k)
{
x[i] = ((Double)ne.atb[i]).floatValue( );
i = i + 1;
}
ne.reset( );
return x;
}

private static double[] solve(double[] A, double[] bx)
{
int k = bx.length;
intW info = new intW( 0 );
lapack.dppsv("U", k, 1, A, bx, k, info);

if (info.val != 0)
{
throw new RuntimeException( "LAPACK run error" );
}
return bx;
}
private static class RatingCale
{
int user;
int item;
float rating;
public RatingCale( int user, int item, float rating )
{
super( );
this.user = user;
this.item = item;
this.rating = rating;
}
}

private static List<Tuple2<Integer, Float>> recommendProducts( int user, int num, List<Tuple2<Integer, float[]>> userFactors, List<Tuple2<Integer, float[]>> itemFactors)
{
float[] userFactor = null;
for (Tuple2<Integer, float[]> t: userFactors)
{
if (t._1 == user)
{
userFactor = t._2;
break;
}
}
if (userFactor == null)
{
throw new RuntimeException( "Error" );
}
return recommend( userFactor, itemFactors, num );
}
private static List<Tuple2<Integer, Float>> recommend(
float[] recommendToFeatures,
List<Tuple2<Integer, float[]>> recommendableFeatures, int num )
{
List<Tuple2<Integer, Float>> retValue = new ArrayList<Tuple2<Integer, Float>>( );
for (Tuple2<Integer, float[]> t: recommendableFeatures)
{
float value = blas.sdot( t._2.length, recommendToFeatures, 1, t._2, 1 );
retValue.add( new Tuple2<Integer, Float>(t._1, value) );
}
retValue.sort( new Comparator<Tuple2<Integer, Float>>()
{


@Override
public int compare( Tuple2<Integer, Float> o1,
Tuple2<Integer, Float> o2 )
{
return o1._2  < o2._2( ) ? 1:-1;
}

});
if (retValue.size( ) > num)
{
return retValue.subList( 0, num );
}
else
{
return retValue;
}


}

private static double[] convert2double(float[] fs)
{
double[] ds = new double[fs.length];
for (int i=0; i<fs.length; i++)
{
ds[i] = fs[i];
}

return ds;
}

private static class NormalEquation
{


private static final String upper = "U";
private int k;
private int trik;
private double[] ata;
private double[] atb;
private double[] da;


public NormalEquation( int k )
{
super( );
this.k = k;
trik = k * ( k + 1 ) / 2;
ata = new double[trik];
atb = new double[k];
da = new double[k];
}


private void copyToDouble( float[] a )
{
int i = 0;
while ( i < k )
{
da[i] = a[i];
i = i + 1;
}
}


NormalEquation add( float[] a, double b, double c )
{
copyToDouble( a );
blas.dspr( upper, k, c, da, 1, ata );
if ( b != 0.0 )
{
blas.daxpy( k, c * b, da, 1, atb, 1 );
}
return this;
}


NormalEquation merge( NormalEquation other )
{
blas.daxpy( ata.length, 1.0, other.ata, 1, ata, 1 );
blas.daxpy( atb.length, 1.0, other.atb, 1, atb, 1 );
return this;
}


void reset( )
{
Arrays.fill( ata, 0.0 );
Arrays.fill( atb, 0.0 );
}
}


}

猜你喜欢

转载自blog.csdn.net/hhtop112408/article/details/80899366
ALS