from pyspark.ml.feature import OneHotEncoder, StringIndexer, StringIndexerModel
df = ss.createDataFrame([
(2, "iphone"),
(11, "小米"),
(22, "huawei"),
(33, "a锤子"),
(66, "小米"),
(50, "iphone")
], ["id", "value"])
df.show()
+---+------+
| id| value|
+---+------+
| 2|iphone|
| 11| 小米|
| 22|huawei|
| 33| a锤子|
| 66| 小米|
| 50|iphone|
+---+------+
stringIndexer = StringIndexer(inputCol="value", outputCol="label").setHandleInvalid("keep")
label_model=stringIndexer.fit(df)
df = label_model.transform(df)
df.show()
+---+------+----------+
| id| value|valueIndex|
+---+------+----------+
| 2|iphone| 0.0|
| 11| 小米| 1.0|
| 22|huawei| 3.0|
| 33| a锤子| 2.0|
| 66| 小米| 1.0|
| 50|iphone| 0.0|
+---+------+----------+
label模型保存至HDFS
label_model.write().overwrite().save('XXX')
labelmodel = StringIndexerModel.load('XXX')
df1 = labelmodel.transform(df1)
df1 = ss.createDataFrame([
(12,"iphone"),
], ["id", "value"])
df2 = ss.createDataFrame([
(22, "鸿蒙"),
], ["id", "value"])
df1.show()
df2.show()
df1 = labelmodel.transform(df1)
df2 = labelmodel.transform(df2)
df1.show()
df2.show()