Python sklearn预测评估指标混淆矩阵计算示例详解

Python sklearn预测评估指标混淆矩阵计算示例详解

本文主要介绍如何使用Python中的sklearn库来计算模型的混淆矩阵,从而评估模型的预测性能。

混淆矩阵

混淆矩阵是模型性能评估的常用指标之一,以二分类问题为例,混淆矩阵通常包含4个元素:

  • 真实值为正例,模型预测结果为正例的数量(True Positive,TP)
  • 真实值为正例,模型预测结果为负例的数量(False Negative,FN)
  • 真实值为负例,模型预测结果为负例的数量(True Negative,TN)
  • 真实值为负例,模型预测结果为正例的数量(False Positive,FP)

下面我们以一个具体的例子来说明:

假设我们要预测一个人是否患有某种疾病,疾病存在为正例,疾病不存在为负例,同时我们从医院收集到了100个样本数据,其中50个样本为正例,50个样本为负例。我们使用某个模型进行预测,得到的结果如下表所示:

真实值/预测结果 预测为正例 预测为负例
真实值为正例 30 20
真实值为负例 10 40

通过上表,我们可以得到以下信息:

  • TP = 30,即模型正确预测出了30个疾病患者
  • FN = 20,即模型将20个疾病患者误判为非疾病
  • TN = 40,即模型正确预测出了40个非疾病的样本
  • FP = 10,即模型将10个非疾病的样本误判为疾病

这些信息可以通过统计真实值和预测结果的对应关系来计算得出,进而构成混淆矩阵。

计算混淆矩阵

Python中的sklearn库提供了计算混淆矩阵的函数。下面是一个示例:

from sklearn.metrics import confusion_matrix

y_true = [1, 1, 0, 1, 0, 0, 1, 0, 0, 0]
y_pred = [1, 0, 0, 1, 1, 0, 1, 1, 0, 0]

matrix = confusion_matrix(y_true, y_pred)
print(matrix)

在上述示例中,我们构造了两个列表y_truey_pred,分别存储了10个样本的真实值和模型的预测结果。然后使用confusion_matrix函数来计算混淆矩阵,并将结果输出。

这段代码的输出结果为:

[[3 2]
 [2 3]]

上述输出结果代表了混淆矩阵。矩阵的行表示真实值,列表示预测结果。在本例中,矩阵的左上角元素表示真实值为0且模型预测结果为0的数量,即TN = 3;右下角元素表示真实值为1且模型预测结果为1的数量,即TP = 3;左下角和右上角分别为FN = 2和FP = 2。

混淆矩阵相关指标的计算

在得到混淆矩阵之后,我们可以通过计算各项指标来评估模型的预测性能。下面介绍两个常用的指标:准确率和召回率。

准确率(Accuracy)

准确率是模型的预测结果与实际结果相同的样本数占总样本数的比例。计算公式为:

$$
\text{Accuracy} = \frac{TP + TN}{TP + TN + FP + FN}
$$

在sklearn库中,准确率可以通过accuracy_score函数来计算。下面是一个示例:

from sklearn.metrics import accuracy_score

y_true = [1, 1, 0, 1, 0, 0, 1, 0, 0, 0]
y_pred = [1, 0, 0, 1, 1, 0, 1, 1, 0, 0]

acc = accuracy_score(y_true, y_pred)
print(acc)

运行上述代码,会得到准确率为0.6,即模型的预测结果与实际结果相同的样本数占总样本数的比例为60%。

召回率(Recall)

召回率是指模型正确预测出的正例样本数占所有正例样本数的比例。计算公式为:

$$
\text{Recall} = \frac{TP}{TP + FN}
$$

在sklearn库中,召回率可以通过recall_score函数来计算。下面是一个示例:

from sklearn.metrics import recall_score

y_true = [1, 1, 0, 1, 0, 0, 1, 0, 0, 0]
y_pred = [1, 0, 0, 1, 1, 0, 1, 1, 0, 0]

rec = recall_score(y_true, y_pred)
print(rec)

运行上述代码,会得到召回率为0.6,即模型正确预测出的正例样本数占所有正例样本数的比例为60%。

示例说明

下面结合两个具体的示例说明混淆矩阵和相关指标的计算。

示例1:鸢尾花数据分类

鸢尾花数据是机器学习中经典的数据集之一,主要任务是根据花萼和花瓣的长度和宽度等特征,将鸢尾花的三个品种进行分类。这是一个典型的多分类问题。下面我们使用逻辑回归模型对数据集进行分类,并评估模型性能。

首先,我们从sklearn库中读取鸢尾花数据:

from sklearn.datasets import load_iris

data = load_iris()
X = data.data
y = data.target

然后将数据集分割成训练集和测试集:

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

接着,我们使用逻辑回归模型对数据进行训练和预测:

from sklearn.linear_model import LogisticRegression

model = LogisticRegression(random_state=42)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

预测结果y_pred是一个包含了测试集中所有样本的分类结果的列表。接下来,我们使用sklearn库中的confusion_matrix函数计算混淆矩阵:

from sklearn.metrics import confusion_matrix

matrix = confusion_matrix(y_test, y_pred)
print(matrix)

计算得到的混淆矩阵如下所示:

[[16  0  0]
 [ 0 18  1]
 [ 0  0 10]]

矩阵中的行表示真实值,列表示预测结果。例如,左上角的元素16表示真实值为0且模型预测为0的数量,即TN;中间的元素1表示真实值为1但是模型将其预测为了2的数量,即FN。我们可以通过混淆矩阵的元素值来计算模型的性能指标。

比如,计算准确率:

from sklearn.metrics import accuracy_score

acc = accuracy_score(y_test, y_pred)
print(acc)

计算得到的准确率为0.98,即模型的预测结果准确率为98%。

同理,计算召回率:

from sklearn.metrics import recall_score

rec = recall_score(y_test, y_pred, average='macro')
print(rec)

由于是多分类问题,所以召回率的计算需要指定average='macro'参数,通过计算得到的召回率为0.98。

示例2:自行车租赁量预测

自行车租赁问题是一个回归问题,我们可以使用Random Forest模型对数据进行建模,预测出租赁量,并使用混淆矩阵进行模型性能评估。

首先,我们从sklearn库中读取自行车租赁数据集:

from sklearn.datasets import load_boston

data = load_boston()
X = data.data
y = data.target

然后将数据集分割成训练集和测试集:

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

接着,我们使用随机森林回归模型对数据进行训练和预测:

from sklearn.ensemble import RandomForestRegressor

model = RandomForestRegressor(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

预测结果y_pred是一个包含了测试集中所有样本的预测值的列表。接下来,我们将预测结果和真实值进行比较并计算混淆矩阵:

threshold = 20
y_test_class = [1 if i > threshold else 0 for i in y_test]
y_pred_class = [1 if i > threshold else 0 for i in y_pred]

matrix = confusion_matrix(y_test_class, y_pred_class)
print(matrix)

在这里,我们根据租赁量是否超过阈值来将回归问题转换为分类问题。设定阈值threshold=20,如果预测的租赁量大于20,则将其设为1,否则设为0。然后使用上述方法计算混淆矩阵。

计算得到的混淆矩阵如下:

[[47 13]
 [ 2  6]]

矩阵中的行表示真实值,列表示预测结果。例如,左上角的元素47表示真实值为0且模型预测为0的数量,即TN;右下角的元素6表示真实值为1且模型预测为1的数量,即TP。我们可以通过混淆矩阵的元素值来计算模型的性能指标。

比如,计算准确率:

from sklearn.metrics import accuracy_score

acc = accuracy_score(y_test_class, y_pred_class)
print(acc)

计算得到的准确率为0.77,即模型的预测结果准确率为77%。

同理,计算召回率:

from sklearn.metrics import recall_score

rec = recall_score(y_test_class, y_pred_class)
print(rec)

计算得到的召回率为0.75,即模型正确预测出的租赁量超过阈值的样本数占所有租赁量超过阈值的样本数的比例为75%。

总结

通过本文的讲解,我们了解了混淆矩阵的概念及其计算方法,以及如何使用sklearn库中的函数计算混淆矩阵和性能指标。混淆矩阵是评估模型性能的重要工具之一,在模型开发过程中大有用处。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Python sklearn预测评估指标混淆矩阵计算示例详解 - Python技术站

(0)
上一篇 2023年6月5日
下一篇 2023年6月5日

相关文章

  • Python3的正则表达式详解

    Python3的正则表达式详解 正则表达式是一种用于描述字符串模式的语言,它可以用于匹配、查找、替换和割字符串。Python中的re模块供了对正则表达式的支持,可以方便进行字符串的处理。本文将详细讲解Python3中正则表达式的语法和re模块的常用函数以及两个常用的匹配实例。 正则表达式语法 正则表达式由一些特殊字符和普通字符组成,用于字符串模式。下面是一些…

    python 2023年5月14日
    00
  • Python try except else使用详解

    Python的try-except-else语句是用于捕捉异常的一种方法。它的常见用法是在一个try语句块中包含有可能会抛出异常的代码,对于不同的异常类型使用不同的except语句块来处理异常,并且使用else语句块来处理正常执行的代码。 使用try-except-else的基本语法 try: # 可能会抛出异常的代码 except ExceptionTyp…

    python 2023年5月13日
    00
  • 基于Python实现商场抽奖小系统

    下面是基于Python实现商场抽奖小系统的完整攻略: 1. 确定系统需求 在开始编写代码前,我们需要先明确这个抽奖小系统需要具备哪些功能,例如: 能够生成一定数量的奖品,并将奖品存储在数据库中 能够在数据库中添加、删除、修改奖品的信息 能够在抽奖时从数据库中获取奖品信息,并展示给用户 能够实现抽奖过程,并在最终抽中奖品后将相关信息存储在数据库中 能够展示抽奖…

    python 2023年6月13日
    00
  • Django笔记二十七之数据库函数之文本函数

    本文首发于公众号:Hunter后端原文链接:Django笔记二十七之数据库函数之文本函数 这篇笔记将介绍如何使用数据库函数里的文本函数。 顾名思义,文本函数,就是针对文本字段进行操作的函数,如下是目录汇总: Concat() —— 合并 Left() —— 从左边开始截取 Length() —— 获取字符串长度 Lower() —— 小写处理 LPad() …

    python 2023年4月22日
    00
  • Python必备技巧之字符数据操作详解

    Python必备技巧之字符数据操作详解 字符数据类型 在Python中,字符串是一种常见的数据类型。字符串是一个由字符序列组成的不可变序列。因为字符串不可变,因此不能像列表一样进行就地修改。字符串可以使用单引号或双引号来表示。 字符串连接和重复 字符串可以连接起来形成新的字符串。连接操作可以使用+运算符或通过字符串插值完成。例如: str1 = "…

    python 2023年5月14日
    00
  • python利用appium实现手机APP自动化的示例

    针对这个话题,我将给出以下完整攻略: 准备工作 安装 Python3 环境 安装 appium-python-client 库 pip install Appium-Python-Client 安装 Android SDK, 并配置 ANDROID_HOME 环境变量 安装 JDK, 并配置 JAVA_HOME 环境变量 在手机上安装待测试的 APP 在电脑…

    python 2023年5月19日
    00
  • python beautifulsoup在标签之间查找

    【问题标题】:python beautifulsoup find between tagspython beautifulsoup在标签之间查找 【发布时间】:2023-04-04 20:26:01 【问题描述】: 我正在尝试从网站获取数据。我设法获得了我想要的数据子集 sections = rows.findAll(‘p’) for section in …

    Python开发 2023年4月6日
    00
  • python使用socket高效传输视频数据帧(连续发送图片)

    下面我将为您详细讲解“python使用socket高效传输视频数据帧(连续发送图片)”的完整实例教程,包括示例说明: 1. 简介 在本教程中,我们将使用Python中的socket库实现高效的视频数据帧传输,特别是连续发送图片。实现这种数据流的目标是传输即时视频,并尽可能地减小延迟。 2. 实现 2.1 导入库 我们首先要导入需要的Python库: impo…

    python 2023年5月13日
    00
合作推广
合作推广
分享本页
返回顶部