Learning To Rank算法和评价指标

排序学习是推荐、搜索、广告的核心方法,而LTR就是专门做排序任务的一个有监督的机器学习算法。所以,LTR仍然是传统的机器学习处理范式,构造特征,学习目标,训练模型,预测。LTR一般分为三种类型:PointWise,PairWise和ListWise。这三种算法并不是特定的算法,而是三种设计思路,主要区别体现在损失函数、标签标注方式和优化方法的不同。

1. PointWise

以搜索任务为例,PointWise只考虑当前Qeury与每个文档的绝对相关度,而没有考虑其他文档与Qeury的相关度。PW的方法通常将文档编码成特征向量,根据训练数据训练分类模型或者回归模型,在预测阶段,直接对文档进行打分,按照此得分排序就是搜索的结果。

处理逻辑如下图:

PointWise

  • 实施细节

训练数据格式为三元组: ( q i , d j , y i j ) (q_i,d_j,y_{ij}) ,标签 y i j y_{ij} 是2个数值,表示相关/不相关。 训练一个二分类模型或者回归模型直接拟合 y i j y_{ij} 。 loss函数:分类模型Loss函数可以使用交叉熵,回归模型Loss函数可以使用均方误差(MSE)。 预测阶段;得分直接用作排序。

  • PointWise的问题
  1. PointWise只考虑query和单个文档document之间的相关性 s i m q , d sim_{q,d} 没有考虑候选文档之间的关系。既然我们追求的目标是对候选结果进行排序,其实是想计算相对得分,直接使用 s i m q , d sim_{q,d} 的大小来排序,往往没有那么准确。实际上 s i m q , d sim_{q,d} 只是准确度概率,而不是真正的相对顺序概率。
  2. PointWise没有考虑同一个query对应的文档之间的内部依赖性。这回导致如下问题:1.导致输入空间内的样本不是独立同步分(IID)的,违反了机器学习的基本假设。2.当不同query有不同数量的文档时,整体loss容易被那些有更多文档(训练数据)的query组所支配。
  3. 排序问题关注的是topk的准确率,所以loss函数的设置需要加入相对位置排序的信息。

2. PairWise

PairWise的基本思路是对样本进行两两比较,构建偏序文档对,从比较中学习顺序。正如在PointWise中分析的,对于一个查询来说,我们需要的是检索结果正确的顺序,而不是检索结果与query的相关得分。PairWise就是希望通过正确估计一对文档的顺序,而得到整体的正确顺序。比如一个正确的排序为:“A>B>C”,PairWise通过学习两两之间的关系“A>B”,“B>C”和“A>C”来推断“A>B>C”。

处理逻辑如下: 此处输入图片的描述

  • 实施细节

训练数据格式为 ( q i , d i + , d i ) (q_i,d^+_i,d^-_i) ,是一个query的正例和负例。通常又被称为:(anchor,positive,negative)。 PairWise实际上是一种metric learning的思路来直接学习他们的相对距离,而不在乎实际的值。 loss函数:大概有两种, 1:输入pair对的Ranking Loss: L ( r 0 , r 1 , y ) = y d ( r 0 r 1 ) + ( 1 y ) m a x ( 0 , m a r g i n d ( r 0 r 1 ) ) L(r_0,r_1,y)=yd(r_0-r_1)+(1-y) max(0,margin-d(r_0-r_1)) 其中y的取值为0或者1。 2: 输入三元组的Triplet Loss或者Contrastive Loss: L ( r a , r p , r n ) = m a x ( 0 , m a r g i n + d ( r a , r p ) d ( r a , r n ) ) L(r_a,r_p,r_n)=max(0,margin+d(r_a,r_p)-d(r_a,r_n)) 预测阶段;和PointWise一样,得分直接用作排序。

  • PairWise的问题
  1. 由于需要构造pair格式的数据集,数量可能是doc数量的n倍(依据不同的构造策略),而PointWise中存在的*“当不同query有不同数量的文档时,整体loss容易被那些有更多文档(训练数据)的query组所支配”*的问题依然没存在,甚至进一步扩大。
  2. PairWise相对于PointWise对于噪音数据更敏感,即一个错误标注将会导致多个pair的错误。
  3. PairWise仍然只是考虑一对doc的相对位置,损失函数还是没有考虑候选文档之间的关系。可以认为是PointWise的优化版,基本思路没有变化。
  4. 同样的,PairWise没有考虑同一个query对应的文档之间的内部依赖性。导致输入空间内的样本不是独立同步分(IID)的,违反了机器学习的基本假设。

3. ListWise

PointWise和PairWise都是直接学习每个样本是否相关,或者两个正负样本的相关关系,更像是metric learning的思路,都是试图通过抽样的学习试图推理出全局的排序结果,这种思路是有根本的劣势。而ListWise的基本思路是试图直接优化像NDCG的排序指标,从而学习到最佳的排序结果。

  • 实施细节

输入的一个sample的格式是query以及他所有的候选doc。如给定: q i q_i ,和他的候选doc及标签: C ( d i 1 , . . , d i m ) C(d_{i1},..,d_{im}) Y ( y i 1 , . . , y i m ) Y(y_{i1},..,y_{im}) 。标签 Y Y 的值就是表示所有候选doc的顺序。比如某个候选集为 { a , d , c , b , e } \{a,d,c,b,e\} ,如果就是自然顺序,其对应的标签为 { 5 , 2 , 3 , 4 , 1 } \{5,2,3,4,1\} 。 通过各种ListWise算法训练模型。 预测阶段;根据得分来排序。

  • ListWise三种基本思路:
  1. 第一种为Measure-specific

这种方法就是直接对比如NDCG这样的指标优化。 这种方法是典型的“理想很丰满,现实很骨干”,因为NDCG、MAP和AUC这类排序指标,他们在数学形式上,是“不连续”(Non-Continuous)的,以及“不可微”(Non-Differentiable)的,基于这个现实,通常有三种解决办法: 第一种方法:找到一个近似NDCG指标的“连续”和“可微”的替代函数,通过最优化这个替代函数来优化NDCG。代表算法:SoftRank 和 AppRank。 第二种方法:尝试从数学上写出一个NDCG等指标的“边界”,然后优化这个“边界”。比如,如果推导出一个上界,那就可以通过最小化这个上界来优化 NDCG。代表算法:SVM-MAP 和 SVM-NDCG。 第三种方法:直接优化算法,可以用来处理“不连续”和“不可微”的NDCG类指标。代表算法:AdaRank 和 RankGP。 2. 第二种为Non-Measure-specific 这种方法是根据一个已知的最优排序,尝试重建这个顺序,然后衡量两者的差距,即优化模型来试图减少这个差距,比如使用KL散度作为Loss。 代表算法:ListNet 和 ListMLE 3. 第三种,ListWise和PairWise结合的算法 这类方法的核心目标仍然是优化NDCG类的排序指标,设计出一种替代的目标函数,有了替代函数之后,优化和计算过程直接使用某种PairWise的方式处理。 代表算法: LambdaRank 和 LambdaMART。

  • ListWise的优缺点
    1. 在很多场景构造训练数据比较困难。
    2. 因为要计算排序的loss,通常计算复杂度更高。
    3. 在有充足质量好的数据基础上,ListWise相比较PairWise和PointWise,直接对目标任务,也就是排序,进行学习和优化,往往表现更好。

4. 常用评价指标

nDCG

关于nDCG的解释

Mean Average Precision(MAP)

排序任务中,每个query都会有一个排序列表。顾名思义,MAP,就是测试集上所有query的AP的平均,那我们先看一下AP:

A P ( π , l ) = k = 1 m P @ k I { l π 1 ( k ) = 1 } m 1 AP(\pi,l)=\frac{\sum^m_{k=1}{P@k*I_{\{ l_{\pi^{-1}(k)}=1\}}}}{m_1}

其中, π \pi 表示item list,即推送的结果列表。 m表示结果列表总数量, m 1 m_1 表示结果列表中与query相关的item数量。 I { l π 1 ( k ) = 1 } I_{\{l_{\pi^{-1}(k)}=1\}} ,表示排在位置k处的标签是否相关,1表示相关,0表示不相关。 P @ k P@k 就是topk的Precision: P @ k ( π , l ) = t < = k I { l π 1 ( k ) = 1 } k P@k(\pi,l)=\frac{\sum_{t<=k}{I_{\{ l_{\pi^{-1}(k)}=1\}}}}{k}

另附一张图讲的很清楚: map.png-176.4kB

代码实现:

def _ap(ranked_list, ground_truth):
    # ranked_list: 结果列表,如['a', 'b', 'd', 'c', 'e']
    # ground_truth: 相关的item列表,如 ['a', 'd']
    hits = 0
    sum_precs = 0
    for n, item in enumerate(ranked_list):
        if item in ground_truth:
            hits += 1
            sum_precs += hits / (n + 1.0)
    return sum_precs / max(1.0, len(ground_truth))

复制代码

  • 参考
  1. 搜索评价指标——NDCG

猜你喜欢

转载自juejin.im/post/7018734362409582628