pytorch分类模型绘制混淆矩阵以及可视化详解

以下是关于“pytorch分类模型绘制混淆矩阵以及可视化详解”的完整攻略,其中包含两个示例说明。

示例1:绘制混淆矩阵

步骤1:导入必要的库

在绘制混淆矩阵之前,我们需要导入一些必要的库,包括numpymatplotlibsklearn

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

步骤2:准备数据

在这个示例中,我们使用一个虚拟的分类模型来演示如何绘制混淆矩阵。我们首先生成一些随机的预测结果和真实标签。

# 生成随机的预测结果和真实标签
y_pred = np.random.randint(0, 10, size=100)
y_true = np.random.randint(0, 10, size=100)

步骤3:计算混淆矩阵

使用sklearn库中的confusion_matrix函数计算混淆矩阵。

# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)

步骤4:绘制混淆矩阵

使用matplotlib库中的imshow函数绘制混淆矩阵。

# 绘制混淆矩阵
plt.imshow(cm, cmap=plt.cm.Blues)
plt.colorbar()
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.xticks(np.arange(10))
plt.yticks(np.arange(10))
plt.show()

步骤5:结果分析

绘制出来的混淆矩阵可以帮助我们分析模型的分类效果。在这个示例中,我们使用随机生成的数据,因此混淆矩阵中的数字没有实际意义。但是,在实际应用中,我们可以根据混淆矩阵中的数字来判断模型的分类效果。

示例2:可视化分类模型的预测结果

步骤1:导入必要的库

在可视化分类模型的预测结果之前,我们需要导入一些必要的库,包括numpymatplotlibtorchvision

import numpy as np
import matplotlib.pyplot as plt
import torchvision

步骤2:准备数据

在这个示例中,我们使用torchvision库中的CIFAR10数据集来演示如何可视化分类模型的预测结果。我们首先加载数据集并随机选择一些样本。

# 加载CIFAR10数据集
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=None)

# 随机选择一些样本
indices = np.random.choice(len(testset), size=16, replace=False)
images = [testset[i][0] for i in indices]
labels = [testset[i][1] for i in indices]

步骤3:加载模型

在这个示例中,我们使用一个预训练的ResNet18模型来演示如何可视化分类模型的预测结果。我们首先加载模型并设置为评估模式。

# 加载预训练的ResNet18模型
model = torchvision.models.resnet18(pretrained=True)

# 设置为评估模式
model.eval()

步骤4:进行预测

使用加载的模型对随机选择的样本进行预测,并将预测结果和真实标签存储在列表中。

# 对随机选择的样本进行预测
with torch.no_grad():
    outputs = model(torch.stack(images))
    _, predicted = torch.max(outputs, 1)

# 将预测结果和真实标签存储在列表中
predicted = predicted.numpy()
labels = np.array(labels)

步骤5:可视化预测结果

使用matplotlib库中的subplot函数可视化预测结果。

# 可视化预测结果
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(10, 10))
for i, ax in enumerate(axes.flat):
    ax.imshow(images[i])
    ax.set_title(f'Predicted: {predicted[i]}, True: {labels[i]}')
    ax.axis('off')
plt.show()

步骤6:结果分析

可视化出来的预测结果可以帮助我们分析模型的分类效果。在这个示例中,我们使用预训练的ResNet18模型对CIFAR10数据集进行预测,并将预测结果可视化出来。通过观察可视化结果,我们可以判断模型的分类效果。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch分类模型绘制混淆矩阵以及可视化详解 - Python技术站

(0)
上一篇 2023年5月16日
下一篇 2023年5月16日

相关文章

  • PyTorch安装及试用 基于Anaconda3

      设置Torch国内镜像 conda config –add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/   安装PyTorch和TorchVision conda install pytorch torchvision   测试pytorch版本 impor…

    PyTorch 2023年4月8日
    00
  • pytorch normal_(), fill_()

    比如有个张量a,那么a.normal_()就表示用标准正态分布填充a,是in_place操作,如下图所示: 比如有个张量b,那么b.fill_(0)就表示用0填充b,是in_place操作,如下图所示:   这两个函数常常用在神经网络模型参数的初始化中,例如 import torch.nn as nn net = nn.Linear(16, 2) for m…

    2023年4月7日
    00
  • pytorch 实现张量tensor,图片,CPU,GPU,数组等的转换

    在PyTorch中,我们可以使用torch.Tensor类来创建张量。张量是PyTorch中最基本的数据结构,它可以表示任意维度的数组。在本文中,我们将深入探讨如何在PyTorch中实现张量、图片、CPU、GPU、数组等的转换。 实现张量的转换 在PyTorch中,我们可以使用torch.Tensor类来创建张量。我们可以使用torch.Tensor()函数…

    PyTorch 2023年5月15日
    00
  • pytorch 入门指南

    两类深度学习框架的优缺点 动态图(PyTorch) 计算图的进行与代码的运行时同时进行的。 静态图(Tensorflow <2.0) 自建命名体系 自建时序控制 难以介入 使用深度学习框架的优点 GPU 加速 (cuda) 自动求导 常用网络层的API PyTorch 的特点 支持 GPU 动态神经网络 Python 优先 命令式体验 轻松扩展 1.P…

    PyTorch 2023年4月8日
    00
  • [pytorch]pytorch loss function 总结

    原文: http://www.voidcn.com/article/p-rtzqgqkz-bpg.html 最近看了下 PyTorch 的损失函数文档,整理了下自己的理解,重新格式化了公式如下,以便以后查阅。 注意下面的损失函数都是在单个样本上计算的,粗体表示向量,否则是标量。向量的维度用 表示。 nn.L1Loss nn.SmoothL1Loss 也叫作 …

    PyTorch 2023年4月8日
    00
  • 动手学pytorch-注意力机制和Seq2Seq模型

    注意力机制和Seq2Seq模型 1.基本概念 2.两种常用的attention层 3.带注意力机制的Seq2Seq模型 4.实验 1. 基本概念 Attention 是一种通用的带权池化方法,输入由两部分构成:询问(query)和键值对(key-value pairs)。(????_????∈ℝ^{????_????}, ????_????∈ℝ^{????_…

    2023年4月6日
    00
  • 3、pytorch实现最基础的MLP网络

    %matplotlib inline import numpy as np import torch from torch import nn import matplotlib.pyplot as plt d = 1 n = 200 X = torch.rand(n,d) #200*1, batch * feature_dim #y = 3*torch.s…

    PyTorch 2023年4月7日
    00
  • 详解anaconda离线安装pytorchGPU版

    详解Anaconda离线安装PyTorch GPU版 本文将介绍如何使用Anaconda离线安装PyTorch GPU版。我们将提供两个示例,分别是使用conda和pip安装PyTorch GPU版。 1. 下载PyTorch GPU版 首先,我们需要下载PyTorch GPU版的安装包。我们可以从PyTorch官网下载对应版本的安装包,也可以使用以下命令从…

    PyTorch 2023年5月15日
    00
合作推广
合作推广
分享本页
返回顶部