本站原创文章,转载请说明来自
《老饼讲解-机器学习》
www.bbbdata.com
决策树可以通过后剪枝一定程度上降低过拟合,
决策树的后剪枝一般采用的是CCP算法,
本文讲解什么是CCP后剪枝和什么是CCP路径及软件中的使用。
01. 什么是CPP后剪枝
本节介绍什么是CCP后剪枝
什么是CCP剪枝法
什么是后剪枝
后剪枝指的是在树完整构建完后,对一些过度生长的节点进行裁剪,以防止决策树过拟合
CCP后剪枝的思想
后剪枝最经典的算法是《CCP剪枝法》(Cost Complexity Pruning代价复杂度剪枝法)
它的主要思路是构造一个把树的代价(不纯度)和树的复杂度(叶子节点个数)都考虑到的损失函数,
然后根据这个损失函数引导决策树进行剪枝
CCP后剪枝
如下,CCP后剪枝法先构造一个综合树代价与树的复杂度的损失函数:
其中,
是一个待定系数, T 则代表叶子节点个数
然后求解,裁掉哪个节点,能让L最小,就裁哪个
02. CPP后剪枝的使用方法
本节讲解实际中如何使用CCP后剪枝对决策树进行剪枝
如何进行CCP后剪枝
CCP后剪枝的使用一般先打印CCP路径,然后根据CCP路径自主剪枝
1、打印CCP路径
CCP路径一般包含三个信息:
👉1. alpha 值
👉2. alpha 值对应剪掉的节点编号
👉3. 剪掉节点后树的Cost(代价,或者质量)
备注:代价的定义各个软件不一,sklearn用的是所有叶子的GINI值/熵值,matlab用的是判断错误的样本占比
2、根据CCP路径自主剪枝
在节点与代价之间自主权衡,选择要裁剪的节点,进行剪枝
剪枝的调用方式各个软件不一,
sklearn是通过直接设置alpha,重新训练决策树
而matlab则是根据节点编号,直接调用剪枝接口对节点裁剪
03. 软件中如何查看CCP路径
下面我们讲解matlab和python中如何查看CCP路径和剪枝方法
matlab使用CCP后剪枝
● 查看CCP路径
matlab 中可以通过tree.PruneAlpha 和tree.PruneList查看 CCP路径:
load fisheriris;
tree = fitctree(meas,species); %建模
disp(['tree.PruneAlpha:[',sprintf('%f, ', tree.PruneAlpha),']'])
disp(['tree.PruneList:[',sprintf('%d, ', tree.PruneList),']'])
👉运行结果如下
tree.PruneAlpha:[0.000000, 0.006667, 0.013333, 0.293333, 0.333333, ]
tree.PruneList:[4, 0, 3, 2, 0, 1, 0, 0, 0, ]
👉结果解读
它代表
时(i>0),被删掉的节点为: find(PruneList==i) ,
即PruneList( i+1)记录了第 i 轮的
,PruneList(i) 记录了节点 i 在 第几轮被删掉。
● 剪枝
(1) 可以通过 prune(tree,'Alpha',0.1) 指定alpha剪枝,
如上例,alpha设为0.1,则会剪2轮,4、6节点将被剪掉。
(2) 也可以通过指定节点剪枝 cptree = prune(tree,'Nodes',[4,6]);
如上例,使用prune(tree,'Nodes',[4,6])将会剪掉4、6节点
python sklearn使用CCP后剪枝
● 查看CCP路径
python中可以通过 clf.cost_complexity_pruning_path(X, y) 计算CCP路径:
# -*- coding: utf-8 -*-
from sklearn.datasets import load_iris
from sklearn import tree
#----------------数据准备----------------------------
iris = load_iris() # 加载数据
X = iris.data
y = iris.target
#---------------模型训练---------------------------------
clf = tree.DecisionTreeClassifier(min_samples_split=10,ccp_alpha=0)
clf = clf.fit(X, y)
pruning_path = clf.cost_complexity_pruning_path(X, y)
print("ccp_alphas:",pruning_path['ccp_alphas'])
print("impurities:",pruning_path['impurities'])
👉运行结果如下
ccp_alphas: [0. 0.00415459 0.01305556 0.02966049 0.25979603 0.33333333]
impurities: [0.02666667 0.03082126 0.04387681 0.07353731 0.33333333 0.66666667]
👉结果解读
alpha = ccp_alphas[i]时,树的不纯度(所有叶子节点的加权Gini(或熵))为 impurities[i]。
● 剪枝
sklearn中的CCP后剪枝,并不能直接指定节点,
它只能通过使用新的
重新训练模型的方式来获取剪枝模型。
例如,我们选定
,重新训练模型,如下:
clf = tree.DecisionTreeClassifier(min_samples_split=10,ccp_alpha=0.1)
End