数据治理中 PyODPS 的正确使用方式

数据治理中 PyODPS 的正确使用方式

概述:表饱和度(字段是否为空)、字段阈值(数值类字段取值是否超出有效边界)是评估数据质量的关键指标,由于是单表内字段级别的校验和统计,并且几乎涉及所有表,范围大、逻辑简单、重复性强,结合 Python 开发效率高的特点,很多数据工程师会使用 PyODPS 进行相关功能的开发。本文基于 PyODPS 分别使用 3 种方式实现了“饱和度统计”功能,展示了它们的执行效率,并分析了原因。

结论:1. 除非数据量极少,否则要避免把数据拉取到本地处理;2. 执行 SQL 的方式效率最高,并且直观,如果只是饱和度的场景,推荐这种方式,但是受 SQL 语法的限制,不够灵活;3. DataFrame SDK 的方式效率虽然略低,但使用方式非常灵活,并且可以把常用处理逻辑封装成函数,代码复用率更高;

测试环境: Mac Book Pro | 4C/16G/512G

1. 通过 open_reader 和分区级别并发实现

测试表数据量 分区数 运行时间
332万 5 22分40秒

分析: 通过 open_reader 把表内数据拉取到本地进行检验和统计,虽然代码中使用了多线程,但是并没有“真正”的并发:1、没有利用 ODPS 的计算能力,而是使用了本地的计算能力;2、Python 的 GIL(全局解释器锁) 使线程之间在串行执行。如果数据量极少,这种方式的优势是节省了创建 ODPS 实例的时间和资源开销。

2. 通过 execute_sql 全表扫实现

测试表数据量 分区数 运行时间
6万 77 21秒
1600万 23 18秒
3.1亿 14 50秒

分析:与在 DataWorks 上面执行 SQL 情况相同,只要能把 SQL 拼出来,就能实现想要的功能。但是,如果所在 Project 限制了全表扫,则需要 set odps.sql.allow.fullscan=true; 操作。缺点是,检验逻辑在 SQL 中,是靠拼字符串拼出来的,代码很难复用。

3. 通过 DataFrame 实现

测试表数据量 分区数 运行时间
6万 77 53秒
1600万 23 39秒
3.1亿 14 5分钟56秒

分析:使用了 DataFrame 的 MapReduce API 与聚合函数 sum,DataFrame 会提交 1 个实例,把 DateFrame 操作转换为 UDF SQL,通过 Mapper 把原表数据转换为 int 类型的计数,然后进行 sum 操作。DataFrame 比 SQL 执行效率低的原因,在于自定义 Python 函数与内置函数的性能差异,例如,作者曾使用自己编写的 Reducer 代替 sum ,效率会大大降低。

代码

#!/usr/bin/python
# -*- coding: UTF-8 -*-
import sys

reload(sys)
sys.setdefaultencoding("utf-8")

from odps import ODPS
import datetime
from odps import options
from datetime import datetime, timedelta
from time import sleep
from odps.df import DataFrame
from odps.df import output
from queue import Queue
import threading

o = ODPS('xxx', 'xxx', 'xxx', endpoint='xxx')

options.interactive = True  # 用到 print 需要打开
options.verbose = True      # 输出进度信息和错误信息


def mapper(row):
    ret = [1]
    for i in range(len(row)):
        ret.append(1 if row[i] is None or str(row[i]).isspace() or len(str(row[i])) <= 0 else 0)
    yield tuple(ret)

def data_frame(table_name):
    findnull = DataFrame(o.get_table(table_name))
    col_num = len(findnull.dtypes)
    output_types = ['int' for i in range(col_num + 1)]
    output_names = [findnull.dtypes.columns[i].name for i in range(col_num)]
    output_names.insert(0, 'total_cnt')
    table = findnull.map_reduce(mapper, mapper_output_names=output_names, mapper_output_types=output_types)
    print(table.sum())

def check_data_by_execute_sql(table_name):
    ta = o.get_table(table_name)
    data_count = {
    
    }
    table_count = 0
    sql_str = 'select \n'
    for col in ta.schema.columns:
        col_name = col.name
        sql_str += "sum(case when (%s is null) or (%s in ('','null','NULL')) or (trim(%s) = '') then 1 else 0 end) as %s_yx,\n" % (col_name, col_name, col_name, col_name)
    sql_str += "count(1) as total_cnt \nfrom %s " %(table_name)
    print(sql_str)
    with o.execute_sql(sql_str).open_reader() as rec:
        for r in rec:
            for col in ta.schema.columns:
                print("%s\t\t%d" % (col.name, r.get_by_name(col.name + '_yx')))
            print("%s\t\t%d" % ('total_cnt', r.get_by_name('total_cnt')))

def get_last_day():
    today = datetime.today()
    last_day = today + timedelta(days=-1)
    return last_day.strftime('%Y%m%d')

count_queue = Queue()
threads = []

def check_data_by_open_reader(table_name, pt):
    ta = o.get_table(table_name)
    data_count = {
    
    }
    print(table_name + "\t:\t" + str(pt) + " STARTED")
    rec = ta.open_reader(partition=str(pt))
    table_count = rec.count
    for r in rec:
        for col in ta.schema:
            col_value = r.get_by_name(col.name)
            if col.name not in data_count:
                data_count[col.name] = 0
            if col_value == None or str(col_value).isspace() or len(str(col_value)) <= 0:
                data_count[col.name] += 1
    count_queue.put((data_count, table_count))
    print(table_name + "\t:\t" + str(pt) + " DONE")

# 假设 dt 为分区字段
def check_data(table_name):
    table_tocheck = o.get_table(table_name)
    for pt in table_tocheck.iterate_partitions("dt='" + get_last_day() + "'"):  
        t = threading.Thread(target=check_data_by_open_reader, args=(table_name, pt))
        t.setDaemon(True)
        t.start()
        threads.append(t)

    print("线程数共:" + str(len(threads)))

    while True:
        thread_num = len(threading.enumerate()) - 1
        print("线程数量是%d" % thread_num)
        if thread_num <= 0:
            break
        sleep(10)

    total_cnt = 0
    total_data_cnt = {
    
    }
    while not count_queue.empty():
        pt_data = count_queue.get()
        data_count = pt_data[0]
        total_cnt += pt_data[1]
        for col_name in data_count.keys():
            if col_name not in total_data_cnt:
                total_data_cnt[col_name] = 0
            total_data_cnt[col_name] += data_count[col_name]

    print(total_cnt, total_data_cnt)

if __name__ == '__main__':
    table_name = 'xxxx' 

    if len(sys.argv) == 2:
        if sys.argv[1] not in ('1', '2', '3'):
            print("ARG ERROR: %s 1|2|3" % sys.argv[0])
            exit()
        print(datetime.now().strftime('%Y-%m-%d %H:%M:%S  BEGIN with ' + table_name))
        if sys.argv[1] == '1':
            check_data(table_name)
        elif sys.argv[1] == '2':
            check_data_by_execute_sql(table_name)
        elif sys.argv[1] == '3':
            data_frame(table_name)
        print(datetime.now().strftime('%Y-%m-%d %H:%M:%S  DONE with ' + table_name))
    else:
        print("ARG ERROR: %s 1|2|3" % sys.argv[0])

猜你喜欢

转载自blog.csdn.net/ManWZD/article/details/106882005