Python sklearn预测评估指标混淆矩阵计算示例详解

Python sklearn预测评估指标混淆矩阵计算示例详解

本文主要介绍如何使用Python中的sklearn库来计算模型的混淆矩阵,从而评估模型的预测性能。

混淆矩阵

混淆矩阵是模型性能评估的常用指标之一,以二分类问题为例,混淆矩阵通常包含4个元素:

  • 真实值为正例,模型预测结果为正例的数量(True Positive,TP)
  • 真实值为正例,模型预测结果为负例的数量(False Negative,FN)
  • 真实值为负例,模型预测结果为负例的数量(True Negative,TN)
  • 真实值为负例,模型预测结果为正例的数量(False Positive,FP)

下面我们以一个具体的例子来说明:

假设我们要预测一个人是否患有某种疾病,疾病存在为正例,疾病不存在为负例,同时我们从医院收集到了100个样本数据,其中50个样本为正例,50个样本为负例。我们使用某个模型进行预测,得到的结果如下表所示:

真实值/预测结果 预测为正例 预测为负例
真实值为正例 30 20
真实值为负例 10 40

通过上表,我们可以得到以下信息:

  • TP = 30,即模型正确预测出了30个疾病患者
  • FN = 20,即模型将20个疾病患者误判为非疾病
  • TN = 40,即模型正确预测出了40个非疾病的样本
  • FP = 10,即模型将10个非疾病的样本误判为疾病

这些信息可以通过统计真实值和预测结果的对应关系来计算得出,进而构成混淆矩阵。

计算混淆矩阵

Python中的sklearn库提供了计算混淆矩阵的函数。下面是一个示例:

from sklearn.metrics import confusion_matrix

y_true = [1, 1, 0, 1, 0, 0, 1, 0, 0, 0]
y_pred = [1, 0, 0, 1, 1, 0, 1, 1, 0, 0]

matrix = confusion_matrix(y_true, y_pred)
print(matrix)

在上述示例中,我们构造了两个列表y_truey_pred,分别存储了10个样本的真实值和模型的预测结果。然后使用confusion_matrix函数来计算混淆矩阵,并将结果输出。

这段代码的输出结果为:

[[3 2]
 [2 3]]

上述输出结果代表了混淆矩阵。矩阵的行表示真实值,列表示预测结果。在本例中,矩阵的左上角元素表示真实值为0且模型预测结果为0的数量,即TN = 3;右下角元素表示真实值为1且模型预测结果为1的数量,即TP = 3;左下角和右上角分别为FN = 2和FP = 2。

混淆矩阵相关指标的计算

在得到混淆矩阵之后,我们可以通过计算各项指标来评估模型的预测性能。下面介绍两个常用的指标:准确率和召回率。

准确率(Accuracy)

准确率是模型的预测结果与实际结果相同的样本数占总样本数的比例。计算公式为:

$$
\text{Accuracy} = \frac{TP + TN}{TP + TN + FP + FN}
$$

在sklearn库中,准确率可以通过accuracy_score函数来计算。下面是一个示例:

from sklearn.metrics import accuracy_score

y_true = [1, 1, 0, 1, 0, 0, 1, 0, 0, 0]
y_pred = [1, 0, 0, 1, 1, 0, 1, 1, 0, 0]

acc = accuracy_score(y_true, y_pred)
print(acc)

运行上述代码,会得到准确率为0.6,即模型的预测结果与实际结果相同的样本数占总样本数的比例为60%。

召回率(Recall)

召回率是指模型正确预测出的正例样本数占所有正例样本数的比例。计算公式为:

$$
\text{Recall} = \frac{TP}{TP + FN}
$$

在sklearn库中,召回率可以通过recall_score函数来计算。下面是一个示例:

from sklearn.metrics import recall_score

y_true = [1, 1, 0, 1, 0, 0, 1, 0, 0, 0]
y_pred = [1, 0, 0, 1, 1, 0, 1, 1, 0, 0]

rec = recall_score(y_true, y_pred)
print(rec)

运行上述代码,会得到召回率为0.6,即模型正确预测出的正例样本数占所有正例样本数的比例为60%。

示例说明

下面结合两个具体的示例说明混淆矩阵和相关指标的计算。

示例1:鸢尾花数据分类

鸢尾花数据是机器学习中经典的数据集之一,主要任务是根据花萼和花瓣的长度和宽度等特征,将鸢尾花的三个品种进行分类。这是一个典型的多分类问题。下面我们使用逻辑回归模型对数据集进行分类,并评估模型性能。

首先,我们从sklearn库中读取鸢尾花数据:

from sklearn.datasets import load_iris

data = load_iris()
X = data.data
y = data.target

然后将数据集分割成训练集和测试集:

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

接着,我们使用逻辑回归模型对数据进行训练和预测:

from sklearn.linear_model import LogisticRegression

model = LogisticRegression(random_state=42)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

预测结果y_pred是一个包含了测试集中所有样本的分类结果的列表。接下来,我们使用sklearn库中的confusion_matrix函数计算混淆矩阵:

from sklearn.metrics import confusion_matrix

matrix = confusion_matrix(y_test, y_pred)
print(matrix)

计算得到的混淆矩阵如下所示:

[[16  0  0]
 [ 0 18  1]
 [ 0  0 10]]

矩阵中的行表示真实值,列表示预测结果。例如,左上角的元素16表示真实值为0且模型预测为0的数量,即TN;中间的元素1表示真实值为1但是模型将其预测为了2的数量,即FN。我们可以通过混淆矩阵的元素值来计算模型的性能指标。

比如,计算准确率:

from sklearn.metrics import accuracy_score

acc = accuracy_score(y_test, y_pred)
print(acc)

计算得到的准确率为0.98,即模型的预测结果准确率为98%。

同理,计算召回率:

from sklearn.metrics import recall_score

rec = recall_score(y_test, y_pred, average='macro')
print(rec)

由于是多分类问题,所以召回率的计算需要指定average='macro'参数,通过计算得到的召回率为0.98。

示例2:自行车租赁量预测

自行车租赁问题是一个回归问题,我们可以使用Random Forest模型对数据进行建模,预测出租赁量,并使用混淆矩阵进行模型性能评估。

首先,我们从sklearn库中读取自行车租赁数据集:

from sklearn.datasets import load_boston

data = load_boston()
X = data.data
y = data.target

然后将数据集分割成训练集和测试集:

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

接着,我们使用随机森林回归模型对数据进行训练和预测:

from sklearn.ensemble import RandomForestRegressor

model = RandomForestRegressor(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

预测结果y_pred是一个包含了测试集中所有样本的预测值的列表。接下来,我们将预测结果和真实值进行比较并计算混淆矩阵:

threshold = 20
y_test_class = [1 if i > threshold else 0 for i in y_test]
y_pred_class = [1 if i > threshold else 0 for i in y_pred]

matrix = confusion_matrix(y_test_class, y_pred_class)
print(matrix)

在这里,我们根据租赁量是否超过阈值来将回归问题转换为分类问题。设定阈值threshold=20,如果预测的租赁量大于20,则将其设为1,否则设为0。然后使用上述方法计算混淆矩阵。

计算得到的混淆矩阵如下:

[[47 13]
 [ 2  6]]

矩阵中的行表示真实值,列表示预测结果。例如,左上角的元素47表示真实值为0且模型预测为0的数量,即TN;右下角的元素6表示真实值为1且模型预测为1的数量,即TP。我们可以通过混淆矩阵的元素值来计算模型的性能指标。

比如,计算准确率:

from sklearn.metrics import accuracy_score

acc = accuracy_score(y_test_class, y_pred_class)
print(acc)

计算得到的准确率为0.77,即模型的预测结果准确率为77%。

同理,计算召回率:

from sklearn.metrics import recall_score

rec = recall_score(y_test_class, y_pred_class)
print(rec)

计算得到的召回率为0.75,即模型正确预测出的租赁量超过阈值的样本数占所有租赁量超过阈值的样本数的比例为75%。

总结

通过本文的讲解,我们了解了混淆矩阵的概念及其计算方法,以及如何使用sklearn库中的函数计算混淆矩阵和性能指标。混淆矩阵是评估模型性能的重要工具之一,在模型开发过程中大有用处。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Python sklearn预测评估指标混淆矩阵计算示例详解 - Python技术站

(0)
上一篇 2023年6月5日
下一篇 2023年6月5日

相关文章

  • 跟老齐学Python之啰嗦的除法

    在Python中,除法运算符/的结果可能会出现小数,这是因为Python默认使用浮点数进行除法运算。但是在某些情况下,我们需要使用整数进行除法运算,这时候就需要使用Python中的整除运算符//。 下面是“跟老齐学Python之啰嗦的除法”的完整攻略: 1. Python中的除法运算符 在Python中,除法运算符/的结果可能会出现小数,例如: >&g…

    python 2023年5月14日
    00
  • 一劳永逸彻底解决pip install慢的办法

    下面是一份详细的攻略,希望可以帮助您解决pip install慢的问题。 解决pip install慢的办法 问题描述 在使用Python时,我们经常需要用到pip安装第三方包。但是,在某些情况下,由于网络速度慢或者其他各种原因,pip install会非常慢,甚至可能无法完成。为了解决这个问题,我们提供以下几种方法。 方法一:更换pip源 一般来说,我们使…

    python 2023年5月14日
    00
  • Python 常见的配置文件写法梳理汇总

    使用Markdown格式,以下是Python常见配置文件的写法梳理汇总完整攻略。 Python常见配置文件写法梳理汇总 1. INI 文件 INI 文件是最常用的配置文件之一,它通常被用于Windows操作系统的应用程序中。INI 文件本质上是一个键值对集合,由多个节组成,每个节下面可以有多个键值对。(示例代码见下) ; Python配置文件示例 [data…

    python 2023年6月3日
    00
  • 检查字节是否在 Python 中生成有效的 ISO 8859-15(拉丁文)

    【问题标题】:Check if bytes result in valid ISO 8859-15 (Latin) in Python检查字节是否在 Python 中生成有效的 ISO 8859-15(拉丁文) 【发布时间】:2023-04-07 07:03:01 【问题描述】: 我想测试我从文件中提取的一串字节是否产生有效的ISO-8859-15 编码文本…

    Python开发 2023年4月8日
    00
  • python多线程共享变量的使用和效率方法

    关于“python多线程共享变量的使用和效率方法”的完整攻略,我们可以分为以下几个方面进行讲解: 1. 多线程共享变量的基本概念 在Python多线程编程中,当多个线程同时访问同一个变量时,就需要考虑多线程共享变量的问题。多线程共享变量是一个非常重要的问题,因为不正确的共享变量会导致程序出现竞态条件,从而导致程序出现不可预料的错误。 多线程共享变量的基本概念…

    python 2023年5月18日
    00
  • Python实现把json格式转换成文本或sql文件

    要把Json格式转换成文本或Sql文件,可以通过Python自带的json库来实现。 1. Json转文本 将Json格式转换成文本,主要是通过序列化Json数据为Python的字符串格式,然后再将字符串输出到文件中,代码如下: import json # 读取Json文件中的数据 with open(‘data.json’) as f: data = js…

    python 2023年6月3日
    00
  • 解决pip install 卡住不动的问题

    使用pip安装Python包时,有时候会遇到卡住不动的情况,这可能是由于网络问题,服务器过载或其他问题引起的。以下是解决pipinstall卡住不动的问题的完整攻略: 检查网络连接:使用命令行或通过浏览器访问网站,以确保网络连接正常。如果有其他人在同一网络环境中下载或上传大量数据,可能会影响pip安装过程,请等待他们完成或更换网络环境。 检查pip版本:如果…

    python 2023年5月14日
    00
  • python中常用的内置模块汇总

    让我来给你详细介绍一下Python中常用的内置模块。 什么是Python内置模块 Python内置模块是指Python语言之中自带的标准库。Python标准库是Python语言的核心组成部分,提供了诸多常用的功能模块,如IO操作、字符串处理、正则表达式、数学运算、日期时间以及网络通信等各种工具。Python内置模块可以直接导入使用,不需要额外安装其他第三方模…

    python 2023年5月30日
    00
合作推广
合作推广
分享本页
返回顶部