以下是关于“pytorch分类模型绘制混淆矩阵以及可视化详解”的完整攻略,其中包含两个示例说明。
示例1:绘制混淆矩阵
步骤1:导入必要的库
在绘制混淆矩阵之前,我们需要导入一些必要的库,包括numpy
、matplotlib
和sklearn
。
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
步骤2:准备数据
在这个示例中,我们使用一个虚拟的分类模型来演示如何绘制混淆矩阵。我们首先生成一些随机的预测结果和真实标签。
# 生成随机的预测结果和真实标签
y_pred = np.random.randint(0, 10, size=100)
y_true = np.random.randint(0, 10, size=100)
步骤3:计算混淆矩阵
使用sklearn
库中的confusion_matrix
函数计算混淆矩阵。
# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)
步骤4:绘制混淆矩阵
使用matplotlib
库中的imshow
函数绘制混淆矩阵。
# 绘制混淆矩阵
plt.imshow(cm, cmap=plt.cm.Blues)
plt.colorbar()
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.xticks(np.arange(10))
plt.yticks(np.arange(10))
plt.show()
步骤5:结果分析
绘制出来的混淆矩阵可以帮助我们分析模型的分类效果。在这个示例中,我们使用随机生成的数据,因此混淆矩阵中的数字没有实际意义。但是,在实际应用中,我们可以根据混淆矩阵中的数字来判断模型的分类效果。
示例2:可视化分类模型的预测结果
步骤1:导入必要的库
在可视化分类模型的预测结果之前,我们需要导入一些必要的库,包括numpy
、matplotlib
和torchvision
。
import numpy as np
import matplotlib.pyplot as plt
import torchvision
步骤2:准备数据
在这个示例中,我们使用torchvision
库中的CIFAR10
数据集来演示如何可视化分类模型的预测结果。我们首先加载数据集并随机选择一些样本。
# 加载CIFAR10数据集
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=None)
# 随机选择一些样本
indices = np.random.choice(len(testset), size=16, replace=False)
images = [testset[i][0] for i in indices]
labels = [testset[i][1] for i in indices]
步骤3:加载模型
在这个示例中,我们使用一个预训练的ResNet18模型来演示如何可视化分类模型的预测结果。我们首先加载模型并设置为评估模式。
# 加载预训练的ResNet18模型
model = torchvision.models.resnet18(pretrained=True)
# 设置为评估模式
model.eval()
步骤4:进行预测
使用加载的模型对随机选择的样本进行预测,并将预测结果和真实标签存储在列表中。
# 对随机选择的样本进行预测
with torch.no_grad():
outputs = model(torch.stack(images))
_, predicted = torch.max(outputs, 1)
# 将预测结果和真实标签存储在列表中
predicted = predicted.numpy()
labels = np.array(labels)
步骤5:可视化预测结果
使用matplotlib
库中的subplot
函数可视化预测结果。
# 可视化预测结果
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(10, 10))
for i, ax in enumerate(axes.flat):
ax.imshow(images[i])
ax.set_title(f'Predicted: {predicted[i]}, True: {labels[i]}')
ax.axis('off')
plt.show()
步骤6:结果分析
可视化出来的预测结果可以帮助我们分析模型的分类效果。在这个示例中,我们使用预训练的ResNet18模型对CIFAR10数据集进行预测,并将预测结果可视化出来。通过观察可视化结果,我们可以判断模型的分类效果。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch分类模型绘制混淆矩阵以及可视化详解 - Python技术站