spark2.0 AFTSurvivalRegression算法

spark2.0的机器学习算法比之前的改变最大的是2.0基本采用了dataframe来实现的,但事前的都是用的RDD,看官网说貌似在3.0的时候RDD就不用了,不知道真的假的。

还有一个就是hiveContext和sqlcontext进行了合并,统一是sessioncontext

val spark = SparkSession
  .builder
  .appName("AFTSurvivalRegressionExample").master("local")
  .getOrCreate()

AFTSurvivalRegression
实现了加速失效时间(AFT)模型,这是一个用于检查数据的参数生存回归模型。 它描述了生存时间对数的模型,因此它通常被称为生存分析的对数线性模型
val training = spark.createDataFrame(Seq(
  (1.218, 1.0, Vectors.dense(1.560, -0.605)),
  (2.949, 0.0, Vectors.dense(0.346, 2.158)),
  (3.627, 0.0, Vectors.dense(1.380, 0.231)),
  (0.273, 1.0, Vectors.dense(0.520, 1.151)),
  (4.199, 0.0, Vectors.dense(0.795, -0.226))
)).toDF("label", "censor", "features")
第一个label表示的是存活的时间,你可以把这个模型看做是预测你能活多长时间的,当然是需要很多方面的参数的
不然就是在扯淡了,虽然这预测听起来很扯淡。。。。。。
第二个censor是结局,1表示死亡,0表示删失数据,病历失访或者尚存活
表现在病人身上就是,你这个人得了一个癌症,根据你的各项指标,用这个模型预测你能活的时间
听起来就很残酷,1表示这个人已经去世,0可能是还活着或者其他因素而没获取到数据
后面的几个参数就是各种病症或者身体情况的症状了,最终都要转化为数据的形式,俗称归一化

分位数概率数组参数。
分位数概率数组的值应在范围内(0,1)
数组应该是非空的。
val quantileProbabilities = Array(0.3, 0.6)
val aft = new AFTSurvivalRegression()
  .setQuantileProbabilities(quantileProbabilities)
如果设置该列,则会输出相应的分位数概率的分位数

 .setQuantilesCol("quantiles")
val model = aft.fit(training)

输出模型的系数

println(s"Coefficients: ${model.coefficients}")
模型的截距 
println(s"Intercept: ${model.intercept}")

源码里面是这个 val scale = math.exp(parameters(0))
 println(s"Scale: ${model.scale}")
Coefficients: [-0.4963111466650707,0.19844437699933098]
Intercept: 2.63809461510401
Scale: 1.5472345574364692
model.transform(training).show(false)

+-----+------+--------------+------------------+--------------------------------------+
|label|censor|features      |prediction        |quantiles                             |
+-----+------+--------------+------------------+--------------------------------------+
|1.218|1.0   |[1.56,-0.605] |5.718979487635007 |[1.1603238947151664,4.99545601027477] |
|2.949|0.0   |[0.346,2.158] |18.07652118149533 |[3.667545845471739,15.789611866277625]|
|3.627|0.0   |[1.38,0.231]  |7.381861804239096 |[1.4977061305190829,6.44796261233896] |
|0.273|1.0   |[0.52,1.151]  |13.577612501425284|[2.7547621481506854,11.8598722240697] |
|4.199|0.0   |[0.795,-0.226]|9.013097744073898 |[1.8286676321297826,7.87282650587843] |
+-----+------+--------------+------------------+--------------------------------------+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
还可以通过类似sql的方式来选择展示结果
model.transform(training).selectExpr(
   "label" , "censor" ,
   "round(prediction,2) as prediction" ).orderBy( "label" )









猜你喜欢

转载自blog.csdn.net/qq_36421826/article/details/72860687