机器学习算法笔记(三十七):随机森林与Extra-Trees

前面的笔记我们讨论了 Bagging 这种方法,即用随机取样、在特征空间中随机取特征,创建诸多子模型并将它们结合在一起。之前我们在实现 Bagging 时运用到基础的分类器是决策树,于是我们就集成了成百上千个决策树,对于这样的模型,就有一个更加形象的名字——随机森林。

一、sklearn中的随机森林

只要集成学习的底层算法是决策树算法,最终得到的模型都可以称为随机森林。sklearn 专门为我们封装了一个随机森林的类,我们通过这个类可以非常方便的创建随机森林模型。与此同时,sklearn 提供的随机森林还有着更多的随机性。具体来说,所有子模型在节点划分时,都是在随机的特征子集上寻找最优的划分特征,也就是在迭代寻找划分维度及其阈值时,不是对全部特征进行搜索,而是对部分特征进行搜索,这样也就极大增加了每一个子模型的随机性与差异性。具体 sklearn 中实现的代码如下:

import numpy as np

from sklearn import datasets

X, y = datasets.make_moons(n_samples=500, noise=0.3, random_state=666)

from sklearn.ensemble import RandomForestClassifier

rf_clf = RandomForestClassifier(n_estimators=500, oob_score=True, random_state=666, n_jobs=-1)
rf_clf.fit(X, y)
print(rf_clf.oob_score_) #prints: 0.892

如果我们要进一步对随机森林进行调参,可以考虑打印第十行的 fit,就能得到如下所示的参数:

随机森林分类器所拥有的参数,可以看到集合了 Bagging 和决策树的参数。

进行一些简单的参数调整:

rf_clf2 = RandomForestClassifier(n_estimators=500, max_leaf_nodes=16, oob_score=True, random_state=666, n_jobs=-1)
rf_clf2.fit(X, y)
print(rf_clf2.oob_score_) #prints: 0.906

可以看到准确率有些许提高。

二、Extra-Trees

和随机森林非常类似的还有另一种方法,其实也可以称它为随机森林,不过有时我们称它为 Extra-Trees (极其随机的森林)。这里的“及其随机”表现在决策树的结点划分上,它干脆直接使用随机的特征和随机的阈值划分,这样我们每一棵决策树形状、差异就会更大、更随机。

也就是说,节点划分时,选择的特征及对应的特征值不是搜索比较所得,而是随机抽取一个特征,再从该特征中随机抽取一个特征值,作为该节点划分的依据。

虽然子模型如此随机,但只要子模型的准确率大于 50%,并且集成的子模型的数量足够多,最终整个集成系统的准确率就能达到要求。这样做的优点是提供额外的随机性,抑制过拟合,并且具有更快的训练速度;缺点也很明显:增大了偏差(bias)。

Extra-Trees 的使用方法也与随机森林类似,代码如下:

from sklearn.ensemble import ExtraTreesClassifier

et_clf = ExtraTreesClassifier(n_estimators=500, bootstrap=True, oob_score=True, random_state=666, n_jobs=-1)
et_clf.fit(X, y)
print(et_clf.oob_score_) #prints: 0.892

三、运用集成学习解决回归问题

大多数集成学习算法都可以解决回归问题。在 sklearn 中,解决分类问题时我们导入的包为:

from sklearn.ensemble import BaggingClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import ExtraTreesClassifier

而在解决回归问题时,导入的包为:

from sklearn.ensemble import BaggingRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import ExtraTreesRegressor

在调用时,他们的接口大致相同,这里就不再赘述了。

发表评论

电子邮件地址不会被公开。 必填项已用*标注