本站原创文章,转载请说明来自《老饼讲解-机器学习》www.bbbdata.com
CART决策树是常用的机器学习算法,它包括CART分类树与回归树,
回归树与分类树不同的地方在于,回归树的输出是数值,分类树输出的是类别。
本文展示一个用python(sklearn)实现的简单的CART分类树例子,用于学习sklearn分类树的调用方法
本节先展示本文用于决策树建模的问题背景与数据介绍
数 据 介 绍
现已采集150组 鸢尾花数据,包括鸢尾花的四个特征与鸢尾花的类别
数据如下(即sk-learn中的iris数据):
花萼长度 sepal length (cm) 、花萼宽度 sepal width (cm)
花瓣长度 petal length (cm) 、花瓣宽度 petal width (cm)
山鸢尾:0,杂色鸢尾:1,弗吉尼亚鸢尾:2
决 策 树 建 模 目 标
我们希望通过采集的数据,训练一个决策树模型,
之后应用该模型,可以根据鸢尾花的四个特征去预测它的类别。
本节讲解决策树解决分类问题的实现代码
决策树分类例子-实现代码
用决策树对以上分类问题进行建模的流程如下:
1. 建立决策树模型
2. 用数据训练决策树模型
3. 用训练好的决策树模型预测
在python中通过sklearn具体实现的代码如下:
from sklearn.datasets import load_iris
from sklearn import tree
#----------------数据准备----------------------------
iris = load_iris() # 加载数据
#---------------模型训练----------------------------------
clf = tree.DecisionTreeClassifier() # sk-learn的决策树模型
clf = clf.fit(iris.data, iris.target) # 用数据训练树模型构建()
r = tree.export_text(clf, feature_names=iris['feature_names'])
#---------------模型预测结果------------------------
text_x = iris.data[[0,1,50,51,100,101], :]
pred_target_prob = clf.predict_proba(text_x) # 预测类别概率
pred_target = clf.predict(text_x) # 预测类别
#---------------打印结果---------------------------
print("\n===模型======")
print(r)
print("\n===测试数据:=====")
print(text_x)
print("\n===预测所属类别概率:=====")
print(pred_target_prob)
print("\n===预测所属类别:======")
print(pred_target)
运行结果
运行代码后,输出如下:
===模型======
|--- petal length (cm) <= 2.45
| |--- class: 0
|--- petal length (cm) > 2.45
| |--- petal width (cm) <= 1.75
| | |--- petal length (cm) <= 4.95
| | | |--- petal width (cm) <= 1.65
| | | | |--- class: 1
| | | |--- petal width (cm) > 1.65
| | | | |--- class: 2
| | |--- petal length (cm) > 4.95
| | | |--- petal width (cm) <= 1.55
| | | | |--- class: 2
| | | |--- petal width (cm) > 1.55
| | | | |--- sepal length (cm) <= 6.95
| | | | | |--- class: 1
| | | | |--- sepal length (cm) > 6.95
| | | | | |--- class: 2
| |--- petal width (cm) > 1.75
| | |--- petal length (cm) <= 4.85
| | | |--- sepal width (cm) <= 3.10
| | | | |--- class: 2
| | | |--- sepal width (cm) > 3.10
| | | | |--- class: 1
| | |--- petal length (cm) > 4.85
| | | |--- class: 2
===测试数据:=====
[[5.1 3.5 1.4 0.2]
[4.9 3. 1.4 0.2]
[7. 3.2 4.7 1.4]
[6.4 3.2 4.5 1.5]
[6.3 3.3 6. 2.5]
[5.8 2.7 5.1 1.9]]
===预测所属类别概率:=====
[[1. 0. 0.]
[1. 0. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 0. 1.]
[0. 0. 1.]]
===预测所属类别:======
[0 0 1 1 2 2]
以上就是决策树的最简例子
End