当我们需要评估二分类模型的性能时,ROC曲线和PR曲线是两个常用的工具。在Python中,我们可以使用PyTorch库来绘制这些曲线。下面是绘制ROC曲线和PR曲线的完整攻略,包括两个示例说明。
1. 绘制ROC曲线
ROC曲线是一种用于评估二分类模型性能的工具,它显示了真阳性率(TPR)与假阳性率(FPR)之间的关系。以下是使用PyTorch绘制ROC曲线的步骤:
- 导入必要的库
python
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
import torch
- 定义模型和数据
```python
# 定义模型
class Net(torch.nn.Module):
def init(self):
super(Net, self).init()
self.fc1 = torch.nn.Linear(2, 10)
self.fc2 = torch.nn.Linear(10, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.sigmoid(self.fc2(x))
return x
# 定义数据
x = torch.randn(1000, 2)
y = torch.randint(0, 2, (1000, 1)).float()
```
- 训练模型并预测
```python
# 训练模型
net = Net()
criterion = torch.nn.BCELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
for epoch in range(100):
optimizer.zero_grad()
output = net(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()
# 预测
y_pred = net(x).detach().numpy()
```
- 计算FPR和TPR
python
fpr, tpr, _ = roc_curve(y, y_pred)
roc_auc = auc(fpr, tpr)
- 绘制ROC曲线
python
plt.figure()
lw = 2
plt.plot(fpr, tpr, color='darkorange',
lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()
运行上述代码,即可绘制ROC曲线。
2. 绘制PR曲线
PR曲线是一种用于评估二分类模型性能的工具,它显示了精确率(Precision)与召回率(Recall)之间的关系。以下是使用PyTorch绘制PR曲线的步骤:
- 导入必要的库
python
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, auc
import torch
- 定义模型和数据
```python
# 定义模型
class Net(torch.nn.Module):
def init(self):
super(Net, self).init()
self.fc1 = torch.nn.Linear(2, 10)
self.fc2 = torch.nn.Linear(10, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.sigmoid(self.fc2(x))
return x
# 定义数据
x = torch.randn(1000, 2)
y = torch.randint(0, 2, (1000, 1)).float()
```
- 训练模型并预测
```python
# 训练模型
net = Net()
criterion = torch.nn.BCELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
for epoch in range(100):
optimizer.zero_grad()
output = net(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()
# 预测
y_pred = net(x).detach().numpy()
```
- 计算Precision和Recall
python
precision, recall, _ = precision_recall_curve(y, y_pred)
pr_auc = auc(recall, precision)
- 绘制PR曲线
python
plt.figure()
lw = 2
plt.plot(recall, precision, color='darkorange',
lw=lw, label='PR curve (area = %0.2f)' % pr_auc)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall example')
plt.legend(loc="lower right")
plt.show()
运行上述代码,即可绘制PR曲线。
以上就是使用PyTorch绘制ROC曲线和PR曲线的完整攻略,包括两个示例说明。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Python利用Pytorch实现绘制ROC与PR曲线图 - Python技术站