Matplotlib 是一个常用的 Python 绘图库,可以用于绘制各种类型的图形,包括混淆矩阵。以下是绘制混淆矩阵的实现攻略:
1. 创建混淆矩阵
混淆矩阵是分类问题中一个重要的评估指标,它可以用来衡量分类器的性能。在 Python 中,我们可以使用 ConfusionMatrixDisplay 类来绘制混淆矩阵。以下是一个示例代码:
from sklearn.metrics import confusion_matrix
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt
# 生成分类数据集
X, y = make_classification(random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
# 训练逻辑回归模型
clf = LogisticRegression(random_state=0).fit(X_train, y_train)
# 预测并生成混淆矩阵
y_pred = clf.predict(X_test)
cm = confusion_matrix(y_test, y_pred, normalize='true')
# 绘制混淆矩阵
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Class 0", "Class 1"])
disp.plot()
plt.show()
这个示例代码中,首先我们使用 make_classification
函数生成一个随机的二分类数据集,然后将数据集分为训练集和测试集。接着我们使用逻辑回归模型训练数据,并预测测试集的标签。最后我们使用 confusion_matrix
函数生成混淆矩阵,并将它传递给 ConfusionMatrixDisplay
类。我们还需要设置 display_labels
参数来指定分类器中每个类的标签。我们调用 plot
函数绘制混淆矩阵,并使用 show
函数显示图形。
在混淆矩阵中,每一行代表真实标签,每一列代表预测标签。对角线上的元素表示分类正确的样本数量,其他元素则表示分类错误的样本数量。
2. 自定义混淆矩阵的样式
在 Matplotlib 中,我们可以使用 imshow
函数自定义混淆矩阵的样式。以下是一个示例代码:
from sklearn.metrics import confusion_matrix
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt
import numpy as np
# 生成分类数据集
X, y = make_classification(random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
# 训练逻辑回归模型
clf = LogisticRegression(random_state=0).fit(X_train, y_train)
# 预测并生成混淆矩阵
y_pred = clf.predict(X_test)
cm = confusion_matrix(y_test, y_pred, normalize='true')
# 自定义混淆矩阵的样式
fig, ax = plt.subplots()
im = ax.imshow(cm, cmap='Blues', interpolation='nearest')
ax.set_xticks(np.arange(cm.shape[1]))
ax.set_yticks(np.arange(cm.shape[0]))
ax.set_xticklabels(["Class 0", "Class 1"])
ax.set_yticklabels(["Class 0", "Class 1"])
ax.set_xlabel("Predicted label")
ax.set_ylabel("True label")
ax.set_title("Confusion Matrix")
fig.colorbar(im)
plt.show()
在这个示例代码中,我们使用了与前一个示例相同的数据集和逻辑回归模型。我们通过调用 imshow
函数将混淆矩阵作为图像显示。我们可以通过指定 cmap
参数来设置颜色映射,通过指定 interpolation
参数来设置插值方法。我们还可以调用 set_xticklabels
和 set_yticklabels
函数指定坐标轴的标签。最后,我们调用 colorbar
函数为混淆矩阵添加颜色条。
通过这两个示例,我们可以看到如何使用 Matplotlib 绘制混淆矩阵,并自定义混淆矩阵的样式。在实际应用中,混淆矩阵可以帮助我们更好地评估分类器的性能,帮助我们调整并改进分类器的表现。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Matplotlib绘制混淆矩阵的实现 - Python技术站