机器学习算法实践-决策树(Decision Tree)
》中对ID3以及C4.5算法进行了介绍并使用ID3算法处理了分类问题。本文主要使用决策树解决回归问题,使用CART(Classification And Regression Trees)算法。
https://github.com/PytLab/MLBox/tree/master/classification_and_regression_trees
首先是加载数据的部分,这里的所有测试数据我均使用的《Machine Learning in Action》中的数据,格式比较规整加载方式也比较一致, 这里由于做树回归,自变量和因变量都放在同一个二维数组中:
1 2 3 4 5 6 7 8 9
def load_data (filename ): ''' 加载文本文件中的数据. ''' dataset = [] with open(filename, 'r' ) as f: for line in f: line_data = [float(data) for data in line.split()] dataset.append(line_data) return dataset
树回归中再找到分割特征和分割值之后需要将数据进行划分以便构建子树或者叶子节点:
1 2 3 4 5 6 7 8 9 10
def split_dataset (dataset, feat_idx, value ): ''' 根据给定的特征编号和特征值对数据集进行分割 ''' ldata, rdata = [], [] for data in dataset: if data[feat_idx] < value: ldata.append(data) else : rdata.append(data) return ldata, rdata
然后就是重要的选取最佳分割特征和分割值了,这里我们通过找打使得分割后的方差最小的分割点最为最佳分割点:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
def choose_best_feature (dataset, fleaf, ferr, opt ): ''' 选取最佳分割特征和特征值 dataset: 待划分的数据集 fleaf: 创建叶子节点的函数 ferr: 计算数据误差的函数 opt: 回归树参数. err_tolerance: 最小误差下降值; n_tolerance: 数据切分最小样本数 ''' dataset = np.array(dataset) m, n = dataset.shape err_tolerance, n_tolerance = opt['err_tolerance' ], opt['n_tolerance' ] err = ferr(dataset) best_feat_idx, best_feat_val, best_err = 0 , 0 , float('inf' ) for feat_idx in range(n-1 ): values = dataset[:, feat_idx] for val in values: ldata, rdata = split_dataset(dataset.tolist(), feat_idx, val) if len(ldata) < n_tolerance or len(rdata) < n_tolerance: continue new_err = ferr(ldata) + ferr(rdata) if new_err < best_err: best_feat_idx = feat_idx best_feat_val = val best_err = new_err if abs(err - best_err) < err_tolerance: return None , fleaf(dataset) ldata, rdata = split_dataset(dataset.tolist(), best_feat_idx, best_feat_val) if len(ldata) < n_tolerance or len(rdata) < n_tolerance: return None , fleaf(dataset) return best_feat_idx, best_feat_val
其中,停止选取的条件有两个: 一个是当分割的子数据集的大小小于一定值;一个是当选取的最佳分割点分割的数据的方差减小量小于一定的值。
fleaf
是创建叶子节点的函数引用,不同的树结构此函数也是不同的,例如本部分的回归树,创建叶子节点就是根据分割后的数据集平均值,而对于模型树来说,此函数返回值是根据数据集得到的回归系数。
ferr
是计算数据集不纯度的函数,不同的树模型该函数也会不同,对于回归树,此函数计算数据集的方差来判定数据集的纯度,而对于模型树来说我们需要计算线性模型拟合程度也就是线性模型的残差平方和。
然后就是最主要的回归树的生成函数了,树结构肯定需要通过递归创建的,选不出新的分割点的时候就触底:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
def create_tree (dataset, fleaf, ferr, opt=None ): ''' 递归创建树结构 dataset: 待划分的数据集 fleaf: 创建叶子节点的函数 ferr: 计算数据误差的函数 opt: 回归树参数. err_tolerance: 最小误差下降值; n_tolerance: 数据切分最小样本数 ''' if opt is None : opt = {'err_tolerance' : 1 , 'n_tolerance' : 4 } feat_idx, value = choose_best_feature(dataset, fleaf, ferr, opt) if feat_idx is None : return value tree = {'feat_idx' : feat_idx, 'feat_val' : value} ldata, rdata = split_dataset(dataset, feat_idx, value) ltree = create_tree(ldata, fleaf, ferr, opt) rtree = create_tree(rdata, fleaf, ferr, opt) tree['left' ] = ltree tree['right' ] = rtree return tree
https://github.com/PytLab/MLBox/tree/master/classification_and_regression_trees
1 2 3 4
dataset = load_data('ex0.txt' ) dataset = np.array(dataset) plt.scatter(dataset[:, 0 ], dataset[:, 1 ])
1 2
tree = create_tree(dataset, fleaf, ferr, opt={'n_tolerance' : 4 , 'err_tolerance' : 1 })
https://github.com/PytLab/MLBox/blob/master/classification_and_regression_trees/regression_tree.py#L159
然后获取树结构图:
1 2 3 4 5
datafile = 'ex0.txt' dotfile = '{}.dot' .format(datafile.split('.' )[0 ]) with open(dotfile, 'w' ) as f: content = dotify(tree) f.write(content)
生成回归树图片:
1
dot -Tpng ex0.dot -o ex0_tree.png
其中节点上数字代表:
特征编号: 特征分割值
1 2 3 4 5
x = np.linspace(0 , 1 , 50 ) y = [tree_predict([i], tree) for i in x] plt.plot(x, y, c='r' ) plt.show()
数据文件左
&
数据文件右
):
左右两边的数据的分布基本相同但是使用相同的参数得到的回归树却完全不同左边的回归树只有两个分支,而右边的分支则有很多,甚至有时候会为所有的数据点得到一个分支,这样回归树将会非常的庞大, 如下是可视化得到的两个回归树:
如果一棵树的节点过多则表明该模型可能对数据进行了“过拟合”。那么我们需要降低决策树的复杂度来避免过拟合,此过程就是
剪枝
。剪枝技术又分为
预剪枝
和
后剪枝
。
https://github.com/PytLab/MLBox/blob/master/classification_and_regression_trees/compare.py
)
相关系数计算:
1 2 3 4
def get_corrcoef (X, Y ): cov = np.mean(X*Y) - np.mean(X)*np.mean(Y) return cov/(np.var(X)*np.var(Y))**0.5
获得的相关系数:
1 2
linear regression correlation coefficient: 0.9434684235674773 regression tree correlation coefficient: 0.9780307932704089
绘制线性回归和树回归的回归曲线(黄色会树回归曲线,红色会线性回归):
可见树回归方法在预测复杂数据的时候会比简单的线性模型更有效。
CART分类与回归树的原理与实现