引言
在数据分析的过程中,我们经常需要对数据建模并做预测。在众多的选择中,randomForest, gbm 和 glmnet 是三个尤其流行的 R 包,它们在 Kaggle 的各大数据挖掘竞赛中的出现频率独占鳌头,被坊间人称为 R 数据挖掘包中的三驾马车。根据我的个人经验,gbm 包比同样是使用树模型的 randomForest 包占用的内存更少,同时训练速度较快,尤其受到大家的喜爱。在 python 的机器学习库 sklearn 里也有 GradientBoostingClassifier 的存在。
Boosting 分类器属于集成学习模型,它基本思想是把成百上千个分类准确率较低的树模型组合起来,成为一个准确率很高的模型。这个模型会不断地迭代,每次迭代就生成一颗新的树。对于如何在每一步生成合理的树,大家提出了很多的方法,我们这里简要介绍由 Friedman 提出的 Gradient Boosting Machine。它在生成每一棵树的时候采用梯度下降的思想,以之前生成的所有树为基础,向着最小化给定目标函数的方向多走一步。在合理的参数设置下,我们往往要生成一定数量的树才能达到令人满意的准确率。在数据集较大较复杂的时候,我们可能需要几千次迭代运算,如果生成一个树模型需要几秒钟,那么这么多迭代的运算耗时,应该能让你专心地想静静……
现在,我们希望能通过 xgboost 工具更好地解决这个问题。xgboost 的全称是 eXtreme Gradient Boosting。正如其名,它是 Gradient Boosting Machine 的一个 c++ 实现,作者为正在华盛顿大学研究机器学习的大牛陈天奇。他在研究中深感自己受制于现有库的计算速度和精度,因此在一年前开始着手搭建 xgboost 项目,并在去年夏天逐渐成型。xgboost 最大的特点在于,它能够自动利用 CPU 的多线程进行并行,同时在算法上加以改进提高了精度。它的处女秀是 Kaggle 的希格斯子信号识别竞赛,因为出众的效率与较高的预测准确度在比赛论坛中引起了参赛选手的广泛关注,在 1700 多支队伍的激烈竞争中占有一席之地。随着它在 Kaggle 社区知名度的提高,最近也有队伍借助 xgboost 在比赛中夺得第一。
为了方便大家使用,陈天奇将 xgboost 封装成了 python 库。我有幸和他合作,制作了 xgboost 工具的 R 语言接口,并将其提交到了 CRAN 上。也有用户将其封装成了 julia 库。python 和 R 接口的功能一直在不断更新,大家可以通过下文了解大致的功能,然后选择自己最熟悉的语言进行学习。
功能介绍
一、基础功能
首先,我们从 github 上安装这个包:
devtools::install_github('dmlc/xgboost',subdir='R-package')
动手时间到!第一步,运行下面的代码载入样例数据:
require(xgboost)
data(agaricus.train, package='xgboost')
data(agaricus.test, package='xgboost')
train <- agaricus.train
test <- agaricus.test
这份数据需要我们通过一些蘑菇的若干属性判断这个品种是否有毒。数据以 1 或 0 来标记某个属性存在与否,所以样例数据为稀疏矩阵类型:
class(train$data)
[1] "dgCMatrix"
attr(,"package")
[1] "Matrix"
不用担心,xgboost 支持稀疏矩阵作为输入。下面就是训练模型的命令
bst <- xgboost(data = train$data, label = train$label, max.depth = 2, eta = 1,nround = 2, objective = "binary:logistic")
[0] train-error:0.046522
[1] train-error:0.022263
我们迭代了两次,可以看到函数输出了每一次迭代模型的误差信息。这里的数据是稀疏矩阵,当然也支持普通的稠密矩阵。如果数据文件太大不希望读进 R 中,我们也可以通过设置参数data = 'path_to_file'
使其直接从硬盘读取数据并分析。目前支持直接从硬盘读取 libsvm 格式的文件。
做预测只需要一句话:
pred <- predict(bst, test$data)
做交叉验证的函数参数与训练函数基本一致,只需要在原有参数的基础上设置nfold
:
cv.res <- xgb.cv(data = train$data, label = train$label, max.depth = 2, eta = 1, nround = 2, objective = "binary:logistic",
nfold = 5)
[0] train-error:0.046522+0.001102 test-error:0.046523+0.004410
[1] train-error:0.022264+0.000864 test-error:0.022266+0.003450
cv.res
train.error.mean train.error.std test.error.mean test.error.std
1: 0.046522 0.001102 0.046523 0.004410
2: 0.022264 0.000864 0.022266 0.003450
交叉验证的函数会返回一个 data.table 类型的结果,方便我们监控训练集和测试集上的表现,从而确定最优的迭代步数。
二、高速准确
上面的几行代码只是一个入门,使用的样例数据没法表现出 xgboost 高效准确的能力。xgboost 通过如下的优化使得效率大幅提高:
- xgboost 借助 OpenMP,能自动利用单机 CPU 的多核进行并行计算。需要注意的是,Mac 上的 Clang 对 OpenMP 的支持较差,所以默认情况下只能单核运行。
- xgboost 自定义了一个数据矩阵类 DMatrix,会在训练开始时进行一遍预处理,从而提高之后每次迭代的效率。
在尽量保证所有参数都一致的情况下,我们使用希格斯子竞赛的数据做了对照实验。
Model and Parameter | gbm | xgboost(1 thread) | xgboost(2 threads) | xgboost(4 threads) | xgboost(8 threads) |
---|---|---|---|---|---|
Time (in secs) | 761.48 | 450.22 | 102.41 | 44.18 | 34.04 |
以上实验使用的 CPU 是 i7-4700MQ。python 的 sklearn 速度与 gbm 相仿。如果想要自己对这个结果进行测试,可以在比赛的官方网站下载数据,并参考这份 demo 中的代码。
除了明显的速度提升外,xgboost 在比赛中的效果也非常好。在这个竞赛初期,大家惊讶地发现 R 和 python 中的 gbm 竟然难以突破组织者预设的 benchmark。而 xgboost 横空出世,用不到一分钟的训练时间便打入当时的 top 10,引起了大家的兴趣与关注。准确度提升的主要原因在于,xgboost 的模型和传统的 GBDT 相比加入了对于模型复杂度的控制以及后期的剪枝处理,使得学习出来的模型更加不容易过拟合。更多算法上的细节可以参考这份陈天奇给出的介绍性讲义。
三、进阶特征
除了速度快精度高,xgboost 还有一些很有用的进阶特性。下面的 “demo” 链接对应着相应功能的简单样例代码。
- 只要能够求出目标函数的梯度和 Hessian 矩阵,用户就可以自定义训练模型时的目标函数。demo
- 允许用户在交叉验证时自定义误差衡量方法,例如回归中使用 RMSE 还是 RMSLE,分类中使用 AUC,分类错误率或是 F1-score。甚至是在希格斯子比赛中的 “奇葩” 衡量标准 AMS。demo
- 交叉验证时可以返回模型在每一折作为预测集时的预测结果,方便构建 ensemble 模型。demo
- 允许用户先迭代 1000 次,查看此时模型的预测效果,然后继续迭代 1000 次,最后模型等价于一次性迭代 2000 次。demo
- 可以知道每棵树将样本分类到哪片叶子上,facebook 介绍过如何利用这个信息提高模型的表现。demo
- 可以计算变量重要性并画出树状图。demo
- 可以选择使用线性模型替代树模型,从而得到带 L1+L2 惩罚的线性回归或者 logistic 回归。demo
这些丰富的功能来源于对日常使用场景的总结,数据挖掘比赛需求以及许多用户给出的精彩建议。
四、未来计划
现在,机器学习工具在实用中会不可避免地遇到 “单机性能不够” 的问题。目前,xgboost 的多机分布式版本正在开发当中。基础设施搭建完成之日,便是新一轮 R 包开始设计与升级之时。
结语
我为 xgboost 制作 R 接口的目的就是希望引进好的工具,让大家使用 R 的时候心情更愉悦。总结下来,xgboost 的特点有三个:速度快,效果好,功能多,希望它能受到大家的喜爱,成为一驾新的马车。
xgboost 功能较多,参数设置比较繁杂,希望在上手之后有更全面了解的读者可以参考项目 wiki。欢迎大家多多交流,在项目 issue 区提出疑问与建议。我们也邀请有兴趣的读者提交代码完善功能,让 xgboost 成为更好用的工具。
另外,在使用 github 开发的过程中,我深切地感受到了协作写代码带来的变化。一群人在一起的时候,可以写出更有效率的代码,在丰富的使用场景中发现新的需求,在极端情况发现隐藏很深的 bug,甚至在主代码手_拖延症_较为忙碌的时候有人挺身而出拿下一片 issue。这样的氛围,能让一个语言社区的交流丰富起来,从而充满生命力地活跃下去。
发表 / 查看评论