Flink Sql教程(4)

Flink UDF

概述

  • 什么是UDF
    • UDF是User-defined Functions的缩写,即自定义函数。
  • UDF种类
    • UDF分为三种:Scalar Functions、Table Functions、Aggregation Functions
    • Scalar Functions
      • 接收0、1、多个参数,返回一个值
    • Table Functions
      • 和上面的Scalar Functions接收的参数个数一样,不同的是可以返回多行,而不是单个值
    • Aggregation Functions
      • 从名字就可以看出来,这个是搭配GROUP BY一起使用的,将表的一个或多个列的一行或多行数据汇聚到一个值里面,看上去有点拗口,其实可以把它简单理解为SQL中的聚合函数
    • Table Aggregation Functions
      • 相当于Table Functions和Aggregation Functions的结合体,聚合之后,再返回多行多列
  • 为什么要有UDF
    • Flink SQL目前提供了很多的内置UDF,主要是为了大家更方便的编写SQL代码完成自己的业务逻辑,具体内置的UDF可以参考官方文档;同时,Flink 也支持注册自己的UDF,下面正式开始我们今天的UDF探索之旅。

Scalar Functions

	//不墨迹,我们直接贴代码
	package udf;

	import org.apache.flink.table.functions.ScalarFunction;
	
	
	public class TestScalarFunc extends ScalarFunction {
	
	    private int factor = 2020;
		//和传入数据进行计算的逻辑,参数个数任意
	    public int eval() {
	        return factor;
	    }
	
	    public int eval(int a) {
	        return a * factor;
	    }
	
	    public int eval(int... a) {
	        int res = 1;
	        for (int i : a) {
	            res *= i;
	        }
	        return res * factor;
	    }
}

  • 自定义Scalar Functions,需要继承ScalarFunction,并且有一个publiceval(),方法可以接受任意个数参数,同时也可以在一个类中重载eval()
  • 写完UDF之后需要注册到我们的运行环境中,使用姿势有两种:
    • tEnv.sqlUpdate("CREATE FUNCTION IF NOT EXISTS test AS 'udf.TestScalarFunc'");
    • tEnv.registerFunction("test",new TestScalarFunc());
    • 第一种偏向在纯SQL的环境中使用,比如我们有个Flink SQL的提交平台,只支持纯SQL语句,那我们可以把自己写的UDF打包上传到平台后,通过SQL语句CREATE FUNCTION IF NOT EXISTS test AS 'udf.TestScalarFunc'来创建UDF;同时可以把UDF注册到catalog中,这里先不深入讨论,之后我们说到Flink X Hive的时候再聊吧
    • 第二种注册方式,如果我们的类有构造方法,可以通过new 对象的时候传递变量进去,更为灵活一点

Table Functions

	package udf;

	import org.apache.flink.api.common.typeinfo.TypeInformation;
	import org.apache.flink.api.common.typeinfo.Types;
	import org.apache.flink.calcite.shaded.com.google.common.base.Strings;
	import org.apache.flink.table.functions.TableFunction;
	import org.apache.flink.types.Row;
	
	public class TestTableFunction extends TableFunction {
	
	    private String separator = ",";
	
	    public TestTableFunction(String separator) {
	        this.separator = separator;
	    }

		//和传入数据进行计算的逻辑,参数个数任意
	    public void eval(String input){
	        
	        Row row = null;
	        
	        if (Strings.isNullOrEmpty(input)){
	            
	            row = new Row(2);
	            row.setField(0,null);
	            row.setField(1,0);
	            collect(row);
	            
	        }else {
	            
	            String[] split = input.split(separator);
	            
	            for (String word : split) {
	                row = new Row(2);
	                row.setField(0,word);
	                row.setField(1,word.length());
	                collect(row);
	            }
	            
	        }
	
	    }
	
	    @Override
	    public TypeInformation getResultType() {
	        return Types.ROW(Types.STRING,Types.INT);
	    }
	}

  • 自定义Table Functions,需要继承TableFunction,并且有一个publiceval(),方法可以接受任意个数参数,同时也可以在一个类中重载eval()
  • 因为返回的是Row类型,所以需要重写getResultType()
  • 在SQL语句中使用时,有两种写法:
    • select a.age,b.name,b.name_length from t2 a, LATERAL TABLE(test2(a.name_list)) as b(name, name_length)
    • select a.age,b.name,b.name_length from t2 a LEFT JOIN LATERAL TABLE(test2(a.name_list)) as b(name, name_length) ON TRUE
    • 第一种的用法相当于用的是CROSS JOIN
    • 第二种的用法是LEFT JOIN

Aggregation Functions

	package udf;

	import org.apache.flink.table.functions.AggregateFunction;
	
	import java.util.Iterator;
	
	public class TestAggregateFunction extends AggregateFunction<Long, TestAggregateFunction.SumAll> {
		//返回最终结果
	    @Override
	    public Long getValue(SumAll acc) {
	        return acc.sum;
	    }
		//构建保存中间结果的对象
	    @Override
	    public SumAll createAccumulator() {
	        return new SumAll();
	    }
		//和传入数据进行计算的逻辑
	    public void accumulate(SumAll acc, long iValue) {
	        acc.sum += iValue;
	    }
	
		//减去要撤回的值
	    public void retract(SumAll acc, long iValue) {
	        acc.sum -= iValue;
	    }
	    
		//从每个分区把数据取出来然后合并
	    public void merge(SumAll acc, Iterable<SumAll> it) {
	
	        Iterator<SumAll> iter = it.iterator();
	
	        while (iter.hasNext()) {
	            SumAll a = iter.next();
	            acc.sum += a.sum;
	
	        }
	    }
		//重置内存中值时调用
	    public void resetAccumulator(SumAll acc) {
	        acc.sum = 0L;
	    }
	
	    public static class SumAll {
	        public long sum = 0;
	    }
	
	}

  • 自定义Aggregation Functions,需要继承AggregateFunction,并且必须要有 以下的方法
    • createAccumulator() 创建一个保留中间结果的数据结构
    • accumulate() 把每个输入行与中间结果进行计算,可以重载
    • getValue() 获取最终结果
  • 根据不同的使用情况,还需要以下的方法
    • retract() 用于bounded OVER窗口,即窗口有结束时间
    • merge()用于多次批量聚合和会话窗口合并
    • resetAccumulator()用于多次批量聚合时,清空中间结果

Table Aggregation Functions

	package udf;

	import org.apache.flink.api.common.typeinfo.TypeInformation;
	import org.apache.flink.api.common.typeinfo.Types;
	import org.apache.flink.table.functions.TableAggregateFunction;
	import org.apache.flink.types.Row;
	import org.apache.flink.util.Collector;
	
	public class TestTableAggregateFunction extends TableAggregateFunction<Row,TestTableAggregateFunction.Top2> {
		//创建保留中间结果的对象
	    @Override
	    public Top2 createAccumulator() {
	        Top2 t = new Top2();
	        t.f1 = Integer.MIN_VALUE;
	        t.f2 = Integer.MIN_VALUE;
	
	        return t;
	    }
		//与传入值进行计算的方法
	    public void accumulate(Top2 t, Integer v) {
		    //如果传入的值比内存中第一个值大,那就用第一个值替换第二个值,传入的值替换第一个值;
		    //如果传入的值比第二个值大比第一个小,那么就替换第二个值。
	        if (v > t.f1) {
	            t.f2 = t.f1;
	            t.f1 = v;
	        } else if (v > t.f2) {
	            t.f2 = v;
	        }
	    }
		
		//合并分区的值
	    public void merge(Top2 t, Iterable<Top2> iterable) {
	        for (Top2 otherT : iterable) {
	            accumulate(t, otherT.f1);
	            accumulate(t, otherT.f2);
	        }
	    }
	
		//拿到返回结果的方法
	    public void emitValue(Top2 t, Collector<Row> out) {
	        Row row = null;
	        //发射数据
	        //如果第一个值不是最小的int值,那就发出去
	        //如果第二个值不是最小的int值,那就发出去
	        if (t.f1 != Integer.MIN_VALUE) {
	            row = new Row(2);
	            row.setField(0,t.f1);
	            row.setField(1,1);
	            out.collect(row);
	        }
	        if (t.f2 != Integer.MIN_VALUE) {
	            row = new Row(2);
	            row.setField(0,t.f2);
	            row.setField(1,2);
	            out.collect(row);
	        }
	    }
		//撤回流拿结果的方法,会发射撤回数据
	    public void emitUpdateWithRetract(Top2 t, RetractableCollector<Row> out) {
	        Row row = null;
	        //如果新旧值不相等,才需要撤回,不然没必要
	        //如果旧值不等于int最小值,说明之前发射过数据,需要撤回
	        //然后将新值发射出去
	        if (!t.f1.equals(t.oldF1)) {
	            if (t.oldF1 != Integer.MIN_VALUE) {
	                row = new Row(2);
	                row.setField(0,t.oldF1);
	                row.setField(1,1);
	                out.retract(row);
	            }
	            row = new Row(2);
	            row.setField(0,t.f1);
	            row.setField(1,1);
	            out.collect(row);
	            t.oldF1 = t.f1;
	        }
		    //和上面逻辑一样,只是一个发射f1,一个f2
	        if (!t.f2.equals(t.oldF2)) {
	            // if there is an update, retract old value then emit new value.
	            if (t.oldF2 != Integer.MIN_VALUE) {
	                row = new Row(2);
	                row.setField(0,t.oldF2);
	                row.setField(1,2);
	                out.retract(row);
	            }
	            row = new Row(2);
	            row.setField(0,t.f2);
	            row.setField(1,2);
	            out.collect(row);
	            t.oldF2 = t.f2;
	        }
	    }
	    //保留中间结果的类
	    public class Top2{
	        public Integer f1;
	        public Integer f2;
	        public Integer oldF1;
	        public Integer oldF2;
	
	    }
	
	    @Override
	    public TypeInformation<Row> getResultType() {
	        return Types.ROW(Types.INT,Types.INT);
	    }
	}

  • 自定义Table Aggregation Functions,需要继承TableAggregateFunction,并且必须要有 以下的方法
    • createAccumulator() 创建一个保留中间结果的数据结构
    • accumulate() 把每个输入行与中间结果进行计算,可以重载
  • 根据不同的使用情况,还需要以下的方法
    • retract() 用于bounded OVER窗口,即窗口有结束时间
    • merge()用于多次批量聚合和会话窗口合并
    • resetAccumulator()用于多次批量聚合时,清空中间结果
    • emitValue() 用于批量和窗口聚合拿到结果
    • emitUpdateWithRetract() 用于流式计算的撤回流
  • 目前Table Aggregation Functions只支持在Table Api中使用

完整代码

	//下面贴出来的是主类的代码,具体每个UDF的类上面已经有了
	package FlinkSql;


	import org.apache.flink.api.common.typeinfo.TypeInformation;
	import org.apache.flink.api.common.typeinfo.Types;
	import org.apache.flink.api.java.tuple.Tuple2;
	import org.apache.flink.streaming.api.datastream.DataStream;
	import org.apache.flink.streaming.api.datastream.DataStreamSource;
	import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
	import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
	import org.apache.flink.table.api.Table;
	import org.apache.flink.types.Row;
	import udf.TestAggregateFunction;
	import udf.TestScalarFunc;
	import udf.TestTableAggregateFunction;
	import udf.TestTableFunction;
	
	import static util.FlinkConstant.env;
	import static util.FlinkConstant.tEnv;
	
	public class FlinkSql04 {
	    public static void main(String[] args) throws Exception {
	
	
	        DataStream<Row> source = env.addSource(new RichSourceFunction<Row>() {
	
	            @Override
	            public void run(SourceContext<Row> ctx) throws Exception {
	                    Row row = new Row(3);
	                    row.setField(0, 2);
	                    row.setField(1, 3);
	                    row.setField(2, 3);
	                    ctx.collect(row);
	            }
	
	            @Override
	            public void cancel() {
	
	            }
	        }).returns(Types.ROW(Types.INT,Types.INT,Types.INT));
	
	        tEnv.createTemporaryView("t",source,"a,b,c");
	
	//        tEnv.sqlUpdate("CREATE FUNCTION IF NOT EXISTS test AS 'udf.TestScalarFunc'");
	
	        tEnv.registerFunction("test",new TestScalarFunc());
	
	        Table table = tEnv.sqlQuery("select test() as a,test(a) as b, test(a,b,c) as c from t");
	
	        DataStream<Row> res = tEnv.toAppendStream(table, Row.class);
	
	//        res.print().name("Scalar Functions Print").setParallelism(1);
	
	        DataStream<Row> ds2 = env.addSource(new RichSourceFunction<Row>() {
	
	
	            @Override
	            public void run(SourceContext<Row> ctx) throws Exception {
	                    Row row = new Row(2);
	                    row.setField(0, 22);
	                    row.setField(1, "aa,b,cdd,dfsfdg,exxxxx");
	                    ctx.collect(row);
	            }
	
	            @Override
	            public void cancel() {
	
	            }
	        }).returns(Types.ROW(Types.INT, Types.STRING));
	
	        tEnv.createTemporaryView("t2",ds2,"age,name_list");
	
	        tEnv.registerFunction("test2",new TestTableFunction(","));
	
	//        Table table2 = tEnv.sqlQuery("select a.age,b.name,b.name_length from t2 a, LATERAL TABLE(test2(a.name_list)) as b(name, name_length)");
	
	        Table table2 = tEnv.sqlQuery("select a.age,b.name,b.name_length from t2 a LEFT JOIN LATERAL TABLE(test2(a.name_list)) as b(name, name_length) ON TRUE");
	
	        DataStream<Row> res2 = tEnv.toAppendStream(table2, Row.class);
	
	//        res2.print().name("Table Functions Print").setParallelism(1);
	
	        DataStream<Row> ds3 = env.addSource(new RichSourceFunction<Row>() {
	            @Override
	            public void run(SourceContext<Row> ctx) throws Exception {
	                Row row1 = new Row(2);
	                row1.setField(0,"a");
	                row1.setField(1,1L);
	
	                Row row2 = new Row(2);
	                row2.setField(0,"a");
	                row2.setField(1,2L);
	
	                Row row3 = new Row(2);
	                row3.setField(0,"b");
	                row3.setField(1,100L);
	
	                ctx.collect(row1);
	                ctx.collect(row2);
	                ctx.collect(row3);
	
	            }
	
	            @Override
	            public void cancel() {
	
	            }
	        }).returns(Types.ROW(Types.STRING, Types.LONG));
	
	        tEnv.createTemporaryView("t3",ds3,"name,cnt");
	
	        tEnv.registerFunction("test3",new TestAggregateFunction());
	
	        Table table3 = tEnv.sqlQuery("select name,test3(cnt) as mySum from t3 group by name");
	
	        DataStream<Tuple2<Boolean, Row>> res3 = tEnv.toRetractStream(table3, Row.class);
	
	//        res3.print().name("Aggregate Functions Print").setParallelism(1);
	
	        DataStream<Row> ds4 = env.addSource(new RichSourceFunction<Row>() {
	            @Override
	            public void run(SourceContext<Row> ctx) throws Exception {
	                Row row1 = new Row(2);
	                row1.setField(0,"a");
	                row1.setField(1,1);
	
	                Row row2 = new Row(2);
	                row2.setField(0,"a");
	                row2.setField(1,2);
	
	                Row row3 = new Row(2);
	                row3.setField(0,"a");
	                row3.setField(1,100);
	
	                ctx.collect(row1);
	                ctx.collect(row2);
	                ctx.collect(row3);
	            }
	
	            @Override
	            public void cancel() {
	
	            }
	        }).returns(Types.ROW(Types.STRING, Types.INT));
	
	        tEnv.createTemporaryView("t4",ds4,"name,cnt");
	
	        tEnv.registerFunction("test4",new TestTableAggregateFunction());
	
	        Table table4 = tEnv.sqlQuery("select * from t4");
	
	        Table table5 = table4.groupBy("name")
	                .flatAggregate("test4(cnt) as (v,rank)")
	                .select("name,v,rank");
	
	        DataStream<Tuple2<Boolean, Row>> res4 = tEnv.toRetractStream(table5, Row.class);
	
	        res4.print().name("Aggregate Functions Print").setParallelism(1);
	
	        env.execute("test udf");
	
	    }
	}

猜你喜欢

转载自blog.csdn.net/weixin_47482194/article/details/106292872