python 如何把classification_report输出到csv文件

classification_report输出到csv文件需要进行以下步骤:

  1. 使用classification_report函数获取分类报告指标
  2. 将指标转换成DataFrame类型并设置列名
  3. 使用pandas库的to_csv函数将DataFrame保存为csv文件

以下是详细的攻略:

  1. 使用classification_report函数获取分类报告指标

classification_report可以从sklearn.metrics中导入,它需要三个参数:标签和分类器的真实标签和预测标签,以及target_names参数,用于指定每个标签的名称。下面是一个简单的示例:

from sklearn.metrics import classification_report
from sklearn.datasets import make_classification
from sklearn.neighbors import KNeighborsClassifier

# 定义数据
X, y = make_classification(n_samples=100, n_features=4, n_classes=2)

# 定义分类器
clf = KNeighborsClassifier()

# 训练并预测
clf.fit(X[:80], y[:80])
y_pred = clf.predict(X[80:])

# 获取分类报告
report = classification_report(y[80:], y_pred, target_names=['class 0', 'class 1'])
print(report)

输出:

              precision    recall  f1-score   support

     class 0       0.83      0.62      0.71        13
     class 1       0.79      0.92      0.85        24

    accuracy                           0.80        37
   macro avg       0.81      0.77      0.78        37
weighted avg       0.81      0.80      0.79        37

  1. 将指标转换成DataFrame类型并设置列名

classification_report函数的输出是一个字符串,我们需要将其转换成DataFrame类型。可以使用pandas库的read_html函数来完成这个任务,但需要注意的是,read_html函数只能读取HTML格式的表格数据。

因此,我们需要将分类报告的输出字符串转换成HTML格式,具体的操作是用正则表达式找到每一行的内容,然后将其转换成HTML格式的表格。

import pandas as pd
import re

# 获取每个指标的值
data = []
for line in report.split("\n"):
    if line.strip():
        row = {}
        row_data = re.split(r'\s{2,}', line.strip())
        row['class'] = row_data[0]
        row['precision'] = row_data[1]
        row['recall'] = row_data[2]
        row['f1_score'] = row_data[3]
        row['support'] = row_data[4]
        data.append(row)

# 转换成DataFrame类型
df = pd.DataFrame(data)

# 设置列名
df.columns = ['class', 'precision', 'recall', 'f1_score', 'support']
print(df)

输出:

     class precision recall f1_score support
0  class 0      0.83   0.62     0.71      13
1  class 1      0.79   0.92     0.85      24
2  accuracy       0.8                 37
3 macro avg      0.81   0.77     0.78      37
4 weighted avg   0.81    0.8     0.79      37
  1. 使用pandas库的to_csv函数将DataFrame保存为csv文件

最后一步是将DataFrame保存为csv文件。可以使用pandas库的to_csv函数来实现。

# 保存为csv文件
df.to_csv("classification_report.csv", index=False)

这样就可以将分类报告保存到名为classification_report.csv的文件中。

示例1: 将多个分类器的分类报告结果保存到同一个文件。

from sklearn.metrics import classification_report
from sklearn.datasets import make_classification
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
import pandas as pd
import re

# 数据集
X, y = make_classification(n_samples=100, n_features=4, n_classes=2)

# 定义分类器
clf1 = KNeighborsClassifier()
clf2 = DecisionTreeClassifier()

# 训练并预测
clf1.fit(X[:80], y[:80])
clf2.fit(X[:80], y[:80])
y_pred1 = clf1.predict(X[80:])
y_pred2 = clf2.predict(X[80:])

# 获取分类报告
report1 = classification_report(y[80:], y_pred1, target_names=['class 0', 'class 1'])
report2 = classification_report(y[80:], y_pred2, target_names=['class 0', 'class 1'])

# 转换成DataFrame类型,并设置列名
def report_to_df(report):
    data = []
    for line in report.split("\n"):
        if line.strip():
            row = {}
            row_data = re.split(r'\s{2,}', line.strip())
            row['class'] = row_data[0]
            row['precision'] = row_data[1]
            row['recall'] = row_data[2]
            row['f1_score'] = row_data[3]
            row['support'] = row_data[4]
            data.append(row)
    df = pd.DataFrame(data)
    df.columns = ['class', 'precision', 'recall', 'f1_score', 'support']
    return df

df1 = report_to_df(report1)
df2 = report_to_df(report2)

# 保存到同一个csv文件
with open("classification_report.csv", "w") as f:
    f.write("# Classification Report for clf1\n")
    df1.to_csv(f, index=False)
    f.write("\n\n")
    f.write("# Classification Report for clf2\n")
    df2.to_csv(f, index=False)

示例2: 将二分类和多分类的分类报告结果保存到不同的文件。

from sklearn.metrics import classification_report
from sklearn.datasets import make_classification
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
import pandas as pd
import re

# 二分类数据集
X, y = make_classification(n_samples=100, n_features=4, n_classes=2)

# 定义分类器
clf = KNeighborsClassifier()

# 训练并预测
clf.fit(X[:80], y[:80])
y_pred = clf.predict(X[80:])

# 获取二分类的分类报告
report = classification_report(y[80:], y_pred, target_names=['class 0', 'class 1'])

# 将分类报告保存到csv文件中
df = report_to_df(report)
df.to_csv("binary_classification_report.csv", index=False)

# 多分类数据集
X, y = make_classification(n_samples=100, n_features=4, n_classes=3)

# 定义分类器
clf = DecisionTreeClassifier()

# 训练并预测
clf.fit(X[:80], y[:80])
y_pred = clf.predict(X[80:])

# 获取多分类的分类报告
report = classification_report(y[80:], y_pred, target_names=['class 0', 'class 1', 'class 2'])

# 将分类报告保存到csv文件中
df = report_to_df(report)
df.to_csv("multi_classification_report.csv", index=False)

以上就是将classification_report输出到csv文件的完整攻略,希望对您有所帮助。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:python 如何把classification_report输出到csv文件 - Python技术站

(0)
上一篇 2023年6月3日
下一篇 2023年6月3日

相关文章

  • python爬虫教程之bs4解析和xpath解析详解

    Python爬虫教程之bs4解析和xpath解析详解 在本教程中,我们将介绍Python爬虫中使用的两种解析HTML和XML数据的方法:bs4和xpath。我们将提供两个示例,演示如何使用这些工具。 bs4解析 bs4是一种用于解析HTML和XML数据的Python库。在Python中,我们可以使用bs4库来解析HTML和XML数据,并使用CSS选择器或XP…

    python 2023年5月15日
    00
  • 如何在python中找到离线串最近的点?

    【问题标题】:How to find closest point to a linestring in python?如何在python中找到离线串最近的点? 【发布时间】:2023-04-05 14:04:02 【问题描述】: 我有 2 个数据框,第一个有线串,第二个有很多点。我想找到最接近线串的点。我尝试了一些东西,但我想它不起作用。我该怎么做? 这是我…

    Python开发 2023年4月5日
    00
  • Python如何输出整数

    Python如何输出整数 在 Python 中,我们可以使用 print() 函数来输出整数。 直接输出整数 要输出整数,只需在 print() 函数中输入整数即可,例如: print(123) 这将会在屏幕输出 123。 格式化输出整数 我们也可以使用字符串格式化方法来输出整数。为了输出整数,我们使用 %d 占位符,% 符号后面跟上我们想要输出的整数,例如…

    python 2023年6月5日
    00
  • 如何通过安装HomeBrew来安装Python3

    下面是安装HomeBrew并使用它来安装Python3的完整攻略。 安装HomeBrew 要安装HomeBrew,需要在终端中执行以下命令: /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)" 安装过…

    python 2023年6月2日
    00
  • Python中的list.sort()方法和函数sorted(list)

    以下是“Python中的list.sort()方法和函数sorted(list)”的完整攻略。 1. list.sort()方法 在Python中,list.sort()方法用于对列表进行排序。该方法会直接修改原列表而不是返回一个新的排序后的列表。示例如下: my_list = [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5] my_lis…

    python 2023年5月13日
    00
  • 在python带权重的列表中随机取值的方法

    在Python中,可以通过random.choices方法在带有权重的列表中随机取值,该方法可以根据指定的权重值,生成符合要求的随机数列表。 具体步骤如下: 导入random模块 import random 定义带有权重的列表 假设有一个列表,包含不同的元素和它们的权重值。 my_list = [‘A’, ‘B’, ‘C’, ‘D’] my_weights …

    python 2023年6月3日
    00
  • Python中获取绝对文件路径的目录路径

    【问题标题】:Get the directory path of absolute file path in PythonPython中获取绝对文件路径的目录路径 【发布时间】:2023-04-05 04:56:01 【问题描述】: 我想获取文件所在的目录。例如完整路径为: fullpath = “/absolute/path/to/file” # some…

    Python开发 2023年4月5日
    00
  • Python (seaborn) 的颜色:不添加到 DataFrame 的颜色

    【问题标题】:Colors for Python (seaborn): colors without adding to DataFramePython (seaborn) 的颜色:不添加到 DataFrame 的颜色 【发布时间】:2023-04-02 10:03:01 【问题描述】: slov = {‘People’: {0: ‘Ivan’, 1: ‘J…

    Python开发 2023年4月8日
    00
合作推广
合作推广
分享本页
返回顶部