当我们要评估一个多分类模型时,一个有用的工具就是混淆矩阵。混淆矩阵提供了模型在每个类别上的分类表现。在本文中,我们将解释如何使用Python实现两种多分类混淆矩阵。
多分类混淆矩阵
在多分类问题中,混淆矩阵是一个表格,用于显示实际标签和预测标签之间的关系。
假设,我们有一个多分类模型,其中包含三个类别:汽车、卡车和自行车。我们通过混淆矩阵来看看模型在这三个类别上的表现。
下图是一个三类混淆矩阵的示例:
汽车 卡车 自行车
实际标签 汽车 20 5 0
卡车 3 12 2
自行车 1 0 13
例如,该模型错误地将一辆卡车分类为汽车,这个错误被标识为(卡车,汽车)。
接下来,我们介绍两种不同的方法来计算多分类混淆矩阵的指标。
方法1: sklearn库中的混淆矩阵
本方法中,我们将使用scikit-learn库的confusion_matrix()
函数来计算多分类混淆矩阵。
下面是代码示例:
from sklearn.metrics import confusion_matrix
import numpy as np
y_true = [0, 1, 2, 0, 1, 2, 0, 1, 2]
y_pred = [0, 0, 0, 0, 1, 1, 2, 2, 2]
confusion_matrix = confusion_matrix(y_true, y_pred)
print(confusion_matrix)
上述代码演示了一个针对3个类别的多分类混淆矩阵的计算。
以下是输出结果:
[[3 0 0]
[1 1 1]
[0 1 2]]
方法2:自定义混淆矩阵
本方法中,我们将使用numpy的bincount()
函数来手动计算多分类混淆矩阵。下文代码演示了这一操作方法:
import numpy as np
y_true = [0, 1, 2, 0, 1, 2, 0, 1, 2]
y_pred = [0, 0, 0, 0, 1, 1, 2, 2, 2]
n_classes = 3
confusion_matrix = np.zeros((n_classes, n_classes))
for true, pred in zip(y_true, y_pred):
confusion_matrix[true][pred] += 1
print(confusion_matrix)
以上代码的输出结果如下:
[[3. 0. 0.]
[1. 1. 1.]
[0. 1. 2.]]
至此,我们已经介绍了两种多分类混淆矩阵的实现方法。这两种方法都可以计算多分类模型的表现。您可以根据具体情况选择使用哪一种方法。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Python实现两种多分类混淆矩阵 - Python技术站