Python sklearn预测评估指标混淆矩阵计算示例详解

Python sklearn预测评估指标混淆矩阵计算示例详解

本文主要介绍如何使用Python中的sklearn库来计算模型的混淆矩阵,从而评估模型的预测性能。

混淆矩阵

混淆矩阵是模型性能评估的常用指标之一,以二分类问题为例,混淆矩阵通常包含4个元素:

  • 真实值为正例,模型预测结果为正例的数量(True Positive,TP)
  • 真实值为正例,模型预测结果为负例的数量(False Negative,FN)
  • 真实值为负例,模型预测结果为负例的数量(True Negative,TN)
  • 真实值为负例,模型预测结果为正例的数量(False Positive,FP)

下面我们以一个具体的例子来说明:

假设我们要预测一个人是否患有某种疾病,疾病存在为正例,疾病不存在为负例,同时我们从医院收集到了100个样本数据,其中50个样本为正例,50个样本为负例。我们使用某个模型进行预测,得到的结果如下表所示:

真实值/预测结果 预测为正例 预测为负例
真实值为正例 30 20
真实值为负例 10 40

通过上表,我们可以得到以下信息:

  • TP = 30,即模型正确预测出了30个疾病患者
  • FN = 20,即模型将20个疾病患者误判为非疾病
  • TN = 40,即模型正确预测出了40个非疾病的样本
  • FP = 10,即模型将10个非疾病的样本误判为疾病

这些信息可以通过统计真实值和预测结果的对应关系来计算得出,进而构成混淆矩阵。

计算混淆矩阵

Python中的sklearn库提供了计算混淆矩阵的函数。下面是一个示例:

from sklearn.metrics import confusion_matrix

y_true = [1, 1, 0, 1, 0, 0, 1, 0, 0, 0]
y_pred = [1, 0, 0, 1, 1, 0, 1, 1, 0, 0]

matrix = confusion_matrix(y_true, y_pred)
print(matrix)

在上述示例中,我们构造了两个列表y_truey_pred,分别存储了10个样本的真实值和模型的预测结果。然后使用confusion_matrix函数来计算混淆矩阵,并将结果输出。

这段代码的输出结果为:

[[3 2]
 [2 3]]

上述输出结果代表了混淆矩阵。矩阵的行表示真实值,列表示预测结果。在本例中,矩阵的左上角元素表示真实值为0且模型预测结果为0的数量,即TN = 3;右下角元素表示真实值为1且模型预测结果为1的数量,即TP = 3;左下角和右上角分别为FN = 2和FP = 2。

混淆矩阵相关指标的计算

在得到混淆矩阵之后,我们可以通过计算各项指标来评估模型的预测性能。下面介绍两个常用的指标:准确率和召回率。

准确率(Accuracy)

准确率是模型的预测结果与实际结果相同的样本数占总样本数的比例。计算公式为:

$$
\text{Accuracy} = \frac{TP + TN}{TP + TN + FP + FN}
$$

在sklearn库中,准确率可以通过accuracy_score函数来计算。下面是一个示例:

from sklearn.metrics import accuracy_score

y_true = [1, 1, 0, 1, 0, 0, 1, 0, 0, 0]
y_pred = [1, 0, 0, 1, 1, 0, 1, 1, 0, 0]

acc = accuracy_score(y_true, y_pred)
print(acc)

运行上述代码,会得到准确率为0.6,即模型的预测结果与实际结果相同的样本数占总样本数的比例为60%。

召回率(Recall)

召回率是指模型正确预测出的正例样本数占所有正例样本数的比例。计算公式为:

$$
\text{Recall} = \frac{TP}{TP + FN}
$$

在sklearn库中,召回率可以通过recall_score函数来计算。下面是一个示例:

from sklearn.metrics import recall_score

y_true = [1, 1, 0, 1, 0, 0, 1, 0, 0, 0]
y_pred = [1, 0, 0, 1, 1, 0, 1, 1, 0, 0]

rec = recall_score(y_true, y_pred)
print(rec)

运行上述代码,会得到召回率为0.6,即模型正确预测出的正例样本数占所有正例样本数的比例为60%。

示例说明

下面结合两个具体的示例说明混淆矩阵和相关指标的计算。

示例1:鸢尾花数据分类

鸢尾花数据是机器学习中经典的数据集之一,主要任务是根据花萼和花瓣的长度和宽度等特征,将鸢尾花的三个品种进行分类。这是一个典型的多分类问题。下面我们使用逻辑回归模型对数据集进行分类,并评估模型性能。

首先,我们从sklearn库中读取鸢尾花数据:

from sklearn.datasets import load_iris

data = load_iris()
X = data.data
y = data.target

然后将数据集分割成训练集和测试集:

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

接着,我们使用逻辑回归模型对数据进行训练和预测:

from sklearn.linear_model import LogisticRegression

model = LogisticRegression(random_state=42)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

预测结果y_pred是一个包含了测试集中所有样本的分类结果的列表。接下来,我们使用sklearn库中的confusion_matrix函数计算混淆矩阵:

from sklearn.metrics import confusion_matrix

matrix = confusion_matrix(y_test, y_pred)
print(matrix)

计算得到的混淆矩阵如下所示:

[[16  0  0]
 [ 0 18  1]
 [ 0  0 10]]

矩阵中的行表示真实值,列表示预测结果。例如,左上角的元素16表示真实值为0且模型预测为0的数量,即TN;中间的元素1表示真实值为1但是模型将其预测为了2的数量,即FN。我们可以通过混淆矩阵的元素值来计算模型的性能指标。

比如,计算准确率:

from sklearn.metrics import accuracy_score

acc = accuracy_score(y_test, y_pred)
print(acc)

计算得到的准确率为0.98,即模型的预测结果准确率为98%。

同理,计算召回率:

from sklearn.metrics import recall_score

rec = recall_score(y_test, y_pred, average='macro')
print(rec)

由于是多分类问题,所以召回率的计算需要指定average='macro'参数,通过计算得到的召回率为0.98。

示例2:自行车租赁量预测

自行车租赁问题是一个回归问题,我们可以使用Random Forest模型对数据进行建模,预测出租赁量,并使用混淆矩阵进行模型性能评估。

首先,我们从sklearn库中读取自行车租赁数据集:

from sklearn.datasets import load_boston

data = load_boston()
X = data.data
y = data.target

然后将数据集分割成训练集和测试集:

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

接着,我们使用随机森林回归模型对数据进行训练和预测:

from sklearn.ensemble import RandomForestRegressor

model = RandomForestRegressor(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

预测结果y_pred是一个包含了测试集中所有样本的预测值的列表。接下来,我们将预测结果和真实值进行比较并计算混淆矩阵:

threshold = 20
y_test_class = [1 if i > threshold else 0 for i in y_test]
y_pred_class = [1 if i > threshold else 0 for i in y_pred]

matrix = confusion_matrix(y_test_class, y_pred_class)
print(matrix)

在这里,我们根据租赁量是否超过阈值来将回归问题转换为分类问题。设定阈值threshold=20,如果预测的租赁量大于20,则将其设为1,否则设为0。然后使用上述方法计算混淆矩阵。

计算得到的混淆矩阵如下:

[[47 13]
 [ 2  6]]

矩阵中的行表示真实值,列表示预测结果。例如,左上角的元素47表示真实值为0且模型预测为0的数量,即TN;右下角的元素6表示真实值为1且模型预测为1的数量,即TP。我们可以通过混淆矩阵的元素值来计算模型的性能指标。

比如,计算准确率:

from sklearn.metrics import accuracy_score

acc = accuracy_score(y_test_class, y_pred_class)
print(acc)

计算得到的准确率为0.77,即模型的预测结果准确率为77%。

同理,计算召回率:

from sklearn.metrics import recall_score

rec = recall_score(y_test_class, y_pred_class)
print(rec)

计算得到的召回率为0.75,即模型正确预测出的租赁量超过阈值的样本数占所有租赁量超过阈值的样本数的比例为75%。

总结

通过本文的讲解,我们了解了混淆矩阵的概念及其计算方法,以及如何使用sklearn库中的函数计算混淆矩阵和性能指标。混淆矩阵是评估模型性能的重要工具之一,在模型开发过程中大有用处。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Python sklearn预测评估指标混淆矩阵计算示例详解 - Python技术站

(0)
上一篇 2023年6月5日
下一篇 2023年6月5日

相关文章

  • Python制作Windows系统服务

    对于Python制作Windows系统服务的完整攻略,可以按照以下步骤进行: 1. 安装pywin32模块 在制作Windows系统服务之前,需要先安装pywin32模块,它是Python在Windows操作系统下的一个扩展库,可以实现操作Windows系统的一些功能,包括服务的创建和管理。 可以使用pip安装pywin32模块,命令如下: pip inst…

    python 2023年5月30日
    00
  • python实现简单银行管理系统

    如何实现简单银行管理系统 简介 Python是一种高级编程语言,它可以用来开发各种应用程序,包括银行管理系统。本文将介绍如何使用Python编写一个简单的银行管理系统。 功能特点 简单的银行管理系统需要具备以下功能: 用户注册:用户可以注册一个帐户进行存款和取款操作。 存款:用户可以存入钱到自己的帐户。 取款:用户可以从自己的帐户中取出钱。 查询余额:用户可…

    python 2023年5月30日
    00
  • python获取指定网页上所有超链接的方法

    获取指定网页上所有超链接的方法可以通过使用Python中的第三方库BeautifulSoup和requests来实现。具体步骤如下: 使用requests库获取网页的HTML源代码 代码示例: import requests url = ‘https://example.com’ response = requests.get(url) html = res…

    python 2023年6月3日
    00
  • python os.path模块使用方法介绍

    Python的os.path模块使用方法介绍 os.path模块是Python标准库中与路径相关操作的模块之一,它提供了许多用于处理文件路径的函数。本文将详细讲解os.path模块的各种方法及其用法。 获取路径信息: os.path.abspath(path) 返回path的绝对路径,如果path不存在,则抛出FileNotFoundError。 >&…

    python 2023年6月2日
    00
  • python复制列表时[:]和[::]之间有什么区别

    当我们想要复制一个列表时,通常使用切片操作来实现。在使用切片时,可以使用两个冒号开始和结束索引之间添加步长来决定生成子列表的步长。Python中表示复制列表的切片语法是用开始和结束索引之间添加“:”的形式,这个语法也有其他的变体。 具体来说,切片语法格式为list[start:end],其中start是开始索引(包含),end是结束索引(不包含)。如果省略开…

    python 2023年6月6日
    00
  • python数据结构之图的实现方法

    以下是关于“Python数据结构之图的实现方法”的完整攻略: 简介 图是一种常用的数据结构,用于表示对象之间的关系。在本教程中,我们将介绍如何使用Python实现图,包括邻接矩阵和邻接表两种实现方法。 邻接矩阵 邻接矩阵是一种常用的图的实现方法,它使用二维数组表示图中的节点和边。在邻接矩阵中,每个节点都对应数组中的一行和一列,如果两个节点之间有边相连,则在对…

    python 2023年5月14日
    00
  • python 通过xml获取测试节点和属性的实例

    当我们进行软件测试时,常常需要读取XML文件中的测试节点和属性。Python提供了多种库来处理XML文件,其中最常用的是ElementTree库。接下来,我将为您提供一个完整的攻略来使用Python通过XML获取测试节点和属性。 第一步:导入ElementTree库 使用Python处理XML文件的第一步是导入ElementTree库。可以通过以下代码来导入…

    python 2023年5月14日
    00
  • Jmeter如何使用BeanShell取样器调用Python脚本

    JMeter是一个性能测试工具,也可以扩展以支持其他类型的测试。它支持Java编写的插件,其中就包括BeanShell取样器。通过BeanShell取样器,我们可以调用Python脚本来实现更复杂的测试场景。 下面是使用JMeter和BeanShell取样器调用Python脚本的完整攻略: 首先,在JMeter中添加BeanShell取样器。在测试计划中添加…

    python 2023年6月2日
    00
合作推广
合作推广
分享本页
返回顶部