使用SQL计算AUC值

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/MiMicoa/article/details/84721116

背景

在开发一些机器学习应用时,经常需要展示模型的ROC曲线以及AUC值。我们固然可以在代码中编写函数或者直接调用已有的软件包来计算,但在某些场景下当面临的数据量很大时,网络的传输可能会影响系统的性能。这种情况下可以考虑直接在SQL语句中计算,而不需要将数据传回到客户端,从而提升效率和稳定性。

实现

计算AUC值需要两个参数:模型的输出值和样本真实的标签。我们可以假设数据库中有一个表用来保存这两个信息,然后基于这个表进行计算。为此我们先创建一张数据表(文中的例子均以PostgreSQL语法编写)

create table score_label(
  score int,
  label boolean
);

这里的score定义为整型是为了简便,因为AUC的计算与模型输出值的绝对大小无关,只与模型输出值的相对大小相关。然后我们可以插入几条数据作为例子:

insert into score_label (score, label) values (9, TRUE);
insert into score_label (score, label) values (8, FALSE);
insert into score_label (score, label) values (7, TRUE);
insert into score_label (score, label) values (6, FALSE);

这里举了个最简单的例子:总共包含四个样本,对应的AUC是0.75。如果对AUC的计算还没有完全明确的同学可以看我之前写的这篇文章。下面可以按照这篇文章中的逻辑来编写计算逻辑,如下所示:

with ROC as (WITH
    r1 AS (SELECT score, 
    count(1) FILTER (WHERE label IS TRUE) AS t, 
    count(1) FILTER (WHERE label IS FALSE) AS f
           FROM score_label
           GROUP BY score
           ORDER BY score desc),
    r2 AS (SELECT score, t, f, 
    sum(t) OVER (ORDER BY score desc) AS tsum,
    sum(f) OVER (ORDER BY score desc) AS fsum
           FROM r1),
    r3 AS (SELECT case
                    when (SELECT sum(f) FROM r2) = 0 then 0
                    else f / (SELECT sum(f) FROM r2)
                      end AS width,
                  case
                    when (SELECT sum(t) FROM r2) = 0 then 0
                    else tsum / (SELECT sum(t) FROM r2)
                      end AS y,
                  case
                    when (SELECT sum(f) FROM r2) = 0 then 0
                    else fsum / (SELECT sum(f) FROM r2)
                      end AS x
           FROM r2
           UNION SELECT 0, 0, 0),
    r4 AS (SELECT *
           FROM r3
           ORDER BY x),
    r5 as (
        SELECT cast(x as numeric(18, 3)) x, 
               cast(y as numeric(18, 3)) y, 
               (y + lag(y, 1, 0.0)
            OVER (ORDER BY x, y)) * width / 2 AS area
    FROM r4)
        SELECT array_agg(x) x, 
               array_agg(y) y, 
               cast(sum(area) as numeric(18, 2)) AS auc
    from r5)
    
select * from ROC;

整个逻辑乍看起来比较庞杂,实际上是通过WITH语法把整个计算逻辑逐步展开,先按照score进行分组排序,然后分别计算每个点对应的TPRFPR,最后再按照微积分的思想对y进行积分。这里我们不仅能计算出AUC,还可以同时把ROC曲线上的各个点以数组的形式返回。运行这段SQL可以得到以下结果:

x y auc
{0,0,0.5,0.5,1} {0,0.5,0.5,1,1} 0.75

有了这段SQL逻辑之后,我们可以将其定义成数据库中的存储过程,或者通过客户端直接调用这段SQL,从而直接地在数据库中完成AUC的计算,而不需要将对应的数据表中的数据返回到客户机上,提升系统性能。


以上就是本文的全部内容,如果您喜欢这篇文章,欢迎将它分享给朋友们。

感谢您的阅读,祝您生活愉快!

作者:小美哥
2018-12-02

猜你喜欢

转载自blog.csdn.net/MiMicoa/article/details/84721116
auc