Python sklearn预测评估指标混淆矩阵计算示例详解

yizhihongxing

Python sklearn预测评估指标混淆矩阵计算示例详解

本文主要介绍如何使用Python中的sklearn库来计算模型的混淆矩阵,从而评估模型的预测性能。

混淆矩阵

混淆矩阵是模型性能评估的常用指标之一,以二分类问题为例,混淆矩阵通常包含4个元素:

  • 真实值为正例,模型预测结果为正例的数量(True Positive,TP)
  • 真实值为正例,模型预测结果为负例的数量(False Negative,FN)
  • 真实值为负例,模型预测结果为负例的数量(True Negative,TN)
  • 真实值为负例,模型预测结果为正例的数量(False Positive,FP)

下面我们以一个具体的例子来说明:

假设我们要预测一个人是否患有某种疾病,疾病存在为正例,疾病不存在为负例,同时我们从医院收集到了100个样本数据,其中50个样本为正例,50个样本为负例。我们使用某个模型进行预测,得到的结果如下表所示:

真实值/预测结果 预测为正例 预测为负例
真实值为正例 30 20
真实值为负例 10 40

通过上表,我们可以得到以下信息:

  • TP = 30,即模型正确预测出了30个疾病患者
  • FN = 20,即模型将20个疾病患者误判为非疾病
  • TN = 40,即模型正确预测出了40个非疾病的样本
  • FP = 10,即模型将10个非疾病的样本误判为疾病

这些信息可以通过统计真实值和预测结果的对应关系来计算得出,进而构成混淆矩阵。

计算混淆矩阵

Python中的sklearn库提供了计算混淆矩阵的函数。下面是一个示例:

from sklearn.metrics import confusion_matrix

y_true = [1, 1, 0, 1, 0, 0, 1, 0, 0, 0]
y_pred = [1, 0, 0, 1, 1, 0, 1, 1, 0, 0]

matrix = confusion_matrix(y_true, y_pred)
print(matrix)

在上述示例中,我们构造了两个列表y_truey_pred,分别存储了10个样本的真实值和模型的预测结果。然后使用confusion_matrix函数来计算混淆矩阵,并将结果输出。

这段代码的输出结果为:

[[3 2]
 [2 3]]

上述输出结果代表了混淆矩阵。矩阵的行表示真实值,列表示预测结果。在本例中,矩阵的左上角元素表示真实值为0且模型预测结果为0的数量,即TN = 3;右下角元素表示真实值为1且模型预测结果为1的数量,即TP = 3;左下角和右上角分别为FN = 2和FP = 2。

混淆矩阵相关指标的计算

在得到混淆矩阵之后,我们可以通过计算各项指标来评估模型的预测性能。下面介绍两个常用的指标:准确率和召回率。

准确率(Accuracy)

准确率是模型的预测结果与实际结果相同的样本数占总样本数的比例。计算公式为:

$$
\text{Accuracy} = \frac{TP + TN}{TP + TN + FP + FN}
$$

在sklearn库中,准确率可以通过accuracy_score函数来计算。下面是一个示例:

from sklearn.metrics import accuracy_score

y_true = [1, 1, 0, 1, 0, 0, 1, 0, 0, 0]
y_pred = [1, 0, 0, 1, 1, 0, 1, 1, 0, 0]

acc = accuracy_score(y_true, y_pred)
print(acc)

运行上述代码,会得到准确率为0.6,即模型的预测结果与实际结果相同的样本数占总样本数的比例为60%。

召回率(Recall)

召回率是指模型正确预测出的正例样本数占所有正例样本数的比例。计算公式为:

$$
\text{Recall} = \frac{TP}{TP + FN}
$$

在sklearn库中,召回率可以通过recall_score函数来计算。下面是一个示例:

from sklearn.metrics import recall_score

y_true = [1, 1, 0, 1, 0, 0, 1, 0, 0, 0]
y_pred = [1, 0, 0, 1, 1, 0, 1, 1, 0, 0]

rec = recall_score(y_true, y_pred)
print(rec)

运行上述代码,会得到召回率为0.6,即模型正确预测出的正例样本数占所有正例样本数的比例为60%。

示例说明

下面结合两个具体的示例说明混淆矩阵和相关指标的计算。

示例1:鸢尾花数据分类

鸢尾花数据是机器学习中经典的数据集之一,主要任务是根据花萼和花瓣的长度和宽度等特征,将鸢尾花的三个品种进行分类。这是一个典型的多分类问题。下面我们使用逻辑回归模型对数据集进行分类,并评估模型性能。

首先,我们从sklearn库中读取鸢尾花数据:

from sklearn.datasets import load_iris

data = load_iris()
X = data.data
y = data.target

然后将数据集分割成训练集和测试集:

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

接着,我们使用逻辑回归模型对数据进行训练和预测:

from sklearn.linear_model import LogisticRegression

model = LogisticRegression(random_state=42)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

预测结果y_pred是一个包含了测试集中所有样本的分类结果的列表。接下来,我们使用sklearn库中的confusion_matrix函数计算混淆矩阵:

from sklearn.metrics import confusion_matrix

matrix = confusion_matrix(y_test, y_pred)
print(matrix)

计算得到的混淆矩阵如下所示:

[[16  0  0]
 [ 0 18  1]
 [ 0  0 10]]

矩阵中的行表示真实值,列表示预测结果。例如,左上角的元素16表示真实值为0且模型预测为0的数量,即TN;中间的元素1表示真实值为1但是模型将其预测为了2的数量,即FN。我们可以通过混淆矩阵的元素值来计算模型的性能指标。

比如,计算准确率:

from sklearn.metrics import accuracy_score

acc = accuracy_score(y_test, y_pred)
print(acc)

计算得到的准确率为0.98,即模型的预测结果准确率为98%。

同理,计算召回率:

from sklearn.metrics import recall_score

rec = recall_score(y_test, y_pred, average='macro')
print(rec)

由于是多分类问题,所以召回率的计算需要指定average='macro'参数,通过计算得到的召回率为0.98。

示例2:自行车租赁量预测

自行车租赁问题是一个回归问题,我们可以使用Random Forest模型对数据进行建模,预测出租赁量,并使用混淆矩阵进行模型性能评估。

首先,我们从sklearn库中读取自行车租赁数据集:

from sklearn.datasets import load_boston

data = load_boston()
X = data.data
y = data.target

然后将数据集分割成训练集和测试集:

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

接着,我们使用随机森林回归模型对数据进行训练和预测:

from sklearn.ensemble import RandomForestRegressor

model = RandomForestRegressor(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

预测结果y_pred是一个包含了测试集中所有样本的预测值的列表。接下来,我们将预测结果和真实值进行比较并计算混淆矩阵:

threshold = 20
y_test_class = [1 if i > threshold else 0 for i in y_test]
y_pred_class = [1 if i > threshold else 0 for i in y_pred]

matrix = confusion_matrix(y_test_class, y_pred_class)
print(matrix)

在这里,我们根据租赁量是否超过阈值来将回归问题转换为分类问题。设定阈值threshold=20,如果预测的租赁量大于20,则将其设为1,否则设为0。然后使用上述方法计算混淆矩阵。

计算得到的混淆矩阵如下:

[[47 13]
 [ 2  6]]

矩阵中的行表示真实值,列表示预测结果。例如,左上角的元素47表示真实值为0且模型预测为0的数量,即TN;右下角的元素6表示真实值为1且模型预测为1的数量,即TP。我们可以通过混淆矩阵的元素值来计算模型的性能指标。

比如,计算准确率:

from sklearn.metrics import accuracy_score

acc = accuracy_score(y_test_class, y_pred_class)
print(acc)

计算得到的准确率为0.77,即模型的预测结果准确率为77%。

同理,计算召回率:

from sklearn.metrics import recall_score

rec = recall_score(y_test_class, y_pred_class)
print(rec)

计算得到的召回率为0.75,即模型正确预测出的租赁量超过阈值的样本数占所有租赁量超过阈值的样本数的比例为75%。

总结

通过本文的讲解,我们了解了混淆矩阵的概念及其计算方法,以及如何使用sklearn库中的函数计算混淆矩阵和性能指标。混淆矩阵是评估模型性能的重要工具之一,在模型开发过程中大有用处。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Python sklearn预测评估指标混淆矩阵计算示例详解 - Python技术站

(0)
上一篇 2023年6月5日
下一篇 2023年6月5日

相关文章

  • 保姆级python教程写个贪吃蛇大冒险

    “保姆级python教程写个贪吃蛇大冒险” 完整攻略 1. 准备工作 在开始写代码之前,我们需要确定游戏的规则以及所需的素材资源。所以在开始编写贪吃蛇游戏之前,需要先进行以下准备工作: 确定游戏规则,包括贪吃蛇的运动规律,障碍物的设置,得分计算等。 准备游戏所需资源,如背景音乐、美术资源等。 需要选择一个合适的游戏引擎,如Pygame。 2. 编写贪吃蛇游戏…

    python 2023年6月13日
    00
  • 在 Python 中,如何在另一个 py 文件的 [if __name__ == ‘__main__’] 中调用子程序?

    【问题标题】:In Python, how to invoke subroutine inside [if __name__ == ‘__main__’] of another py file?在 Python 中,如何在另一个 py 文件的 [if __name__ == ‘__main__’] 中调用子程序? 【发布时间】:2023-04-01 11:2…

    Python开发 2023年4月8日
    00
  • Python爬取京东商品信息评论存并进MySQL

    Python爬取京东商品信息评论存并进MySQL 本攻略将介绍如何使用Python爬取京东商品信息评论,并将其存储到MySQL数据库中。我们将使用Python的requests库和BeautifulSoup库来获取和解析京东商品信息评论,使用pymysql库来连接和操作MySQL数据库。 获取京东商品信息评论 我们可以使用Python的requests库来获…

    python 2023年5月15日
    00
  • Python实现视频转换为字符画详解

    下面是“Python实现视频转换为字符画”攻略: 准备 首先确保你已经安装好了Python语言、FFmpeg和ImageMagick这三个软件。 然后在命令行输入以下命令来安装Python第三方库: pip install opencv-python pillow numpy Python代码 下面是Python代码的流程: 1. 导入需要的库 import…

    python 2023年6月3日
    00
  • Python教程使用Chord包实现炫彩弦图示例

    接下来我将详细讲解“Python教程使用Chord包实现炫彩弦图示例”的完整攻略。 准备工作 在开始使用Chord包实现炫彩弦图之前,我们需要先安装必要的依赖,其中包括: Python 3.5 及以上版本 matplotlib numpy pandas chord 其中,matplotlib、numpy和pandas可通过pip命令进行安装,而chord需要…

    python 2023年5月18日
    00
  • Python基于select实现的socket服务器

    本攻略将介绍如何使用Python基于select实现一个socket服务器。select是一种多路复用的I/O模型,可以同时监视多个文件描述符,当其中任意一个文件描述符就绪时,select函数就会返回。使用select可以实现高效的I/O操作,避免了阻塞和轮询的问题。 实现socket服务器 以下是一个示例代码,用于实现一个基于select的socket服务…

    python 2023年5月15日
    00
  • Django打印出在数据库中执行的语句问题

    一、简介 Django提供了一个非常好用的ORM,可以方便的操作数据库,但是有时候我们需要查看ORM生成的SQL语句,以便优化ORM的使用。本攻略将详细介绍如何在Django中打印执行的SQL语句。 二、打印SQL语句的方法 在Django中,打印出在数据库中执行的SQL语句非常简单,我们只需要在settings.py中设置DEBUG=True,然后在执行O…

    python 2023年5月13日
    00
  • Python中JsonPath提取器和正则提取器

    以下是“Python中JsonPath提取器和正则提取器”的完整攻略: 一、问题描述 在Python中,我们经常需要从文本数据中提取特定的信息。JsonPath提取器和正则提取器是两种常见的提取工具,它们可以帮助我们快速、准确地提取所需的信息。本文将详细讲解Python中JsonPath提取器和正则提取器的使用方法,以及如何在实际开发中应用。 二、解决方案 …

    python 2023年5月14日
    00
合作推广
合作推广
分享本页
返回顶部