本站原创文章,转载请说明来自《老饼讲解-机器学习》www.bbbdata.com
混淆矩阵(Confusion Matrix)是机器学习中常用的一个概念,用于评估分类模型的性能
本文讲解混淆矩阵的用途,怎么看混淆矩阵,以及如何使用代码计算混淆矩阵并画出热力图
通过本文可以快速了解混淆矩阵是什么,如何使用混淆矩阵及热力图来评估分类模型的性能
本节直观地介绍什么是混淆矩阵,快速了解混淆矩阵怎么看
什么是混淆矩阵
混淆矩阵CM(Confusion Matrix)是一个用于查看多分类模型预测效果的矩阵
什么是混淆矩阵CM
下图是二分类的解释,直接看图即明白
混淆矩阵CM是一个类别个数类别个数的矩阵,如下图
其中, 代表第 类样本被判为类的个数与占比
易知,混淆矩阵CM的对角线就为判断准确的个数
从混淆矩阵可以很清晰看到各类别的判别情况,因为多分类时一般都会通过混淆矩阵来查看模型效果
本节展示如何使用python计算混淆矩阵以及画热力图
使用python画混淆矩阵
在python中可以使用confusion_matrix来计算混淆矩阵,再用heatmap画出热力图
具体示例代码如下
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
# 真实标签和预测标签
y_true = [0,1,2,2,0,0,1,2,1,0] # 样本的真实标签
y_pred = [0,0,2,2,1,1,0,2,1,0] # 样本的预测标签
cm = confusion_matrix(y_true, y_pred) # 生成混淆矩阵
sns.heatmap(cm, annot=True, fmt='d') # 使用Seaborn的heatmap来画混淆矩阵
plt.title('Confusion Matrix')
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.show()
运行结果如下:
End