机器学习-入门教程

【代码】CART决策树代码(自实现)

作者 : 老饼 发表日期 : 2024-03-12 06:24:26 更新日期 : 2025-03-30 13:50:56
本站原创文章,转载请说明来自《老饼讲解-机器学习》www.bbbdata.com



本文展示如何使用python自行实现CART决策树(非调包)和CCP后剪枝算法

算法逻辑来自matlab的fitctree函数,亲测结果与fitctree一致

通过本代码的理解,可以更进一步理解CART决策树的算法逻辑和详细实现细节




  01. CART决策树代码-简介  




本节介绍本文的CART决策树代码实现的内容




   CART决策树-代码功能与算法来源   


本代码实现CART决策树的构建,预测和CCP剪枝功能。
 
代码的算法流程来自matlab自带的决策树包,构建的结果与matlab的fitctree一致
最后,代码中附加一个使用Demo,展示代码各个函数功能的使用
CART决策树的相关原理可见:
CART决策树模型简介与实例
CART决策树算法流程          
笔者语
由于CART决策树的实现代码非常复杂,并且本文实现的功能较为全面,代码较多
所以笔者并不建议初学者真的去阅读和理解本代码
之所以展示此代码,仅仅是为了告诉大家,如果有必要,是可以具体实现的
另一方面,在学习CART决策树时,如果有细节不理解,也可以借助代码辅助理解





   CART决策树代码-运行后执行的内容   


代码运行后,就是运行Demo函数,
Demo函数主要是用sklearn自带的iris数据构训练一棵决策树,并进行剪枝,预测等操作
具体实现内容如下:
1、数据生成:使用sklearn自带的iris数据                                          
2、用iris数据构建一棵完全生长的决策树                                           
3、用构建好的决策树对样本进行预测,并打印出预测错误的样本编号 
4、将树去掉无效子节点                                                                    
5、计算各个节点的临界alpha                                                           
6、计算并打印CCP路径                                                                    
7、设定alpha=0.1,进行剪枝                                                             
8、根据节点编号进行剪枝                                                                







   02. CART决策树代码   





本节展示用python自实现一棵CART决策树的代码





    自实现CART决策树的代码    


用python自实现CART决策树(分类树)的代码如下:
# -*- coding: utf-8 -*-
"""
决策树CART(分类树)自实现
PASS:算法来源来matlab决策树工具包,笔者亲测,与软件包结果一致。
本代码来自老饼讲解-机器学习:www.bbbdata.com
"""
from sklearn.datasets import load_iris
import numpy as np
from copy import copy

# 节点类:决策树即为一串的树节点连接
class Node(object):
    def __init__(self,cMat,sample_idx):
        sample_class    = cMat[sample_idx].sum(axis=0)                 # 统计各类别样本个数
        class_num       = (sample_class>0).sum()                       # 类别个数
        node_class      = sample_class.argmax()                        # 节点类别
        sample_num      = len(sample_idx)                              # 样本个数
        err_num         = sample_num-sample_class[node_class]          # 错误样本个数
																	   
        self.id         = None                                         # 节点ID,删除节点时可通过该值删除
        self.is_leaf    = 0                                            # 是否是叶子节点
        self.left_node  = None                                         # 左节点(也是Node类)
        self.right_node = None                                         # 右节点(也是Node类)
        self.cut_var    = None                                         # 分割变量
        self.cut_val    = None                                         # 分割值
																	   
																	   
        self.sample_idx   = sample_idx                                 # 属于该节点的样本索引
        self.sample_num   = sample_num                                 # 属于该节点的样本个数
        self.sample_class = sample_class                               # 属于该节点的样本属于各类的个数,如[30,10]代表0类30个,1类10个
        self.class_num    = class_num                                  # 属于该节点的样本类别个数
        self.node_class   = node_class                                 # 该节点被判为属于哪一类别 
																	   
																	   
        self.err_num       = err_num                                   # 节点上判断错误的样本个数
        self.leaf_errnum   = 0                                         # 该节点下的叶子节点的总错误样本个数,用于计算alpha
        self.leaf_nodenum  = 0                                         # 该节点下的叶子节点个数,用于计算alpha
        self.alpha         = None                                      # 剪枝系数alpha
															           
    # 将节点改为叶子节点                                               
    def be_leaf(self):                                                 
        self.left_node  = None                                         # 将左节点置空
        self.right_node = None                                         # 将右节点置空
        self.cut_var    = None                                         # 将切割变量置空
        self.cut_val    = None                                         # 将切割阈值置空
        self.is_leaf    = 1                                            # 标记为叶子节点
															           
    # 将节点深拷贝(采用递归拷贝)                                       
    def copy(self):                                                    
        if(self.is_leaf==0):                                           # 如果不是叶子节点
            left_node             = self.left_node.copy()              # 拷贝左节点
            right_node            = self.right_node.copy()             # 拷贝右节点
            new_node              = copy(self)                         # 拷贝自身 
            new_node.sample_idx   = self.sample_idx.copy()             # 拷贝节点样本索引
            new_node.sample_class = self.sample_class.copy()           # 拷贝节点样本类别
            new_node.left_node    = left_node                          # 设置左节点
            new_node.right_node   = right_node                         # 设置右节点
        if(self.is_leaf==1):                                           # 如果是叶子节点
          new_node=copy(self)                                          # 拷贝自身
          new_node.sample_idx     = self.sample_idx.copy()             # 拷贝节点样本索引
          new_node.sample_class   = self.sample_class.copy()           # 拷贝节点样本类别
        return new_node                                                                      
																	                         
# 计算使用变量x,切值为cut_val时的收益函数(gini)                                              
def cal_gain(x,cMat,cut_val):                                                                
    is_left    = x<=cut_val                                                                  # 样本是否是左节点
    left_rate  = is_left.sum()/len(is_left)                                                  # 左节点样本个数占比
    right_rate = 1- left_rate                                                                # 右节点样本个数占比
    left_cMat  = cMat[is_left]                                                               # 左节点类别
    right_cMat = cMat[~is_left]                                                              # 右节点类别
    p_left     = left_cMat.sum(axis=0)/left_cMat.sum()                                       # 左节点各类别占比
    if (right_cMat.sum()>0):                                                                 # 如果右节点有样本
        p_right = right_cMat.sum(axis=0)/right_cMat.sum()                                    # 计算右节点的类别占比
    else:                                                                                    # 否则
        p_right = right_cMat.sum(axis=0)                                                     # 右节点的类别占比全为0
																                             
    g_left  = 1- (p_left**2).sum()                                                           # 左节点基尼系数
    g_right = 1- (p_right**2).sum()                                                          # 右节点基尼系数
    gain    = -(left_rate)*g_left - (right_rate)*g_right                                     # 收益值:-左右基尼系数加权和
    return gain                                                                              # 返回收益值

# 找出x变量的最佳切割点
def find_best_cut(x,cMat):
     unique_x     = np.unique(x)                                                             # 先对x去重
     best_cut_val = (unique_x[0]+unique_x[1])/2 if len(unique_x)>1 else unique_x[0]          # 初始化最佳切割点
     best_g       = cal_gain(x,cMat,best_cut_val)                                            # 初始化最佳收益
     for i in range(1,len(unique_x)-1):                                                      # 逐个切割点循环
         cut_val = (unique_x[i]+unique_x[i+1])/2                                             # 当前切割点
         g       = cal_gain(x,cMat,cut_val)                                                  # 当前切割点的收益
         if g>best_g:                                                                        # 如果当前收益大于最佳收益
             best_cut_val = cut_val                                                          # 将当前切割点作为最佳切割点
             best_g       = g                                                                # 将当前收益作为最佳收益
     return best_g,best_cut_val                                                              # 返回最佳收益和最佳切割点
																					         
# 将类别标签转为one-hot矩阵                                                                  
def class2Cmat(y):                                                                           
    c_name = np.unique(y)                                                                    # 将y去重
    c_num  = len(c_name)                                                                     # 类别个数
    cMat   = np.zeros([len(y),c_num])                                                        # 初始化one-hot矩阵
    for i in range(c_num):                                                                   # 逐个类别循环
        cMat[y==c_name[i],i]=1                                                               # 设置当前类别的样本的one-hot
    return cMat,c_name                                                                       # 返回one-hot矩阵与类别名称
																					         
# 删除无用叶子: 叶子不能降低判别误差,则删                                                  
def prune_bad(node):                                                                         
    un_use_node = [node]                                                                     # 将根节点入栈
    delete_bad  = 0                                                                          # 初始化是否删除过节点
    while(len(un_use_node)>0):                                                               # 如果栈里还有节点
        cur_node = un_use_node.pop()                                                         # 弹出栈里的节点
        if(cur_node.is_leaf==1):                                                             # 如果是叶子节点
            pass                                                                             # 不必判断
        elif((cur_node.left_node.is_leaf==1) &(cur_node.right_node.is_leaf==1)):             # 如果左右都是叶子,则判断是否删除
           if((cur_node.left_node.err_num+ cur_node.right_node.err_num) >=cur_node.err_num): # 叶子节点不能降低判别误差,则删
               cur_node.be_leaf()                                                            # 将节点置为叶子节点
               delete_bad = 1                                                                # 标记:已删过节点
        else:                                                                                # 如果不是叶子,也不是倒算第二层节点
           un_use_node.append(cur_node.left_node)                                            # 把左叶子入栈
           un_use_node.append(cur_node.right_node)                                           # 把右叶子入栈
    if(delete_bad==1):                                                                       # 如果删过节点
       prune_bad(node)                                                                       # 则将树重新裁剪

# 给节点及其子孙设置id序号       
def set_node_id(node,next_id):
     node.id = next_id                                                                       # 设置当前节点的ID
     if(node.is_leaf==0):                                                                    # 如果节点不是叶子
         next_id = set_node_id(node.left_node,next_id+1)                                     # 给左节点设置ID
         set_node_id(node.right_node,next_id)                                                # 给右节点设置ID
     return next_id+1                                                                        # 返回最新节点编号

# 计算各个节点的临界alpha,并返回最小临界alpha
def cal_alpha(node,total_num=None):
    total_num = node.sample_num if(total_num is None) else total_num                         # 初始化总样本数
	# 如果不是叶子节点,获取左右节点下的叶子节点的错误样本个数,并计算临界alpha
    if(node.is_leaf==0):                                                                     # 如果不是叶子节点,
       left_err,left_nodenum,left_min_alpha    = cal_alpha(node.left_node,total_num)         # 获取左节点下的叶子节点的错误样本个数
       right_err,right_nodenum,right_min_alpha = cal_alpha(node.right_node,total_num)        # 获取右节点下的叶子节点的错误样本个数
       node.leaf_errnum  = left_err + right_err                                              # 计算当前节点下的叶子节点的错误样本个数
       node.leaf_nodenum = left_nodenum + right_nodenum                                      # 计算当前了点下的叶子节点个数
       node.alpha = max(((node.err_num-node.leaf_errnum)/total_num)/(node.leaf_nodenum-1),0) # 计算临界alpha
       min_alpha  = min(left_min_alpha,right_min_alpha,node.alpha)                           # 记录最小临界alpha
	# 如果是叶子节点,直接获取叶子节点错误个数
    else:                                                                                    # 如果是叶子节点
        node.leaf_errnum  = node.err_num                                                     # 当前节点错误样本个数就是叶子节点错误样本个数
        node.leaf_nodenum = 1                                                                # 叶子节点个数
        min_alpha = float('inf')                                                             # 设置最小临界alpha为无穷大
    return node.leaf_errnum,node.leaf_nodenum,min_alpha                                      # 返回错误样本数、叶子数、最小临界alpha

# 对临界alpha<=alpha的节点剪枝
def prune_alpha(node,alpha):
    prune_list =[]                                                                           # 初始化剪枝节点列表
    if(node.is_leaf==1):                                                                     # 如果当前节点是叶子
        return prune_list                                                                    # 无需处理,直接返回剪枝列表
    if(node.alpha<=alpha):                                                                   # 如果节点alpha小于剪枝alpha
        prune_list.append(node.id)                                                           # 将当前节点添加到剪枝列表
        left_prune_list  = prune_alpha(node.left_node,0)                                     # 对左节点继续用alpha=0进行剪枝
        right_prune_list = prune_alpha(node.right_node,0)                                    # 对右节点继续用alpha=0进行剪枝
        node.be_leaf()                                                                       # 将节点设为叶子节点
    else:                                                                                    # 如果节点alpha不小于剪枝alpha
        left_prune_list  = prune_alpha(node.left_node,alpha)                                 # 对左节点继续用alpha进行剪枝
        right_prune_list = prune_alpha(node.right_node,alpha)                                # 对右节点继续用alpha进行剪枝
    prune_list.extend(left_prune_list)                                                       # 添加左节点的剪枝节点
    prune_list.extend(right_prune_list)                                                      # 添加右节点的剪枝节点
    return prune_list                                                                        # 返回剪枝节点列表
        
# 获取迭代剪枝的最小alpha
'''
先计算树最小临界alpha,剪掉最小临界alpha的叶子,再计算树小临界,再剪..再计算,再剪,直到只剩根节点,
返回每轮的最小临界alpha
'''
def cal_prune_list(node):
    cnode = node.copy()
    min_alpha_list = []
    prune_list = []
    while((cnode.is_leaf==0) and len(min_alpha_list)<1000):
        leaf_errnum,leaf_nodenum,min_alpha = cal_alpha(cnode)
        cur_prune_list = prune_alpha(cnode,min_alpha)
        prune_list.append(cur_prune_list)
        min_alpha_list.append(min_alpha)
    return min_alpha_list,prune_list


#根据alpha值最大剪枝
def prune(node,alpha):
    p_node = node.copy()
    min_alpha_list,prune_list = cal_prune_list(p_node)
    prune_term = np.argwhere(np.array(min_alpha_list)<=alpha)
    
    if(len(prune_term)>0):
        prune_term=prune_term[-1][0]+1
        for i in range(prune_term):
            leaf_errnum,leaf_nodenum,min_alpha = cal_alpha(p_node)
            prune_alpha(p_node,min_alpha)
    return p_node


# 剪掉指定ID的节点
def prune_nodes(node,id_list):
    p_node = node.copy()                                                                    # 入参的节点作为父节点
    un_use_node=[p_node]                                                                    # 将父节点入栈
    while((len(un_use_node)>0) and len(id_list)>0 ):                                        # 如果还有待剪节点(同时要求栈里有节点)
        cur_node = un_use_node.pop()                                                        # 从栈中弹出一个节点
        if(cur_node.id in id_list):                                                         # 如果当前节点在待剪列表中
            cur_node.be_leaf()                                                              # 将当前节点置为叶子
        elif(cur_node.is_leaf==0):                                                          # 如果当前节点不是叶子
            un_use_node.append(cur_node.left_node)                                          # 将左叶子节点入栈
            un_use_node.append(cur_node.right_node)                                         # 将右叶子节点入栈
    return p_node                                                                          
																					       
# predict,预测函数                                                                         
def predict(node,x):                                                                        
    while(node.is_leaf==0):                                                                 # 如果节点不是叶子
       node = node.left_node if x[node.cut_var]<=node.cut_val else node.right_node          # 根据切割变量与切割阈值,判断节点走向左还是右
    return node.node_class                                                                  # 返回最终节点的类别
        
# 打印树
def print_node(node,deep=0,var_name_list=[],show_sample_class=0,show_alpha_info=0): 
    node_id = '('+str(node.id)+')'
    alpha_info = ''
    if (show_alpha_info==1):
        alpha_info = ' (leaf_errnum:'+str(node.leaf_errnum)+')' +  ' (alpha:'+str(node.alpha)+')' 
    if(node.is_leaf==0):
         var_name  = 'x' + str(node.cut_var)  if(len(var_name_list)==0) else var_name_list[node.cut_var]
         left_sample_class  = "("+str(node.left_node.sample_class)+")"  if show_sample_class==1 else ''
         right_sample_class = "("+str(node.right_node.sample_class)+")" if show_sample_class==1 else ''
         
         print('  |'*deep+"--"+node_id+var_name+"<="+str(node.cut_val)+left_sample_class+alpha_info)
         print_node(node.left_node,deep+1
                    ,var_name_list=var_name_list
                    ,show_sample_class=show_sample_class
                    ,show_alpha_info=show_alpha_info)
         print('  |'*deep+"--"+node_id+var_name+">"+str(node.cut_val)+right_sample_class+alpha_info)
         print_node(node.right_node
                    ,deep+1
                    ,var_name_list=var_name_list
                    ,show_sample_class=show_sample_class
                    ,show_alpha_info=show_alpha_info)
    else:
        print('  |'*deep+"--"+node_id+"class="+str(node.node_class) +alpha_info)

# 主程序:构建树        
def build_tree(x,y):                                                                        
    min_leaf_num = 10                                                                       # 参数预设
    n_samples    = x.shape[0]                                                               # 样本个数
    n_feture     = x.shape[1]                                                               # 特征个数
    cMat,c_name  = class2Cmat(y)                                                            # 将y转为one-hot矩阵
    root_node    = Node(cMat,np.arange(n_samples))                                          # 树初始化
    un_use_node  = [root_node]                                                              # 将根节点入栈
																			                
    # 树构建主流程                                                                          
    while(len(un_use_node)>0 ):                                                             # 如果还有节点未分裂完成
        # --------- 弹出节点 ----------------                                               
        cur_node = un_use_node.pop()                                                        # 弹出一个未完成分裂的节点
        node_x   = x[cur_node.sample_idx]                                                   # 获取节点样本的x
        cur_cMat = cMat[cur_node.sample_idx]                                                # 获取节点样本的y(类别矩阵形式)
        not_leaf = (cur_node.sample_num>=min_leaf_num)& (cur_node.class_num>1)              # 判断是否未达叶子条件
        #  ---------- 分裂或设为叶子 --------                                               
        if(not_leaf):                                                                       # 如果未能成为叶子,继续分裂
            best_var = 0                                                                    # 预设第一个变量为最佳变量
            best_g,best_cut_val = find_best_cut(node_x[:,best_var],cur_cMat)                # 预设第一个变量的最佳切割为节点最佳切割
            
            # 历遍变量,找出每个变量的最佳切割,再比较哪个变量的最佳切割最好
            for i in range(1,n_feture):                                                     # 历遍变量
                g,cut_val = find_best_cut(node_x[:,i],cur_cMat)                             # 找出该变量的最佳切割,与最佳收益
                if g>best_g:                                                                # 更新最佳变量、最佳切割、最佳收益
                  best_g       = g                                                          # 更新最佳收益
                  best_cut_val = cut_val                                                    # 更新最佳切割
                  best_var     = i                                                          # 更新最佳变量
																			                
            cur_node.cut_var = best_var                                                     # 把最佳变量作为本节点最佳变量
            cur_node.cut_val = best_cut_val                                                 # 把最佳切割作为本节点最佳切割
            is_left = node_x[:,best_var]<=best_cut_val                                      # 找出左节点样本
            cur_node.left_node  = Node(cMat,cur_node.sample_idx[is_left ])                  # 新建左节点
            cur_node.right_node = Node(cMat,cur_node.sample_idx[~is_left ])                 # 新建右节点
            un_use_node.extend([cur_node.left_node,cur_node.right_node])                    # 把左右节点添加到分裂池
        else:                                                                               
           cur_node.be_leaf()                                                               # 如果节点已达叶子条件,则设为叶子节点 
    set_node_id(root_node,1)                                                                # 给树每个节点添加序号
    return root_node                                                                        # 返回根节点,即树
																			                
# 各个功能的使用demo                                                                        
def test_demo():                                                                            
    # -----加载数据-----------------                                                        
    iris = load_iris()                                                                      # 加载iris数据
    var_name_list =['sepal_length','sepal_width','petal_length','petal_width']              # 变量名称
    x = iris.data                                                                           # 变量x
    y = iris.target                                                                         # 标签y
																			                
    # ----构建完全生长的树-----------                                                       
    tree = build_tree(x,y)                                                                  # 构建一棵完全生长的决策树
																			                
    # -----预测---------------------                                                        
    predict_y = np.zeros(y.shape)                                                           # 初始化预测结果
    for i in range(x.shape[0]):                                                             # 逐个样本预测
       predict_y[i] = predict(tree,x[i])                                                    # 当前样本的预测值
																			                
    # -----打印信息-------------------------------------                                    
    print("\n------全生长树:------")                                                       # 打印
    print_node(tree,var_name_list=var_name_list)                                            # 打印树结构
    print("\n------预测错误样本:------")                                                   # 打印
    print(np.argwhere(predict_y != y ))                                                     # 打印预测错误的样本序号
																			                
    #-----剪枝-----------------------------------                                           
    # 去掉无效叶子                                                                          
    prune_bad(tree)                                                                         # 去掉无效叶子
    print("\n------去掉无效叶子的树------------")                                           # 打印树
    print_node(tree,var_name_list=var_name_list)                                            # 打印树结构
																			                
    # 根据alpha剪枝                                                                         
    # 带临界alpha信息的树                                                                   
    cal_alpha(tree,tree.sample_num)                                                         # 获取临界alpha信息
    print("\n------树的临界alpha信息----------")                                            # 打印
    print_node(tree,var_name_list=var_name_list,show_alpha_info=1)                          # 打印树结构
																			                
    # 多轮迭代式剪枝得到的CCP路径                                                           
    min_alpha_list,prune_list = cal_prune_list(tree)                                        # 模拟迭代式删除临界alpha
    print("\n--CCP路径--")                                                                  # 打印
    print("每轮alpha:",min_alpha_list)                                                      # 打印临界alpha
    print("每轮剪除节点:",prune_list)                                                       # 打印对应剪除的节点
																			                
																			                
    p_tree = prune(tree,0.1 )                                                               # 根据alpha剪枝(会多轮迭代)
    print("\n------指定alpha进行剪枝后的决策树-------")                                     # 打印
    print_node(p_tree,var_name_list=var_name_list)                                          # 打印树
																			                
    # 其它剪枝                                                                              
    p_tree = prune_nodes(p_tree,[5,9])                                                      # 指定节点剪枝
    prune_alpha(tree,0.1)                                                                   # 根据alpha剪枝(只剪一轮)

# 调用测试Demo
test_demo()




    代码运行结果    


代码运行后的输出如下:
 1、用iris作为训练数据构建的完全生长的决策树 
 
自实现CART决策树代码-运行结果1
  
 2、决策树在训练样本中预测错误的样本编号
 自实现CART决策树代码-运行结果2
 
 3、将完全生长的树去掉无效节点的后得到的决策树
 自实现CART决策树代码-运行结果3
  
4、展示树的节点信息,包括alpha

 自实现CART决策树代码-运行结果4
 
 5、打印出ccp路径
 自实现CART决策树代码-运行结果5 
  
6、根据CCP路径信息,我们指定alpha=0.1,进行剪枝后得到的决策树
 
自实现CART决策树代码-运行结果6







  03. CART决策树代码说明   



本节对代码中的相关函数进行说明



test_demo: 测试用例主函数,直接运行时就是执行该函数。


1、数据生成:使用sklearn自带的iris数据
2、用iris数据构建一棵完全生长的决策树
3、用构建好的决策树对样本进行预测,并打印出预测错误的样本编号
4、将树去掉无效子节点
5、计算各个节点的临界alpha
6、计算并打印CCP路径
7、设定alpha=0.1,进行剪枝
8、根据节点编号进行剪枝


build_tree:决策树构建主函数,用于构建一棵CART决策树


决策树构建主函数,用于构建一棵CART决策树


print_node:打印决策树


将决策树结构打印出来,并可选择是否显示节点上的相关信息。


predict:决策树的预测函数


传入决策树和要预测的x,即可得到决策树的预测结果


cal_prune_list:计算CCP路径


每次按最小alpha值,迭代剪枝,直到剪完整棵树,最后返回剪枝的路径。
注:剪枝动作是模拟进行的,并不影响树本身。


prune_bad:对无效节点进行剪枝


如果节点的分枝并不能降低判别误差,说明节点的分枝是无益的,对这些无效节点进行剪枝


prune_nodes:决策树的节点剪枝函数


传入要剪掉的节点编号,对树进行剪枝


prune:决策树的alpha剪枝函数


传入alpha值,按alpha值进行剪枝




8个辅助函数:用于辅助计算的函数。


Node:节点类
cal_alpha:计算各个节点的临界alpha,并返回最小临界alpha
prune_alpha:剪掉当前树节点alpha<=某个值的节点
set_node_id:重新设置子孙节点编号
find_best_cut:找出x变量的最佳切割点
cal_gain:找出x变量计算使用变量x,切值为cut_val时的收益函数(gini)最佳切割点
class2Cmat:将类别转为类别矩阵
cal_gain:找出x变量计算使用变量x,切值为cut_val时的收益函数(gini)最佳切割点





以上就是自现实一棵CART决策树的代码了~







 End 






内容纠正