简介
lag函数用于把指定列向后移动多少行之后和原表拼接。
lag(column,n,default)
lead(column,n,default)
其中 column表示要移动的列,n表示要移动多少行,default表示默认值,不给就是null
示例:
select gid,
lag(time,1,'0') over (partition by gid order by time) as lag_time,
lead(time,1,'0') over (partition by gid order by time) as lead_time
from user_order;
实例
考虑这样一个问题,我们要找出连续2个月都购买了某个指定商品的人,怎么处理?
很多同学可能首先考虑的是把时间处理成月份,然后根据用户id分组,统计月份大于2的就可以了。
如果问题变一点点,要找出在60天内至少购买过2次某个指定商品的人,怎么处理?
咋一看,一样的,实际上没有给出是哪60天,所以要计算所有订单时间的差值,然后找出时间差在60天范围之内的订单用户。
逻辑比较简单,但是处理起来不太还处理,不过如果知道lag或者lead函数,就变得非常简单了。
只需要把使用head函数,把后一次订单时间拼接上来,然后通过过滤器过滤一下就好了。
下面看一个代码示例,应该就清楚了。
代码实例
import org.apache.spark.api.java.function.FilterFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.junit.Before;
import org.junit.Test;
import java.io.Serializable;
import java.time.LocalDate;
import java.time.format.DateTimeFormatter;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
public class SparkHiveLagTest implements Serializable {
private SparkSession sparkSession;
@Before
public void setUp() {
sparkSession = SparkSession
.builder()
.appName("test")
.master("local")
.getOrCreate();
}
private static List<Info> getInfos() {
String[] gids = {"10001","10002","10003","10004","10005"};
LocalDate base = LocalDate.of(2020, 1, 1);
DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd");
Random random = new Random(120);
LinkedList<Info> infos = new LinkedList<>();
for(int i=0;i<50;i++){
Info info = new Info();
info.setGid(gids[random.nextInt(gids.length)]);
info.setDate(base.plusDays(random.nextInt(365)).format(formatter));
infos.add(info);
}
return infos;
}
@Test
public void lag(){
List<Info> infos = getInfos();
Dataset<Info> dataset = sparkSession.createDataset(infos, Encoders.bean(Info.class));
dataset.createOrReplaceTempView("temp");
String sql = "select gid,date,lag(date,1,0) over(partition by gid order by date) lag_date from temp";
Dataset<Row> ds = sparkSession.sql(sql);
ds.show(200);
ds = ds.filter(new FilterFunction<Row>() {
@Override
public boolean call(Row row) throws Exception {
String lagDate = row.getString(2);
if (lagDate.equals("0")) {
return false;
}
String date = row.getString(1);
DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd");
LocalDate localDateBase = LocalDate.parse(date, formatter);
LocalDate localDateLag = LocalDate.parse(lagDate, formatter);
if (localDateBase.toEpochDay() - localDateLag.toEpochDay() <= 7) {
return true;
}
return false;
}
});
ds.show();
ds.select("gid").orderBy("gid").distinct().show();
}
public static class Info implements Serializable {
private String gid;
private String date;
public String getGid() {
return gid;
}
public void setGid(String gid) {
this.gid = gid;
}
public String getDate() {
return date;
}
public void setDate(String date) {
this.date = date;
}
}
}
简单解释一下:
getInfos函数是模拟生成一点订单数据,用于测试
select gid,date,lag(date,1,0) over(partition by gid order by date) lag_date from temp
gid用于标识用户,不同的用户是单独计算,所以按gid partition,
date是订单日期,按日期排序之后才能保证是相邻的订单
lag(date,1,0)是把订单日期向后移动1行,如果是第一个订单,没有从前面移动过来的上一个订单日期,就填充0
这样我们就得到一个有用户id、订单日期、前一个订单日期的表了,接下来的过滤就非常简单了。
当然我们也可以使用lead窗口函数,这样得到的就算用户id、订单日期、后一个订单日期,过滤的时候逻辑变一下就可以了。