pyspark---label编码

from pyspark.ml.feature import OneHotEncoder, StringIndexer, StringIndexerModel

df = ss.createDataFrame([
    (2, "iphone"),
    (11, "小米"),
    (22, "huawei"),
    (33, "a锤子"),
    (66, "小米"),
    (50, "iphone")
], ["id", "value"])
df.show()
+---+------+
| id| value|
+---+------+
|  2|iphone|
| 11|  小米|
| 22|huawei|
| 33| a锤子|
| 66|  小米|
| 50|iphone|
+---+------+
stringIndexer = StringIndexer(inputCol="value", outputCol="label").setHandleInvalid("keep")
label_model=stringIndexer.fit(df)
df = label_model.transform(df)
df.show()
+---+------+----------+
| id| value|valueIndex|
+---+------+----------+
|  2|iphone|       0.0|
| 11|  小米|       1.0|
| 22|huawei|       3.0|
| 33| a锤子|       2.0|
| 66|  小米|       1.0|
| 50|iphone|       0.0|
+---+------+----------+

label模型保存至HDFS

label_model.write().overwrite().save('XXX')
labelmodel = StringIndexerModel.load('XXX')
df1 = labelmodel.transform(df1)
df1 = ss.createDataFrame([
    (12,"iphone"),
], ["id", "value"])
df2 = ss.createDataFrame([
    (22, "鸿蒙"),
], ["id", "value"])
df1.show()
df2.show()

df1 = labelmodel.transform(df1)
df2 = labelmodel.transform(df2)
df1.show()
df2.show()

https://www.cnblogs.com/SoftwareBuilding/p/9492285.html

Guess you like

Origin blog.csdn.net/qq_42363032/article/details/120200380