本站原创文章,转载请说明来自《老饼讲解-机器学习》www.bbbdata.com
本文为《决策树建模完整流程》中的代码实现
本节展示决策树建模的完整代码,用于建模时进行借鉴与复制
决策树建模-完整流程代码
《决策树建模完整流程》中的完整代码如下:
# -*- coding: utf-8 -*-
from sklearn.datasets import load_iris
from sklearn import tree
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
import numpy as np
import pandas as pd
import graphviz
import datetime
#-----------错误样本在叶子节点的分布-----------------
def cal_err_node(clf,X,y):
# 计算错误样本在叶子节点上的分布
leaf_node = clf.apply(X)
predict_y = clf.predict(X)
is_err = predict_y!=y
df = pd.DataFrame({"leaf_node":leaf_node,"num":np.ones(len(leaf_node)).astype(int),"is_err":is_err})
df = df.groupby(["leaf_node"]).sum().reset_index(drop=False)
df["err_rate"] = df["is_err"]/df["num"]
df = df[df['err_rate']>0].reset_index(drop=True)
df = df.sort_values(by='err_rate', ascending=False)
return df
#--------数据加载-----------------------------------
iris = load_iris() # 加载数据
all_X = iris.data
all_y = iris.target
#--------数据预处理-----------------------------------
train_X, test_X, train_y, test_y = train_test_split(all_X, all_y, test_size=0.2, random_state=0)
#--------模型极限试探-----------------------------------
clf = tree.DecisionTreeClassifier(max_depth=3,min_samples_leaf=8,random_state=20)
clf = clf.fit(all_X, all_y)
total_socre = clf.score(all_X,all_y)
clf = clf.fit(train_X, train_y)
train_socre = clf.score(train_X,train_y)
print("\n========模型试探============")
print("全量数据建模准确率:",total_socre)
print("训练数据建模准确率:",train_socre)
#-------网格扫描最优训练参数---------------------------
clf = tree.DecisionTreeClassifier(random_state=0)
param_test = {
'max_depth':range(3,15,3) #最大深度
,'min_samples_leaf':range(5,20,3)
,'random_state':range(0,100,10)
# ,'min_samples_split':range(5,20,3)
# ,'splitter':('best','random') #
# ,'criterion':('gini','entropy') #基尼 信息熵
}
gsearch= GridSearchCV(estimator=clf, # 对应模型
param_grid=param_test, # 要找最优的参数
scoring=None, # 准确度评估标准
n_jobs=-1, # 并行数个数,-1:跟CPU核数一致
cv = 5, # 交叉验证 5折
verbose=0 # 输出训练过程
)
gsearch.fit(train_X,train_y)
print("\n========最优参数扫描结果============")
print("模型最佳评分:",gsearch.best_score_)
print("模型最佳参数:",gsearch.best_params_)
#-------用最优参数训练模型---------------------------
clf = tree.DecisionTreeClassifier(**gsearch.best_params_)
clf = clf.fit(train_X, train_y)
pruning_path = clf.cost_complexity_pruning_path(train_X, train_y)
test_score = clf.score(test_X,test_y) # 统计得分(错误占比)
err_node_df = cal_err_node(clf, test_X, test_y)
print("\n========最优参数训练结果============")
print("\n---------决策树信息--------------")
print("叶子个数:",clf.get_n_leaves())
print("树的深度:",clf.get_depth())
print("特征权重:",clf.feature_importances_)
print("\n--------测试样本准确率:----------:\n",test_score)
print("\n----错误样本在叶子节点的分布--------:")
print(err_node_df)
print("\n------CCP路径---------------")
print("ccp_alphas:",pruning_path['ccp_alphas'])
print("impurities:",pruning_path['impurities'])
dot_data = tree.export_graphviz(clf, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = graphviz.Source(dot_data)
graph.view("tree"+datetime.datetime.now().strftime('%H-%M-%S'))
'''
查看模型信息,决定剪枝alpha,再执行以下剪枝代码
'''
#-------最后阶段:剪枝---------------------------
clf = tree.DecisionTreeClassifier(max_depth=3,min_samples_leaf=8,random_state=20,ccp_alpha=0.1)
clf = clf.fit(train_X, train_y)
test_score = clf.score(test_X,test_y)
print("\n==============剪枝=====================:\n")
print("测试样本准确率:",test_score)
print("叶子节点个数",clf.get_n_leaves())
运行结果
运行结果如下:
========模型试探============
全量数据建模准确率: 1.0
训练数据建模准确率: 1.0
========最优参数扫描结果============
模型最佳评分: 0.95
模型最佳参数: {'max_depth': 3, 'min_samples_leaf': 8, 'random_state': 20}
========最优参数训练结果============
---------决策树信息--------------
叶子个数: 5
树的深度: 3
特征权重: [0.00277564 0. 0.54604969 0.45117467]
--------测试样本准确率:----------:
0.9666666666666667
----错误样本在叶子节点的分布--------:leaf_node num is_err err_rate
0 4 9 1 0.111111
------CCP路径---------------
ccp_alphas: [0. 0.00167683 0.01384615 0.25871926 0.32988169]
impurities: [0.06073718 0.06241401 0.07626016 0.33497942 0.66486111]
==============剪枝=====================:
测试样本准确率: 0.9666666666666667
叶子节点个数 3
如果决策树画图部分报错,请参考《决策树结果可视化》
End