pyspark @udf 循环使用变量问题

问题描述

通过@udf的方式新增两列,udf应用每一列的广播变量是不一样的。但是几轮循环中,udf中接受的全局变量都是第一轮循环。伪代码如下:

from pyspark.sql.functions import udf

tlowdata = ss.createDataFrame([{
    
    'suuid': 'DDD1', 'oaid': '00-01', 'y': 1},
                    {
    
    'suuid': 'DOOD', 'oaid': '00-02', 'y': 0}, 
                    {
    
    'suuid': '009-1234', 'oaid': 'default1', 'y': 0},
                    {
    
    'suuid': 'DDD1', 'oaid': 'ttt', 'y': 0},
                    {
    
    'suuid': 'www', 'oaid': 'fwao', 'y': 0},
                    {
    
    'suuid': 'www', 'oaid': 'fff1', 'y': 0},])
tlowdata.show()

@udf
def tmp_udf(uid):
    return str(tmplst.value)

cols = ['suuid', 'oaid']
lst = [[1, 2, 3], [4, 5, 6]]
global tmplst
for i in range(len(lst)):
    tmplst = lst[i]
    
    tmplst = sc.broadcast(tmplst)
    print(tmplst.value)
    
    tlowdata = tlowdata.withColumn(c[i] + '_flag', tmp_udf(fn.col(c[i])))
    
    tmplst.unpersist()
    
tlowdata.show()
+--------+--------+---+
|    oaid|   suuid|  y|
+--------+--------+---+
|   00-01|    DDD1|  1|
|   00-02|    DOOD|  0|
|default1|009-1234|  0|
|     ttt|    DDD1|  0|
|    fwao|     www|  0|
|    fff1|     www|  0|
+--------+--------+---+

[1, 2, 3]
[4, 5, 6]
+--------+--------+---+----------+---------+
|    oaid|   suuid|  y|suuid_flag|oaid_flag|
+--------+--------+---+----------+---------+
|   00-01|    DDD1|  1| [1, 2, 3]|[1, 2, 3]|
|   00-02|    DOOD|  0| [1, 2, 3]|[1, 2, 3]|
|default1|009-1234|  0| [1, 2, 3]|[1, 2, 3]|
|     ttt|    DDD1|  0| [1, 2, 3]|[1, 2, 3]|
|    fwao|     www|  0| [1, 2, 3]|[1, 2, 3]|
|    fff1|     www|  0| [1, 2, 3]|[1, 2, 3]|
+--------+--------+---+----------+---------+

理论上,在新增oaid_flag这一列的时候,应该是[4, 5, 6],但是无论循环几次,udf里接受的广播变量始终是[1, 2, 3]。我都怀疑是spark2.4版本的问题了…

解决方案

放弃@udf,使用lambda udf或者使用UserDefinedFunction

  • lambda udf
# lambda udf
tmp_udf = fn.udf(lambda x: str(tmplst.value))
tlowdata = tlowdata.withColumn(c[i] + '_flag', tmp_udf(fn.col(c[i])))
  • UserDefinedFunction
# UserDefinedFunction
from pyspark.sql.udf import UserDefinedFunction

def tmp_udf(uid):
    return str(tmplst.value)

tlowdata = tlowdata.withColumn(c[i] + '_flag', UserDefinedFunction(lambda x: tmp_udf(x))(fn.col(c[i])))
+--------+--------+---+
|    oaid|   suuid|  y|
+--------+--------+---+
|   00-01|    DDD1|  1|
|   00-02|    DOOD|  0|
|default1|009-1234|  0|
|     ttt|    DDD1|  0|
|    fwao|     www|  0|
|    fff1|     www|  0|
+--------+--------+---+

[1, 2, 3]
[4, 5, 6]
+--------+--------+---+----------+---------+
|    oaid|   suuid|  y|suuid_flag|oaid_flag|
+--------+--------+---+----------+---------+
|   00-01|    DDD1|  1| [1, 2, 3]|[4, 5, 6]|
|   00-02|    DOOD|  0| [1, 2, 3]|[4, 5, 6]|
|default1|009-1234|  0| [1, 2, 3]|[4, 5, 6]|
|     ttt|    DDD1|  0| [1, 2, 3]|[4, 5, 6]|
|    fwao|     www|  0| [1, 2, 3]|[4, 5, 6]|
|    fff1|     www|  0| [1, 2, 3]|[4, 5, 6]|
+--------+--------+---+----------+---------+

没想明白是什么原因,暂时是以上述方法解决的。想明白了再来更新。

猜你喜欢

转载自blog.csdn.net/qq_42363032/article/details/125403637