【译】R包介绍:Online Random Forest

作者:顾全,浙江大学软件工程硕士,现任桃树科技算法工程师

地址:https://github.com/ZJUguquan/OnlineRandomForest

参与:Cynthia

翻译:本文为天善智能编译,未经容许,禁止转载

介绍

Online Random Forest(ORF) 是由Amir Saffari等人最先提出。之后,Arthur Lui使用Python实现了算法。非常感谢他们的工作。在论文内容和Lui的算法的基础上,我通过R和R6包重构了代码。此外,ORF在此包中的实现,与randomForest结合,使它同时支持增量学习和批量学习,例如:在ORF的基础上构建树,然后通过ORF进行更新。通过这种方法,它将比以前快得多。

安装

if(!require(devtools)) install.packages("devtools")

devtools::install_github("ZJUguquan/OnlineRandomForest")

快速启动

最小举例:增量学习

扫描二维码关注公众号,回复: 1034464 查看本文章

library(OnlineRandomForest)param<-list('minSamples'=1,'minGain'=0.1,'numClasses'=3,'x.rng'=dataRange(iris[1:4]))orf<-ORF$new(param,numTrees=10)for(iin1:150)orf$update(iris[i,1:4], as.integer(iris[i,5]))cat("Mean depth of trees in the forest is:",orf$meanTreeDepth(),"\n")orf$forest[[2]]$draw()

## Mean depth of trees in the forest is: 3

## Root X4 < 1.21

## |----L: X3 < 2.38

##      |----L: Leaf 1

##      |----R: Leaf 2

## |----R: X4 < 2.15

##      |----L: X1 < 4.92

##          |----L: Leaf 3

##          |----R: Leaf 3

##      |----R: Leaf 3

分类举例

library(OnlineRandomForest)#data preparationdat<-iris;dat[,5]<-as.integer(dat[,5])x.rng<-dataRange(dat[1:4])param<-list('minSamples'=2,'minGain'=0.2,'numClasses'=3,'x.rng'=x.rng)ind.gen<-sample(1:150,30)#for generate ORFind.updt<-sample(setdiff(1:150,ind.gen),100)#for uodate ORFind.test<-setdiff(setdiff(1:150,ind.gen),ind.updt)#for test#construct ORF and updaterf<-randomForest::randomForest(factor(Species)~.,data=dat[ind.gen, ],maxnodes=2,ntree=100)orf<-ORF$new(param)orf$generateForest(rf,df.train=dat[ind.gen, ],y.col="Species")cat("Mean size of trees in the forest is:",orf$meanTreeSize(),"\n")

## Mean size of trees in the forest is: 3

for(iinind.updt) {orf$update(dat[i,1:4],dat[i,5])}cat("After update, mean size of trees in the forest is:",orf$meanTreeSize(),"\n")

## After update, mean size of trees in the forest is: 11.9

#predictorf$confusionMatrix(dat[ind.test,1:4],dat[ind.test,5],pretty=T)

##

## 

##    Cell Contents

## |-------------------------|

## |                      N |

## |          N / Row Total |

## |          N / Col Total |

## |-------------------------|

##

## 

## Total Observations in Table:  20

##

## 

##              | actual

##  prediction |        1 |        2 |        3 | Row Total |

## -------------|-----------|-----------|-----------|-----------|

##            1 |        4 |        0 |        0 |        4 |

##              |    1.000 |    0.000 |    0.000 |    0.200 |

##              |    1.000 |    0.000 |    0.000 |          |

## -------------|-----------|-----------|-----------|-----------|

##            2 |        0 |        9 |        2 |        11 |

##              |    0.000 |    0.818 |    0.182 |    0.550 |

##              |    0.000 |    1.000 |    0.286 |          |

## -------------|-----------|-----------|-----------|-----------|

##            3 |        0 |        0 |        5 |        5 |

##              |    0.000 |    0.000 |    1.000 |    0.250 |

##              |    0.000 |    0.000 |    0.714 |          |

## -------------|-----------|-----------|-----------|-----------|

## Column Total |        4 |        9 |        7 |        20 |

##              |    0.200 |    0.450 |    0.350 |          |

## -------------|-----------|-----------|-----------|-----------|

##

##

#comparetable(predict(rf,newdata=dat[ind.test,])==dat[ind.test,5])

## FALSE  TRUE

##    9    11

table(orf$predicts(X=dat[ind.test,])==dat[ind.test,5])

## FALSE  TRUE

##    2    18

回归举例

#data preparationif(!require(ggplot2)) install.packages("ggplot2")data("diamonds",package="ggplot2")dat<-as.data.frame(diamonds[sample(1:53000,1000), c(1:6,8:10,7)])for(colinc("cut","color","clarity"))dat[[col]]<-as.integer(dat[[col]])#Don't forget thisx.rng<-dataRange(dat[1:9])param<-list('minSamples'=10,'minGain'=1,'maxDepth'=10,'x.rng'=x.rng)ind.gen<-sample(1:1000,800)ind.updt<-sample(setdiff(1:1000,ind.gen),100)ind.test<-setdiff(setdiff(1:1000,ind.gen),ind.updt)

#construct ORFrf<-randomForest::randomForest(price~.,data=dat[ind.gen, ],maxnodes=20,ntree=100)orf<-ORF$new(param)orf$generateForest(rf,df.train=dat[ind.gen, ],y.col="price")orf$meanTreeSize()

## [1] 39

#and updatefor(iinind.updt) {orf$update(dat[i,1:9],dat[i,10])}orf$meanTreeSize()

## [1] 105.7

#predict and compareif(!require(Metrics)) install.packages("Metrics")preds.rf<-predict(rf,newdata=dat[ind.test,])Metrics::rmse(preds.rf,dat$price[ind.test])

## [1] 988.8055

preds<-orf$predicts(dat[ind.test,1:9])Metrics::rmse(preds,dat$price[ind.test])#make progress

## [1] 869.9613

其他用途

在 Tree 类中

ta<-Tree$new("abc",NULL,NULL)tb<-Tree$new(1,Tree$new(36),Tree$new(3))tc<-Tree$new(89,tb,ta)tc$draw()

#update tctc$right$updateChildren(Tree$new("666"),Tree$new(999))tc$right$right$updateChildren(Tree$new("666"),Tree$new(999))tc$draw()

通过random Forest包配置一个Online random Tree,并升级

#data preparationlibrary(randomForest)dat1<-iris;dat1[,5]<-as.integer(dat1[,5])rf<-randomForest(factor(Species)~.,data=dat1,maxnodes=3)treemat1<-getTree(rf,1,labelVar=F)treemat1<-cbind(treemat1,node.ind=1:nrow(treemat1))x.rng1<-dataRange(dat1[1:4])param1<-list('minSamples'=5,'minGain'=0.1,'numClasses'=3,'x.rng'=x.rng1)ind.gen<-sample(1:150,50)#for generate ORTind.updt<-setdiff(1:150,ind.gen)#for update ORT#originort2<-ORT$new(param1)ort2$draw()

## Root 1

##  Leaf 1

#generate a treeort2$generateTree(treemat1,df.node=dat1[ind.gen,])ort2$draw()

## Root X3 < 2.45

## |----L: Leaf 1

## |----R: X3 < 4.75

##      |----L: Leaf 2

##      |----R: Leaf 3

#update this treefor(iinind.updt) {ort2$update(dat1[i,1:4],dat1[i,5])}ort2$draw()

## Root X3 < 2.45

## |----L: Leaf 1

## |----R: X3 < 4.75

##      |----L: Leaf 2

##      |----R: X4 < 2.19

##          |----L: X2 < 3.68

##                |----L: X1 < 7.12

##                    |----L: X3 < 4.06

##                          |----L: Leaf 1

##                          |----R: Leaf 3

##                    |----R: Leaf 3

##                |----R: Leaf 1

##          |----R: Leaf 3

猜你喜欢

转载自blog.csdn.net/r3ee9y2oefcu40/article/details/80423917
今日推荐