Unorder 值的信息熵增益计算

用Java实现的计算Unorder值的信息熵计算,原理好动,逻辑就是SPARK MLLIB中决策树的逻辑。

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.stream.Collectors;


import scala.Tuple2;




/**
 * 
 */


public class UnorderGainTest
{
private static final int BIN = 6;
private static int STATS = -1;

public static void main( String[] args )
{
List<Tuple2<Double, Integer>> data = convertData( createSampleData( ) );
STATS = data.stream( ).map( s->s._2 ).distinct( ).sorted( ).collect( Collectors.toList( ) ).size();
int numSplits = ((Double)(Math.pow( 2, STATS - 1) - 1)).intValue( ); 
int[] allStats = new int[BIN*STATS];
List<Integer>[] splits = buildSplit(  );
for (Tuple2<Double, Integer> t:data)
{
for (int i=0; i<STATS; i++)
{
if (splits[i].contains( t._2 ))
{
updateStats( allStats, 0, t._1( ), i );
}
else
{
updateStats( allStats, getRightOffset( ), t._1( ), i );
}
}
}

double impurity = caleImpurity( data );
double bestGain = -Double.MAX_VALUE;
int bestSpilt = -1;

for ( int splitIndex = 0; splitIndex < numSplits; splitIndex++ )
{
double gain = caleEnt(allStats, 0, getRightOffset( ), splitIndex, impurity);
if (gain > bestGain)
{
bestGain = gain;
bestSpilt = splitIndex;
}
}


System.out.println( "Gain == " + bestGain );
System.out.println( "BestSplit == " + bestSpilt );
}

private static double caleEnt(int[] allStats, int leftOffset, int rightOffset, int spiliIndex, double impurity)
{

int[] leftArray = new int[STATS];
System.arraycopy(allStats, leftOffset + spiliIndex*STATS, leftArray, 0, STATS );

int[] rightArray = new int[STATS];
System.arraycopy(allStats, rightOffset + spiliIndex*STATS, rightArray, 0, STATS );

Integer leftCount = Arrays.stream( leftArray ).sum( );
Integer rightCount = Arrays.stream( rightArray ).sum( );

Integer totalCount = leftCount + rightCount;

double leftImpurity = calculate( leftArray, leftCount);
double rightImpurity = calculate( rightArray, rightCount);


double leftWeight = leftCount.doubleValue( )
/ totalCount.doubleValue( );
double rightWeight = rightCount.doubleValue( )
/ totalCount.doubleValue( );
double gain = impurity
- leftWeight * leftImpurity
- rightWeight * rightImpurity;
return gain;
}

private static double calculate(int[] counts, int totalCount)
{
int numClasses = counts.length;
double impurity = 0.0;
int classIndex = 0;

while ( classIndex < numClasses )
{
double classCount = counts[classIndex];
if ( classCount != 0 )
{
double freq = classCount / totalCount;
impurity -= freq * log2( freq );
}
classIndex += 1;
}


return impurity;
}

private static double caleImpurity(List<Tuple2<Double, Integer>> data)
{
int count = data.size( );
Map<Double, Integer> map = new HashMap<Double, Integer>( );
for (Tuple2<Double, Integer> t:data)
{
Integer value = map.get( t._1 );
if (value == null)
{
value = 1;
}
else
{
value ++;
}

map.put( t._1, value );
}

Iterator<Entry<Double, Integer>> itor =  map.entrySet( ).iterator( );
double impurity = 0.0;
while(itor.hasNext( ))
{
Entry<Double, Integer> entry = itor.next( );
double value = entry.getValue( ).doubleValue( )/count;
impurity = impurity - value*log2( value );
}
return impurity;
}
private static void updateStats(int[] allStats, int offset, double label, int splitIndex)
{
int pos = offset + splitIndex*STATS + ((Double)label).intValue( );
allStats[pos] = allStats[pos] + 1;
}


private static int getRightOffset()
{
return (BIN>>1) * STATS;
}
private static List<Integer>[] buildSplit()
{
List<Integer>[] retValue = new List[STATS];
//List<Integer> list = data.stream( ).map( s->s._2 ).distinct( ).sorted( ).collect( Collectors.toList( ) );
for (int i=0; i<STATS; i++)
{
retValue[i] = extractMultiClassCategories( i + 1 );
}
return retValue;
}
private static List<Tuple2<Double, String>> createSampleData()
{
List<Tuple2<Double, String>> retValue = new ArrayList<Tuple2<Double, String>>();
retValue.add( new Tuple2<Double, String> (1.0, "C"));
retValue.add( new Tuple2<Double, String> (1.0, "B"));
retValue.add( new Tuple2<Double, String> (1.0, "A"));
retValue.add( new Tuple2<Double, String> (1.0, "B"));
retValue.add( new Tuple2<Double, String> (1.0, "B"));
retValue.add( new Tuple2<Double, String> (0.0, "C"));

return retValue;
}

private static List<Tuple2<Double, Integer>> convertData(List<Tuple2<Double, String>> input)
{
List<String> strs = input.stream( ).map( s->s._2 ).distinct( ).sorted( ).collect( Collectors.toList( ) );
final Map<String, Integer> map = new HashMap<String, Integer>( );
for (int i=0; i<strs.size( ); i++)
{
map.put( strs.get( i ), i );
}
return input.stream( ).map( s->new Tuple2<Double, Integer>(s._1, map.get( s._2 )) ).collect( Collectors.toList( ) );

}

private static List<Integer> extractMultiClassCategories( int input )
{
List<Integer> retValue = new ArrayList<Integer>();

int bitShiftedInput = input;
int j = 0;
while(j < STATS)
{
if(bitShiftedInput %2 != 0)
{
retValue.add( j );
}
bitShiftedInput = bitShiftedInput >> 1;
j++;
}
return retValue;
}

private static double log2(double d)
{
return Math.log( d )/Math.log( 2.0 );
}
}

猜你喜欢

转载自blog.csdn.net/hhtop112408/article/details/79756297