【Flink原理和应用】:Flink上部署K-Means机器学习模型

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/hxcaifly/article/details/86496243

1. K-Means的简单介绍

这个案例将实现一个简单的K-Means聚类算法。有必要先简单地介绍下K-Means的算法计算原理。

K-Means均值是一种迭代聚类算法,其工作原理如下:

  1. K-Means基于点数据集和一个初始的K值(簇个数)来计算;
  2. 在每次迭代中,算法计算每个数据点到每个簇中心的距离。每个点都被分配到离它最近的簇中心;
  3. 随后,将簇中的所有点的坐标平均值作为该簇的新的中心点;
  4. 被移动之后的簇中心将传递给下一轮迭代计算;

算法在固定次数的迭代后终止 (本案例采用的),或者簇中心不再怎么移动了,那么也可以终止计算。

2. 本案例说明

这个案例是在二维数据点数据集上实现的。输入文件是纯文本文件,文件格式必须要满足如下格式:

  • 二维点数据集表示为两个由空白字符分隔的双精度值。数据点用换行符分隔。
    例如:"1.2 2.3\n5.3 7.2\n"将代表两个点,分别是 (x=1.2, y=2.3)和(x=5.3, y=7.2)。

  • 簇中心将由id和点坐标来呈现。
    例如:"1 6.2 3.2\n2 2.9 5.7\n"将代表两个簇中心,分别是(id=1, x=6.2, y=3.2)和(id=2, x=2.9, y=5.7)。

通过本案例我们将主要学习如下知识:

  1. 批量迭代;
  2. 批量迭代中的广播变量;
  3. Java objects(POJOs)。

本案例主要是讲解一种应用思维方式,所以用来训练的原始数据不多。主要目的是为了展示效果。

3. 源码分析

3.1. 二维空间点坐标数据结构

public static class Point implements Serializable {
        // x坐标,y坐标
		public double x, y;

		public Point() {}

		public Point(double x, double y) {
			this.x = x;
			this.y = y;
		}
        // 点坐标的加法器
		public Point add(Point other) {
			x += other.x;
			y += other.y;
			return this;
		}
        
        // 点坐标的除法器
		public Point div(long val) {
			x /= val;
			y /= val;
			return this;
		}
        // 计算点之间的欧式距离
		public double euclideanDistance(Point other) {
			return Math.sqrt((x - other.x) * (x - other.x) + (y - other.y) * (y - other.y));
		}

		public void clear() {
			x = y = 0.0;
		}

		@Override
		public String toString() {
			return x + " " + y;
		}
	}

3.2. 簇中心的数据结构

簇中心从物理角度看是称为质心。质心的数据结构代码定义如下:

扫描二维码关注公众号,回复: 5156071 查看本文章
/**
 * 质心类, 基于点坐标和id.
 */
public static class Centroid extends Point {

	public int id;

	public Centroid() {}

	public Centroid(int id, double x, double y) {
		super(x, y);
		this.id = id;
	}

	public Centroid(int id, Point p) {
		super(p.x, p.y);
		this.id = id;
	}

	@Override
	public String toString() {
		return id + " " + super.toString();
	}
}

簇中心(质心)的定义类,是基于Point的。其由一个质心Id和质心的位置坐标组成。

3.3. 默认数据集说明

如果主程序执行时没有指定输入的CSV文件路径,那么就读取默认数据。默认数据的定义如下:

package org.apache.flink.examples.java.clustering.util;

import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.examples.java.clustering.KMeans.Centroid;
import org.apache.flink.examples.java.clustering.KMeans.Point;

import java.util.LinkedList;
import java.util.List;

/**
 * 提供用于K-Means示例程序的默认数据集。如果没有为程序提供参数,则使用默认数据集。
 *
 */
public class KMeansData {

	/**
	 * 簇中心(质心)数据
	 */
	public static final Object[][] CENTROIDS = new Object[][] {
		new Object[] {1, -31.85, -44.77},
		new Object[]{2, 35.16, 17.46},
		new Object[]{3, -5.16, 21.93},
		new Object[]{4, -24.06, 6.81}
	};

	/**
	 * 输入的点数据
	 */
	public static final Object[][] POINTS = new Object[][] {
		new Object[] {-14.22, -48.01},
		new Object[] {-22.78, 37.10},
		new Object[] {56.18, -42.99},
		new Object[] {35.04, 50.29},
		new Object[] {-9.53, -46.26},
		new Object[] {-34.35, 48.25},
		new Object[] {55.82, -57.49},
		new Object[] {21.03, 54.64},
		new Object[] {-13.63, -42.26},
		new Object[] {-36.57, 32.63},
		new Object[] {50.65, -52.40},
		new Object[] {24.48, 34.04},
		new Object[] {-2.69, -36.02},
		new Object[] {-38.80, 36.58},
		new Object[] {24.00, -53.74},
		new Object[] {32.41, 24.96},
		new Object[] {-4.32, -56.92},
		new Object[] {-22.68, 29.42},
		new Object[] {59.02, -39.56},
		new Object[] {24.47, 45.07},
		new Object[] {5.23, -41.20},
		new Object[] {-23.00, 38.15},
		new Object[] {44.55, -51.50},
		new Object[] {14.62, 59.06},
		new Object[] {7.41, -56.05},
		new Object[] {-26.63, 28.97},
		new Object[] {47.37, -44.72},
		new Object[] {29.07, 51.06},
		new Object[] {0.59, -31.89},
		new Object[] {-39.09, 20.78},
		new Object[] {42.97, -48.98},
		new Object[] {34.36, 49.08},
		new Object[] {-21.91, -49.01},
		new Object[] {-46.68, 46.04},
		new Object[] {48.52, -43.67},
		new Object[] {30.05, 49.25},
		new Object[] {4.03, -43.56},
		new Object[] {-37.85, 41.72},
		new Object[] {38.24, -48.32},
		new Object[] {20.83, 57.85}
	};

	/**
	 * 得到默认的质心数据
	 * @param env
	 * @return
	 */
	public static DataSet<Centroid> getDefaultCentroidDataSet(ExecutionEnvironment env) {
		List<Centroid> centroidList = new LinkedList<Centroid>();
		for (Object[] centroid : CENTROIDS) {
			centroidList.add(
					new Centroid((Integer) centroid[0], (Double) centroid[1], (Double) centroid[2]));
		}
		return env.fromCollection(centroidList);
	}

	/**
	 * 得到默认的点数据
	 * @param env
	 * @return
	 */
	public static DataSet<Point> getDefaultPointDataSet(ExecutionEnvironment env) {
		List<Point> pointList = new LinkedList<Point>();
		for (Object[] point : POINTS) {
			pointList.add(new Point((Double) point[0], (Double) point[1]));
		}
		return env.fromCollection(pointList);
	}

}

3.4. 主程序入口

主程序入口的代码如下,下面将逐步地分析代码的逻辑。

public static void main(String[] args) throws Exception {

	// 1.解析命令行参数
	final ParameterTool params = ParameterTool.fromArgs(args);

	// 2. 构建执行环境
	ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

	// 3. 使参数在Web界面中可用
	env.getConfig().setGlobalJobParameters(params);  

	// 4. 得到输入数据:从提供的路径读取点和质心,或返回默认数据
	DataSet<Point> points = getPointDataSet(params, env);
	DataSet<Centroid> centroids = getCentroidDataSet(params, env);

	// 5. 为K-Means算法设置批量迭代次数
	IterativeDataSet<Centroid> loop = centroids.iterate(params.getInt("iterations", 10));
    // 6. K-Means算法计算过程
	DataSet<Centroid> newCentroids = points
		// 6.1. 计算每个点距离最近的质心
		.map(new SelectNearestCenter()).withBroadcastSet(loop, "centroids")
		// 6.2. 每个簇内的所有点坐标求和
		.map(new CountAppender())
		.groupBy(0).reduce(new CentroidAccumulator())
		// 6.3. 根据点计数和坐标和计算新的质心
		.map(new CentroidAverager());

	// 7. 将新的质心数据反馈到下一个迭代中
	DataSet<Centroid> finalCentroids = loop.closeWith(newCentroids);
    // 8. 将点归宿给最终的簇
	DataSet<Tuple2<Integer, Point>> clusteredPoints = points
		.map(new SelectNearestCenter()).withBroadcastSet(finalCentroids, "centroids");

	// 9. 指定输出结果路径和执行
	if (params.has("output")) {
		clusteredPoints.writeAsCsv(params.get("output"), "\n", " ");
		env.execute("KMeans Example");
	} else {
		System.out.println("Printing result to stdout. Use --output to specify output path.");
		clusteredPoints.print();
	}
}

上面是执行主函数逻辑的全部代码。代码注释中,我将逻辑代码注释成了9步。所以下面将主要解释下重要步骤的实现细节。

3.4.1. 解析命令行参数

final ParameterTool params = ParameterTool.fromArgs(args);

任务在提交执行时是可以指定参数的,主要可传参数包括:

  1. points: 表示点数据集的输入路径;
  2. centroids:初始聚集中心(质心)的数据集;
  3. iterations: 迭代运算的迭代次数;
  4. output: 计算结果最终的保存路径。

当然如果执行时,某些参数不传,那么系统会读取默认的。

3.4.2. 构建执行环境

ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

这里没有什么好说的,就是Flink任务必须要做的事情,初始化执行程序的上下文环境。

3.4.3. 使参数在Web界面中可用

env.getConfig().setGlobalJobParameters(params); 

这一步的目的也很简单,使参数在WebUI界面中可用。所以也不多说明。

3.4.4. 得到输入数据

DataSet<Point> points = getPointDataSet(params, env);
DataSet<Centroid> centroids = getCentroidDataSet(params, env);

这一步是得到输入的数据。输入的数据包含了两部分:点数据集和聚集中心(质心)数据集。

得到点数据集的函数

/**
 * 得到输入的点数据集
 * @param params
 * @param env
 * @return
 */
private static DataSet<Point> getPointDataSet(ParameterTool params, ExecutionEnvironment env) {
	DataSet<Point> points;

	// 如果有“points”这个输入参数,则从指定CSV路径中读入点数据源
	if (params.has("points")) {
		points = env.readCsvFile(params.get("points"))
			.fieldDelimiter(" ")
			.pojoType(Point.class, "x", "y");
	// 否则,读取默认的数据源
	} else {   
		System.out.println("Executing K-Means example with default point data set.");
		System.out.println("Use --points to specify file input.");
		points = KMeansData.getDefaultPointDataSet(env);
	}
	return points;
}

得到质心数据集的函数

/**
 * 得到质心数据集
 * @param params
 * @param env
 * @return
 */
private static DataSet<Centroid> getCentroidDataSet(ParameterTool params, ExecutionEnvironment env) {
	DataSet<Centroid> centroids;
    // 如果指定了质心数据集的读入csv文件路径,那么就读取。
	if (params.has("centroids")) {
		centroids = env.readCsvFile(params.get("centroids"))
			.fieldDelimiter(" ")
			.pojoType(Centroid.class, "id", "x", "y");
	// 否则,那么就读取默认数据
	} else {
		System.out.println("Executing K-Means example with default centroid data set.");
		System.out.println("Use --centroids to specify file input.");
		centroids = KMeansData.getDefaultCentroidDataSet(env);
	}
	return centroids;
}

3.4.5. 为K-Means算法设置批量迭代次数

IterativeDataSet<Centroid> loop = centroids.iterate(params.getInt("iterations", 10));

这一步是为K-Means算法设置批量迭代次数,默认是迭代10次。

3.4.6. K-Means算法计算过程

DataSet<Centroid> newCentroids = points
			// 6.1. 计算每个点距离最近的质心
			.map(new SelectNearestCenter()).withBroadcastSet(loop, "centroids")
			// 6.2. 每个质心的点坐标计数和求和
			.map(new CountAppender())
			.groupBy(0).reduce(new CentroidAccumulator())
			// 6.3. 根据点计数和坐标和计算新的质心
			.map(new CentroidAverager());

这里是真正迭代运算的计算逻辑。其细节过程是是分步的。因为这里逻辑是算法的核心了,我们有必要细看下。

第一步:计算每个点距离最近的质心

.map(new SelectNearestCenter()).withBroadcastSet(loop, "centroids")

这里重点看下SelectNearestCenter类的执行逻辑。

/** 确定数据点最近的群集中心. */
@ForwardedFields("*->1")
public static final class SelectNearestCenter extends RichMapFunction<Point, Tuple2<Integer, Point>> {
	private Collection<Centroid> centroids;

	/** 将广播变量中的质心数据集读取到集合中*/
	@Override
	public void open(Configuration parameters) throws Exception {
		this.centroids = getRuntimeContext().getBroadcastVariable("centroids");
	}

	@Override
	public Tuple2<Integer, Point> map(Point p) throws Exception {

		double minDistance = Double.MAX_VALUE;
		int closestCentroidId = -1;

		// 遍历所有的簇中心
		for (Centroid centroid : centroids) {
			// 计算点和簇中心的欧式距离
			double distance = p.euclideanDistance(centroid);

			// 找到距离点最近的簇中心
			if (distance < minDistance) {
				minDistance = distance;
				closestCentroidId = centroid.id;
			}
		}
		// 输出一条心的记录,由簇中心id和Point组成.
		return new Tuple2<>(closestCentroidId, p);
	}
}

第二步: 每个簇内的所有点坐标求和

.map(new CountAppender()).groupBy(0)
    .reduce(new CentroidAccumulator())

这里重点看下CountAppender类的执行逻辑和CentroidAccumulator类的执行逻辑:

/** 对 Tuple2<Integer, Point>进行计数 */
@ForwardedFields("f0;f1")
public static final class CountAppender implements MapFunction<Tuple2<Integer, Point>, Tuple3<Integer, Point, Long>> {

	@Override
	public Tuple3<Integer, Point, Long> map(Tuple2<Integer, Point> t) {
	    // 对簇内点进行计数
		return new Tuple3<>(t.f0, t.f1, 1L);
	}
}
/** 对簇内点计数以及对簇内点的坐标进行累加 */
@ForwardedFields("0")
public static final class CentroidAccumulator implements ReduceFunction<Tuple3<Integer, Point, Long>> {

	@Override
	public Tuple3<Integer, Point, Long> reduce(Tuple3<Integer, Point, Long> val1, Tuple3<Integer, Point, Long> val2) {
	    // 这一步逻辑很关键,对簇内点坐标累计,然后对簇内元素个数计数。
		return new Tuple3<>(val1.f0, val1.f1.add(val2.f1), val1.f2 + val2.f2);
	}
}

这一步实现了对每个簇内的元素(点)个数进行了计数,然后对簇内的这些点的坐标进行了累加。

第三步: 根据点计数和坐标和计算新的质心

.map(new CentroidAverager());

这里看下CentroidAverager类的逻辑。

/** 从簇内点的个数和这些点的坐标和计算出新的簇中心*/
@ForwardedFields("0->id")
public static final class CentroidAverager implements MapFunction<Tuple3<Integer, Point, Long>, Centroid> {

	@Override
	public Centroid map(Tuple3<Integer, Point, Long> value) {
	    // 坐标和/簇内点个数作为新的簇中心
		return new Centroid(value.f0, value.f1.div(value.f2));
	}
}

这一步是根据上一步计算的簇内元素个数,以及这些元素的坐标和来求得新的簇中心坐标。计算方式是(坐标和/簇内元素个数)。

3.4.7. 将新的质心数据反馈到下一个迭代中

DataSet<Centroid> finalCentroids = loop.closeWith(newCentroids);

将上一次迭代计算得到的新的簇中心数据newCentroids反馈给loop,然后进行下一次迭代。

其实3.4.5到3.4.7是可以一起看的,这三步定义了批量迭代计算的逻辑。也是迭代计算(iterative computation)的定义模板。

3.4.8. 将点归宿给最终的簇

DataSet<Tuple2<Integer, Point>> clusteredPoints = points
		.map(new SelectNearestCenter()).withBroadcastSet(finalCentroids, "centroids");

将过上述三步的迭代计算之后,就可以确定下来最终的稳定的簇。那么这一步就开始把每个点归宿给最终的簇了。逻辑还是一样,通过欧式定理来归属。SelectNearestCenter类的实现逻辑在前文中讲过,所以这里不做赘述了。

3.4.9. 指定输出结果方式和执行

最后一步就是制定输出结果方式和执行。

if (params.has("output")) {
		clusteredPoints.writeAsCsv(params.get("output"), "\n", " ");
		env.execute("KMeans Example");
	} else {
		System.out.println("Printing result to stdout. Use --output to specify output path.");
		clusteredPoints.print();
	}

4. 任务正确运行之后的结果

运行之后的结果是,把每个顶点都归宿到了各个簇去了。
结果如下:

(1,-14.22 -48.01)
(4,-22.78 37.1)
(2,56.18 -42.99)
(3,35.04 50.29)
(1,-9.53 -46.26)
(4,-34.35 48.25)
(2,55.82 -57.49)
(3,21.03 54.64)
(1,-13.63 -42.26)
(4,-36.57 32.63)
(2,50.65 -52.4)
(3,24.48 34.04)
(1,-2.69 -36.02)
(4,-38.8 36.58)
(2,24.0 -53.74)
(3,32.41 24.96)
(1,-4.32 -56.92)
(4,-22.68 29.42)
(2,59.02 -39.56)
(3,24.47 45.07)
(1,5.23 -41.2)
(4,-23.0 38.15)
(2,44.55 -51.5)
(3,14.62 59.06)
(1,7.41 -56.05)
(4,-26.63 28.97)
(2,47.37 -44.72)
(3,29.07 51.06)
(1,0.59 -31.89)
(4,-39.09 20.78)
(2,42.97 -48.98)
(3,34.36 49.08)
(1,-21.91 -49.01)
(4,-46.68 46.04)
(2,48.52 -43.67)
(3,30.05 49.25)
(1,4.03 -43.56)
(4,-37.85 41.72)
(2,38.24 -48.32)
(3,20.83 57.85)

5. 总结

本案例的难点在于迭代计算的应用。机器学习算法的本质就是一个迭代计算,然后在迭代中减少损失函数的不断优化过程。掌握Flink的迭代计算,将为我们设计出更多复杂有效的机器学习模型打下基础。

后续文章中会继续推出,怎么在Flink上实现更多复杂有趣的机器学习模型。

猜你喜欢

转载自blog.csdn.net/hxcaifly/article/details/86496243