CART决策树剪枝算法
由于生成的决策树存在过拟合问题,需要对它进行剪枝,以简化学到的决策树。决策树的剪枝,往往从已生成的树上剪掉一些叶节点或叶节点以上的子树,并将其父节点或根节点作为新的叶节点,从而简化生成的决策树。
《统计学习方法》书上讲的关于CART决策树的剪枝算法有些不好懂,结合网上资料和自己的理解记录一下。
CART剪枝算法由两步组成:
- 首先从生成算法产生的决策树
底端开始不断剪枝,直到 的根节点,形成一个子树序列 ; - 然后通过交叉验证法在独立的验证数据集上对子树序列进行测试,从中选择最优子树。
1 剪枝,形成一个子树序列
1.1 损失函数
我们剪枝是跟据损失函数这一指标来进行的。在剪枝过程中计算子树的损失函数:
其中,
具体的,训练数据的预测误差的计算公式为:
其中,
因为CART树是二叉树,所以每一个叶节点
对于二分类问题(
1.2 嵌套子树序列
Breiman等人证明:可以用递归的方法对树进行剪枝。将
这个过程怎么解释呢?从整体树
因为此时
,所以等式后面加的就是
以
我们可以观察当
当
当
只要
对于这段话怎么理解呢,看一下下面这个图
![]()
在交点之前, 的损失函数要更小一些,也就是没剪枝要好一些;而在交点 之后, 的损失函数要更小一些,也就是该剪枝了。
(P.S.关于一开始为什么找遍全网没有一个地方解释清楚,最后在scikit的文档里面找到一个算是能解释的说法吧:
In general, the impurity of a node is greater than the sum of impurities of its terminal nodes.
为此,对
它表示剪枝后整体损失函数减少的程度。在
这一段怎么理解呢,看一下下面这个图
为什么要选择最小的g(t)呢?以图中两个点为例,结点1和结点2,g(t)2大于g(t)1, 假设在所有结点中g(t)1最小,g(t)2最大,两种选择方法:
当选择最大值g(t)2,即结点2进行剪枝,但此时结点1的不修剪的误差大于修剪之后的误差,即如果不修剪的话,误差变大,依次类推,对其它所有的结点的g(t)都是如此,从而造成整体的累计误差更大。
反之,如果选择最小值g(t)1,即结点1进行剪枝,则其余结点不剪的误差要小于剪后的误差,不修剪为好,且整体的误差最小。
从而以最小g(t)剪枝获得的子树是该值下的最优子树!这样一步一步出来的树也是嵌套的
完整的剪枝算法过程如图
2 在剪枝得到的子树序列 中通过交叉验证选取最优子树
具体地,利用独立的验证数据集,测试子树序列
完整的选取过程如图
参考
https://blog.csdn.net/zhengzhenxian/article/details/79083643
https://zhuanlan.zhihu.com/p/85731206
https://scikit-learn.org/stable/modules/tree.html#minimal-cost-complexity-pruning
https://scikit-learn.org/stable/auto_examples/tree/plot_cost_complexity_pruning.html#sphx-glr-auto-examples-tree-plot-cost-complexity-pruning-py
https://en.wikipedia.org/wiki/Decision_tree_pruning#Cost_complexity_pruning
http://mlwiki.org/index.php/Cost-Complexity_Pruning
https://online.stat.psu.edu/stat508/lesson/11/11.8/11.8.2