- Vector转array
- 一个数值向量转float
假数据构造
tmpllist = [{
'suuid': 'DoNewsa5a4de34-ec', 'ad':'110', 'y':1},
{
'suuid': 'DoNewsa5a4de34-ec', 'ad':'10', 'y':1},
{
'suuid': 'DoNewsa5a4de34-ec', 'ad':'03', 'y':0},
{
'suuid': 'DoNewsa5a4de34-ec', 'ad':'110', 'y':1}]
tmpdf = ss.createDataFrame(tmpllist)
tmpdf.show()
tmpdfsuuid = tmpdf.groupBy('suuid').agg(
fn.collect_list('ad').alias('advertisement'),
fn.collect_list('y').alias('y')
).rdd.map(row_suuidDense).toDF(schema=['suuid', 'userDense'])
tmpdfsuuid.show(truncate=False)
+---+-----------------+---+
| ad| suuid| y|
+---+-----------------+---+
|110|DoNewsa5a4de34-ec| 1|
| 10|DoNewsa5a4de34-ec| 1|
| 03|DoNewsa5a4de34-ec| 0|
|110|DoNewsa5a4de34-ec| 1|
+---+-----------------+---+
+-----------------+---------+
|suuid |userDense|
+-----------------+---------+
|DoNewsa5a4de34-ec|[110, 10]|
+-----------------+---------+
word2Vec_user = Word2Vec(vectorSize=1, minCount=0, seed=42, inputCol="userDense", outputCol="userDense_embedding")
tmpdfsuuid_user = word2Vec_user.fit(tmpdfsuuid)
tmpdfsuuid_user = tmpdfsuuid_user.getVectors()
print(tmpdfsuuid_user.count())
tmpdfsuuid_user.show(truncate=False)
+----+----------------------+
|word|vector |
+----+----------------------+
|110 |[-0.02330401912331581]|
|10 |[0.3935629725456238] |
+----+----------------------+
tmpdfsuuid_user.printSchema()
'''
root
|-- word: string (nullable = true)
|-- vector: vector (nullable = true)
'''
vector转array
import pyspark.sql.functions as F
import pyspark.sql.types as T
to_array = F.udf(lambda v: v.toArray().tolist(), T.ArrayType(T.FloatType()))
ttarray = tmpdfsuuid_user.withColumn('vector', to_array('vector'))
ttarray.printSchema()
ttarray.show()
'''
root
|-- word: string (nullable = true)
|-- vector: array (nullable = true)
| |-- element: float (containsNull = true)
+----+-------------+
|word| vector|
+----+-------------+
| 110|[-0.02330402]|
| 10| [0.39356297]|
+----+-------------+
'''
to_float = F.udf(lambda v: v[0])
ttfloat = ttarray.withColumn('vector', to_float('vector'))
ttfloat.printSchema()
ttfloat.show()
'''
root
|-- word: string (nullable = true)
|-- vector: string (nullable = true)
+----+--------------------+
|word| vector|
+----+--------------------+
| 110|-0.02330401912331581|
| 10| 0.3935629725456238|
+----+--------------------+
'''