Python利用Pytorch实现绘制ROC与PR曲线图

yizhihongxing

当我们需要评估二分类模型的性能时,ROC曲线和PR曲线是两个常用的工具。在Python中,我们可以使用PyTorch库来绘制这些曲线。下面是绘制ROC曲线和PR曲线的完整攻略,包括两个示例说明。

1. 绘制ROC曲线

ROC曲线是一种用于评估二分类模型性能的工具,它显示了真阳性率(TPR)与假阳性率(FPR)之间的关系。以下是使用PyTorch绘制ROC曲线的步骤:

  1. 导入必要的库

python
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
import torch

  1. 定义模型和数据

```python
# 定义模型
class Net(torch.nn.Module):
def init(self):
super(Net, self).init()
self.fc1 = torch.nn.Linear(2, 10)
self.fc2 = torch.nn.Linear(10, 1)
self.sigmoid = torch.nn.Sigmoid()

   def forward(self, x):
       x = torch.relu(self.fc1(x))
       x = self.sigmoid(self.fc2(x))
       return x

# 定义数据
x = torch.randn(1000, 2)
y = torch.randint(0, 2, (1000, 1)).float()
```

  1. 训练模型并预测

```python
# 训练模型
net = Net()
criterion = torch.nn.BCELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
for epoch in range(100):
optimizer.zero_grad()
output = net(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()

# 预测
y_pred = net(x).detach().numpy()
```

  1. 计算FPR和TPR

python
fpr, tpr, _ = roc_curve(y, y_pred)
roc_auc = auc(fpr, tpr)

  1. 绘制ROC曲线

python
plt.figure()
lw = 2
plt.plot(fpr, tpr, color='darkorange',
lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()

运行上述代码,即可绘制ROC曲线。

2. 绘制PR曲线

PR曲线是一种用于评估二分类模型性能的工具,它显示了精确率(Precision)与召回率(Recall)之间的关系。以下是使用PyTorch绘制PR曲线的步骤:

  1. 导入必要的库

python
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, auc
import torch

  1. 定义模型和数据

```python
# 定义模型
class Net(torch.nn.Module):
def init(self):
super(Net, self).init()
self.fc1 = torch.nn.Linear(2, 10)
self.fc2 = torch.nn.Linear(10, 1)
self.sigmoid = torch.nn.Sigmoid()

   def forward(self, x):
       x = torch.relu(self.fc1(x))
       x = self.sigmoid(self.fc2(x))
       return x

# 定义数据
x = torch.randn(1000, 2)
y = torch.randint(0, 2, (1000, 1)).float()
```

  1. 训练模型并预测

```python
# 训练模型
net = Net()
criterion = torch.nn.BCELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
for epoch in range(100):
optimizer.zero_grad()
output = net(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()

# 预测
y_pred = net(x).detach().numpy()
```

  1. 计算Precision和Recall

python
precision, recall, _ = precision_recall_curve(y, y_pred)
pr_auc = auc(recall, precision)

  1. 绘制PR曲线

python
plt.figure()
lw = 2
plt.plot(recall, precision, color='darkorange',
lw=lw, label='PR curve (area = %0.2f)' % pr_auc)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall example')
plt.legend(loc="lower right")
plt.show()

运行上述代码,即可绘制PR曲线。

以上就是使用PyTorch绘制ROC曲线和PR曲线的完整攻略,包括两个示例说明。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Python利用Pytorch实现绘制ROC与PR曲线图 - Python技术站

(0)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • pytorch 在sequential中使用view来reshape的例子

    在PyTorch中,我们可以使用Sequential模块来构建神经网络。Sequential模块允许我们按照顺序添加一系列的层,从而构建一个完整的神经网络。在Sequential模块中,我们可以使用view函数来对张量进行reshape操作,以适应不同的层的输入和输出形状。 以下是两个使用Sequential模块和view函数的示例: 示例1:使用Seque…

    PyTorch 2023年5月15日
    00
  • Pytorch mask_select 函数的用法详解

    PyTorch mask_select 函数的用法详解 在 PyTorch 中,mask_select 函数是一种常见的选择操作,它可以根据给定的掩码(mask)从输入张量中选择元素。本文将详细讲解 PyTorch 中 mask_select 函数的用法,并提供两个示例说明。 1. mask_select 函数的基本用法 在 PyTorch 中,我们可以使用…

    PyTorch 2023年5月16日
    00
  • new_zeros() pytorch版本的转换方式

    PyTorch中new_zeros()函数的用法 new_zeros()是PyTorch中的一个函数,用于创建一个指定形状的全零张量。以下是new_zeros()函数的用法: torch.Tensor.new_zeros(size, dtype=None, device=None, requires_grad=False) 其中,size是张量的形状,dty…

    PyTorch 2023年5月15日
    00
  • pytorch使用tensorboardX进行loss可视化实例

    PyTorch使用TensorboardX进行Loss可视化实例 在PyTorch中,我们可以使用TensorboardX库将训练过程中的Loss可视化。本文将介绍如何使用TensorboardX库进行Loss可视化,并提供两个示例说明。 1. 安装TensorboardX 要使用TensorboardX库,我们需要先安装它。可以使用以下命令在终端中安装Te…

    PyTorch 2023年5月15日
    00
  • PyTorch自定义数据集

    数据传递机制 我们首先回顾识别手写数字的程序: … Dataset = torchvision.datasets.MNIST(root=’./mnist/’, train=True, transform=transform, download=True,) dataloader = torch.utils.data.DataLoader(dataset=…

    2023年4月7日
    00
  • pytorch–(MisMatch in shape & invalid index of a 0-dim tensor)

    在尝试运行CVPR2019一篇行为识别论文的代码时,遇到了两个问题,记录如下。但是,原因没懂,如果看此文章的你了解原理,欢迎留言交流吖。 github代码链接: 方法1: 根据定位的错误位置,我的是215行,将criticD_real.bachward(mone)改为criticD_real.bachward(mone.mean())上一行注释。保存后运行,…

    PyTorch 2023年4月6日
    00
  • pytorch 如何自定义卷积核权值参数

    PyTorch自定义卷积核权值参数 在PyTorch中,我们可以自定义卷积核权值参数。本文将介绍如何自定义卷积核权值参数,并提供两个示例。 示例一:自定义卷积核权值参数 我们可以使用nn.Parameter()函数创建可训练的权值参数。可以使用以下代码创建自定义卷积核权值参数: import torch import torch.nn as nn class…

    PyTorch 2023年5月15日
    00
  • PyTorch环境配置及安装过程

    以下是PyTorch环境配置及安装过程的完整攻略,包括Windows、macOS和Linux三个平台的安装步骤。同时,还提供了两个示例说明。 Windows平台 1. 安装Anaconda 在Windows平台上,我们可以使用Anaconda来安装PyTorch。首先,我们需要下载并安装Anaconda。可以在官网上下载对应的安装包,然后按照提示进行安装。 …

    PyTorch 2023年5月16日
    00
合作推广
合作推广
分享本页
返回顶部