Python利用Pytorch实现绘制ROC与PR曲线图

当我们需要评估二分类模型的性能时,ROC曲线和PR曲线是两个常用的工具。在Python中,我们可以使用PyTorch库来绘制这些曲线。下面是绘制ROC曲线和PR曲线的完整攻略,包括两个示例说明。

1. 绘制ROC曲线

ROC曲线是一种用于评估二分类模型性能的工具,它显示了真阳性率(TPR)与假阳性率(FPR)之间的关系。以下是使用PyTorch绘制ROC曲线的步骤:

  1. 导入必要的库

python
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
import torch

  1. 定义模型和数据

```python
# 定义模型
class Net(torch.nn.Module):
def init(self):
super(Net, self).init()
self.fc1 = torch.nn.Linear(2, 10)
self.fc2 = torch.nn.Linear(10, 1)
self.sigmoid = torch.nn.Sigmoid()

   def forward(self, x):
       x = torch.relu(self.fc1(x))
       x = self.sigmoid(self.fc2(x))
       return x

# 定义数据
x = torch.randn(1000, 2)
y = torch.randint(0, 2, (1000, 1)).float()
```

  1. 训练模型并预测

```python
# 训练模型
net = Net()
criterion = torch.nn.BCELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
for epoch in range(100):
optimizer.zero_grad()
output = net(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()

# 预测
y_pred = net(x).detach().numpy()
```

  1. 计算FPR和TPR

python
fpr, tpr, _ = roc_curve(y, y_pred)
roc_auc = auc(fpr, tpr)

  1. 绘制ROC曲线

python
plt.figure()
lw = 2
plt.plot(fpr, tpr, color='darkorange',
lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()

运行上述代码,即可绘制ROC曲线。

2. 绘制PR曲线

PR曲线是一种用于评估二分类模型性能的工具,它显示了精确率(Precision)与召回率(Recall)之间的关系。以下是使用PyTorch绘制PR曲线的步骤:

  1. 导入必要的库

python
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, auc
import torch

  1. 定义模型和数据

```python
# 定义模型
class Net(torch.nn.Module):
def init(self):
super(Net, self).init()
self.fc1 = torch.nn.Linear(2, 10)
self.fc2 = torch.nn.Linear(10, 1)
self.sigmoid = torch.nn.Sigmoid()

   def forward(self, x):
       x = torch.relu(self.fc1(x))
       x = self.sigmoid(self.fc2(x))
       return x

# 定义数据
x = torch.randn(1000, 2)
y = torch.randint(0, 2, (1000, 1)).float()
```

  1. 训练模型并预测

```python
# 训练模型
net = Net()
criterion = torch.nn.BCELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
for epoch in range(100):
optimizer.zero_grad()
output = net(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()

# 预测
y_pred = net(x).detach().numpy()
```

  1. 计算Precision和Recall

python
precision, recall, _ = precision_recall_curve(y, y_pred)
pr_auc = auc(recall, precision)

  1. 绘制PR曲线

python
plt.figure()
lw = 2
plt.plot(recall, precision, color='darkorange',
lw=lw, label='PR curve (area = %0.2f)' % pr_auc)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall example')
plt.legend(loc="lower right")
plt.show()

运行上述代码,即可绘制PR曲线。

以上就是使用PyTorch绘制ROC曲线和PR曲线的完整攻略,包括两个示例说明。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Python利用Pytorch实现绘制ROC与PR曲线图 - Python技术站

(0)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • Pytorch基础-张量基本操作

    Pytorch 中,张量的操作分为结构操作和数学运算,其理解就如字面意思。结构操作就是改变张量本身的结构,数学运算就是对张量的元素值完成数学运算。 一,张量的基本操作 二,维度变换 2.1,squeeze vs unsqueeze 维度增减 2.2,transpose vs permute 维度交换 三,索引切片 3.1,规则索引切片方式 3.2,gathe…

    2023年4月6日
    00
  • [PyTorch] Facebook Research – Mask R-CNN Benchmark 的安装与测试

    Github项目链接:https://github.com/facebookresearch/maskrcnn-benchmark maskrcnn_benchmark 安装步骤: 安装Anaconda3,创建虚拟环境。 conda activate maskrcnn conda create -n maskrcnn python=3 conda activ…

    2023年4月8日
    00
  • pytorch 多gpu训练

    pytorch 多gpu训练 用nn.DataParallel重新包装一下 数据并行有三种情况 前向过程 device_ids=[0, 1, 2] model = model.cuda(device_ids[0]) model = nn.DataParallel(model, device_ids=device_ids) 只要将model重新包装一下就可以。…

    PyTorch 2023年4月6日
    00
  • Pytorch可视化的几种实现方法

    PyTorch是一个非常流行的深度学习框架,它提供了许多工具来帮助我们可视化模型和数据。在本文中,我们将介绍PyTorch可视化的几种实现方法,包括使用TensorBoard、使用Visdom和使用Matplotlib等。同时,我们还提供了两个示例说明。 使用TensorBoard TensorBoard是TensorFlow提供的一个可视化工具,但是它也可…

    PyTorch 2023年5月16日
    00
  • Ubuntu下安装pytorch(GPU版)

    我这里主要参考了:https://blog.csdn.net/yimingsilence/article/details/79631567 并根据自己在安装中遇到的情况做了一些改动。   先说明一下我的Ubuntu和GPU版本: Ubuntu 16.04 GPU:GEFORCE GTX 1060   1. 查看显卡型号 使用命令:lspci | grep -…

    PyTorch 2023年4月8日
    00
  • pytorch之torchvision.transforms图像变换实例

    在PyTorch中,torchvision.transforms模块提供了一系列用于图像变换的函数。本文将提供两个示例说明,以展示如何使用torchvision.transforms模块进行图像变换。 示例1:使用torchvision.transforms进行图像旋转 在这个示例中,我们将使用torchvision.transforms模块对图像进行旋转操…

    PyTorch 2023年5月15日
    00
  • pytorch的topk()函数

    pytorch.topk()用于返回Tensor中的前k个元素以及元素对应的索引值。例: import torch item=torch.IntTensor([1,2,4,7,3,2]) value,indices=torch.topk(item,3) print(“value:”,value) print(“indices:”,indices) 输出结果为…

    2023年4月8日
    00
  • pytorch conditional GAN 调试笔记

    推荐的几个开源实现 znxlwm 使用InfoGAN的结构,卷积反卷积 eriklindernoren 把mnist转成1维,label用了embedding wiseodd 直接从tensorflow代码转换过来的,数据集居然还用tf的数据集。。 Yangyangii 转1维向量,全连接 FangYang970206 提供了多标签作为条件的实现思路 znx…

    2023年4月8日
    00
合作推广
合作推广
分享本页
返回顶部