机器学习-专题详述

【例子】一个简单的决策树分类例子

作者 : 老饼 发表日期 : 2022-06-26 09:46:57 更新日期 : 2024-11-17 09:56:02
本站原创文章,转载请说明来自《老饼讲解-机器学习》www.bbbdata.com 


  CART决策树是常用的机器学习算法,它包括CART分类树与回归树,

回归树与分类树不同的地方在于,回归树的输出是数值,分类树输出的是类别。

本文展示一个用python(sklearn)实现的简单的CART分类树例子,用于学习sklearn分类树的调用方法



  01. 决策树分类例子-问题简介  




本节先展示本文用于决策树建模的问题背景与数据介绍




       数  据 介 绍      


现已采集150组 鸢尾花数据,包括鸢尾花的四个特征与鸢尾花的类别
数据如下(即sk-learn中的iris数据):
  
 花萼长度 sepal length (cm) 、花萼宽度 sepal width (cm)   
花瓣长度 petal length (cm) 、花瓣宽度 petal width (cm)  
山鸢尾:0,杂色鸢尾:1,弗吉尼亚鸢尾:2    
               



      决 策 树 建 模 目 标      


我们希望通过采集的数据,训练一个决策树模型,
之后应用该模型,可以根据鸢尾花的四个特征去预测它的类别。





   02. 决策树分类例子-实现代码   



本节讲解决策树解决分类问题的实现代码



     决策树分类例子-实现代码    


用决策树对以上分类问题进行建模的流程如下:
1. 建立决策树模型                   
2. 用数据训练决策树模型         
3. 用训练好的决策树模型预测  
在python中通过sklearn具体实现的代码如下:
from sklearn.datasets import load_iris
from sklearn import tree

#----------------数据准备----------------------------
iris = load_iris()                          # 加载数据

#---------------模型训练----------------------------------
clf = tree.DecisionTreeClassifier()         # sk-learn的决策树模型
clf = clf.fit(iris.data, iris.target)        # 用数据训练树模型构建()
r = tree.export_text(clf, feature_names=iris['feature_names'])


#---------------模型预测结果------------------------
text_x = iris.data[[0,1,50,51,100,101], :]
pred_target_prob = clf.predict_proba(text_x)        # 预测类别概率
pred_target = clf.predict(text_x)              # 预测类别

#---------------打印结果---------------------------
print("\n===模型======")
print(r)
print("\n===测试数据:=====")
print(text_x)
print("\n===预测所属类别概率:=====")
print(pred_target_prob)
print("\n===预测所属类别:======")
print(pred_target)



    运行结果    


运行代码后,输出如下:
===模型======
|--- petal length (cm) <= 2.45
|   |--- class: 0
|--- petal length (cm) >  2.45
|   |--- petal width (cm) <= 1.75
|   |   |--- petal length (cm) <= 4.95
|   |   |   |--- petal width (cm) <= 1.65
|   |   |   |   |--- class: 1
|   |   |   |--- petal width (cm) >  1.65
|   |   |   |   |--- class: 2
|   |   |--- petal length (cm) >  4.95
|   |   |   |--- petal width (cm) <= 1.55
|   |   |   |   |--- class: 2
|   |   |   |--- petal width (cm) >  1.55
|   |   |   |   |--- sepal length (cm) <= 6.95
|   |   |   |   |   |--- class: 1
|   |   |   |   |--- sepal length (cm) >  6.95
|   |   |   |   |   |--- class: 2
|   |--- petal width (cm) >  1.75
|   |   |--- petal length (cm) <= 4.85
|   |   |   |--- sepal width (cm) <= 3.10
|   |   |   |   |--- class: 2
|   |   |   |--- sepal width (cm) >  3.10
|   |   |   |   |--- class: 1
|   |   |--- petal length (cm) >  4.85
|   |   |   |--- class: 2

===测试数据:=====
[[5.1 3.5 1.4 0.2]
 [4.9 3.  1.4 0.2]
 [7.  3.2 4.7 1.4]
 [6.4 3.2 4.5 1.5]
 [6.3 3.3 6.  2.5]
 [5.8 2.7 5.1 1.9]]

===预测所属类别概率:=====
[[1. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 0. 1.]
 [0. 0. 1.]]

===预测所属类别:======
[0 0 1 1 2 2]





以上就是决策树的最简例子






 End 




联系老饼