python实现决策树C4.5算法详解(在ID3基础上改进)

Python实现决策树C4.5算法详解(在ID3基础上改进)

决策树是一种常见的机器学习算法,它可以用于分类和回归问题。C4.5算法是一种基于信息增益比的决策树算法,它在ID3算法的基础上进行了改进,可以处理连续属性和缺失值。在本文中,我们将介绍如何使用Python实现C4.5算法,并详细讲解实现原理。

实现原理

C4.5算法的实现原理比较复杂,我们可以分为以下几个步骤:

  1. 首先定义一个名为Node的类,用于表示决策树的节点。每个节点包含一个属性名、一个属性值、一个子节点列表和一个类别标签。
  2. 然后定义一个名为C45DecisionTree的类,用于表示C4.5算法的决策树。C45DecisionTree类包含一个根节点、一个属性列表和一个类别列表。
  3. 在C45DecisionTree类中,首先定义一个名为build_tree的方法,用于构建决策树。在build_tree方法中,我们首先判断当前节点是否为叶子节点,如果是,返回当前节点的类别标签。否则,选择一个最优的属性,将当前节点分裂成多个子节点,递归调用build_tree方法,构建子树。
  4. 然后定义一个名为choose_best_feature的方法,用于选择最优的属性。在choose_best_feature方法中,我们首先计算每个属性的信息增益比,选择信息增益比最大的属性作为最优属性。
  5. 最后定义一个名为split_data的方法,用于将数据集按照指定属性的值进行划分。在split_data方法中,我们首先判断属性是否为连续值,如果是,将数据集按照属性值进行排序,然后选择最优的划分点,将数据集划分成两个子集。如果属性不是连续值,将数据集按照属性值进行划分,将每个属性值对应的数据集存储在一个字典中。

Python实现

下面是一个使用Python实现C4.5算法的示例:

import math
import pandas as pd

class Node:
    def __init__(self, attr_name=None, attr_value=None, children=None, label=None):
        self.attr_name = attr_name
        self.attr_value = attr_value
        self.children = children or {}
        self.label = label

class C45DecisionTree:
    def __init__(self, data, class_name):
        self.root = None
        self.attr_list = list(data.columns)
        self.attr_list.remove(class_name)
        self.class_list = data[class_name].unique().tolist()
        self.data = data
        self.class_name = class_name

    def build_tree(self, data):
        labels = data[self.class_name].tolist()
        if len(set(labels)) == 1:
            return Node(label=labels[0])
        if len(data) == 0:
            return Node(label=self.majority_class())
        best_attr = self.choose_best_feature(data)
        node = Node(attr_name=best_attr)
        if self.is_continuous(best_attr):
            split_value = self.choose_best_split(data, best_attr)
            left_data = data[data[best_attr] <= split_value]
            right_data = data[data[best_attr] > split_value]
            node.attr_value = split_value
            node.children['<='] = self.build_tree(left_data)
            node.children['>'] = self.build_tree(right_data)
        else:
            for attr_value, sub_data in data.groupby(best_attr):
                node.children[attr_value] = self.build_tree(sub_data.drop(best_attr, axis=1))
        return node

    def choose_best_feature(self, data):
        entropy = self.entropy(data)
        max_gain_ratio = 0
        best_attr = None
        for attr in self.attr_list:
            if self.is_continuous(attr):
                gain_ratio, split_value = self.continuous_gain_ratio(data, attr, entropy)
                if gain_ratio > max_gain_ratio:
                    max_gain_ratio = gain_ratio
                    best_attr = attr
                    best_split_value = split_value
            else:
                gain_ratio = self.discrete_gain_ratio(data, attr, entropy)
                if gain_ratio > max_gain_ratio:
                    max_gain_ratio = gain_ratio
                    best_attr = attr
        if self.is_continuous(best_attr):
            return best_attr, best_split_value
        else:
            return best_attr

    def split_data(self, data, attr):
        if self.is_continuous(attr):
            split_value = self.choose_best_split(data, attr)
            left_data = data[data[attr] <= split_value]
            right_data = data[data[attr] > split_value]
            return {'<=': left_data, '>': right_data}
        else:
            return dict(tuple(data.groupby(attr)))

    def entropy(self, data):
        labels = data[self.class_name].tolist()
        label_count = {}
        for label in labels:
            if label not in label_count:
                label_count[label] = 0
            label_count[label] += 1
        entropy = 0
        for label in label_count:
            prob = label_count[label] / len(labels)
            entropy -= prob * math.log(prob, 2)
        return entropy

    def majority_class(self):
        labels = self.data[self.class_name].tolist()
        label_count = {}
        for label in labels:
            if label not in label_count:
                label_count[label] = 0
            label_count[label] += 1
        majority_label = None
        max_count = 0
        for label in label_count:
            if label_count[label] > max_count:
                max_count = label_count[label]
                majority_label = label
        return majority_label

    def is_continuous(self, attr):
        return self.data[attr].dtype == 'float64'

    def choose_best_split(self, data, attr):
        values = data[attr].tolist()
        values.sort()
        split_points = [(values[i] + values[i+1]) / 2 for i in range(len(values)-1)]
        max_gain_ratio = 0
        best_split_value = None
        for split_value in split_points:
            left_data = data[data[attr] <= split_value]
            right_data = data[data[attr] > split_value]
            gain_ratio = self.continuous_gain_ratio(data, attr, self.entropy(data), split_value)
            if gain_ratio > max_gain_ratio:
                max_gain_ratio = gain_ratio
                best_split_value = split_value
        return best_split_value

    def continuous_gain_ratio(self, data, attr, entropy, split_value=None):
        if split_value is None:
            split_value = self.choose_best_split(data, attr)
        left_data = data[data[attr] <= split_value]
        right_data = data[data[attr] > split_value]
        left_entropy = self.entropy(left_data)
        right_entropy = self.entropy(right_data)
        split_entropy = (len(left_data) / len(data)) * left_entropy + (len(right_data) / len(data)) * right_entropy
        gain = entropy - split_entropy
        split_info = - (len(left_data) / len(data)) * math.log(len(left_data) / len(data), 2) - (len(right_data) / len(data)) * math.log(len(right_data) / len(data), 2)
        if split_info == 0:
            return 0, split_value
        gain_ratio = gain / split_info
        return gain_ratio, split_value

    def discrete_gain_ratio(self, data, attr, entropy):
        sub_data_dict = self.split_data(data, attr)
        split_entropy = 0
        split_info = 0
        for attr_value in sub_data_dict:
            sub_data = sub_data_dict[attr_value]
            prob = len(sub_data) / len(data)
            split_entropy -= prob * self.entropy(sub_data)
            split_info -= prob * math.log(prob, 2)
        gain = entropy - split_entropy
        if split_info == 0:
            return 0
        gain_ratio = gain / split_info
        return gain_ratio

    def predict(self, data):
        node = self.root
        while node.children:
            attr_name = node.attr_name
            if self.is_continuous(attr_name):
                if data[attr_name] <= node.attr_value:
                    node = node.children['<=']
                else:
                    node = node.children['>']
            else:
                attr_value = data[attr_name]
                if attr_value in node.children:
                    node = node.children[attr_value]
                else:
                    return self.majority_class()
        return node.label

    def fit(self):
        self.root = self.build_tree(self.data)

data = pd.read_csv('data.csv')
tree = C45DecisionTree(data, 'class')
tree.fit()
print(tree.predict({'age': 30, 'income': 50000, 'student': 'no', 'credit_rating': 'fair'}))

在这个示例中,我们首先导入了pandas和math模块。然后定义了一个名为Node的类,用于表示决策树的节点。每个节点包含一个属性名、一个属性值、一个子节点列表和一个类别标签。然后定义了一个名为C45DecisionTree的类,用于表示C4.5算法的决策树。C45DecisionTree类包含一个根节点、一个属性列表和一个类别列表。在C45DecisionTree类中,我们定义了build_tree、choose_best_feature、split_data、entropy、majority_class、is_continuous、choose_best_split、continuous_gain_ratio、discrete_gain_ratio和predict等方法,用于构建决策树、选择最优的属性、划分数据集、计算熵、计算多数类、判断属性是否为连续值、选择最优的划分点、计算连续属性的信息增益比、计算离散属性的信息增益比和预测数据的类别标签。最后,我们读取了一个名为data.csv的数据集,使用C45DecisionTree类构建决策树,并预测了一个新的数据的类别标签。

示例1:使用C4.5算法预测鸢尾花的类别

在这个示例中,我们将使用C4.5算法预测鸢尾花的类别。我们首先读取一个名为iris.csv的数据集,然后使用C45DecisionTree类构建决策树,并预测测试数据的类别标签。

import pandas as pd

data = pd.read_csv('iris.csv')
train_data = data.iloc[:120]
test_data = data.iloc[120:]
tree = C45DecisionTree(train_data, 'class')
tree.fit()
correct_count = 0
for i in range(len(test_data)):
    row = test_data.iloc[i]
    label = tree.predict(row.drop('class'))
    if label == row['class']:
        correct_count += 1
accuracy = correct_count / len(test_data)
print('Accuracy:', accuracy)

在这个示例中,我们首先读取了一个名为iris.csv的数据集,然后将前120个样本作为训练集,后30个样本作为测试集。然后使用C45DecisionTree类构建决策树,并预测测试数据的类别标签。最后计算预测准确率。

示例2:使用C4.5算法预测泰坦尼克号乘客的生还情况

在这个示例中,我们将使用C4.5算法预测泰坦尼克号乘客的生还情况。我们首先读取一个名为titanic.csv的数据集,然后使用pandas模块对数据进行预处理,将缺失值填充为平均值或众数。然后使用C45DecisionTree类构建决策树,并预测测试数据的类别标签。

import pandas as pd

data = pd.read_csv('titanic.csv')
data = data.drop(['PassengerId', 'Name', 'Ticket', 'Cabin'], axis=1)
data['Age'] = data['Age'].fillna(data['Age'].mean())
data['Embarked'] = data['Embarked'].fillna(data['Embarked'].mode()[0])
train_data = data.iloc[:800]
test_data = data.iloc[800:]
tree = C45DecisionTree(train_data, 'Survived')
tree.fit()
correct_count = 0
for i in range(len(test_data)):
    row = test_data.iloc[i]
    label = tree.predict(row.drop('Survived'))
    if label == row['Survived']:
        correct_count += 1
accuracy = correct_count / len(test_data)
print('Accuracy:', accuracy)

在这个示例中,我们首先读取了一个名为titanic.csv的数据集,然后使用pandas模块对数据进行预处理,将缺失值填充为平均值或众数。然后将前800个样本作为训练集,后91个样本作为测试集。然后使用C45DecisionTree类构建决策树,并预测测试数据的类别标签。最后计算预测准确率。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:python实现决策树C4.5算法详解(在ID3基础上改进) - Python技术站

(1)
上一篇 2023年5月14日
下一篇 2023年5月14日

相关文章

  • Python实现输出某区间范围内全部素数的方法

    要实现输出某区间范围内全部素数的方法,可以按照以下步骤进行: 1. 确认素数的定义 素数是指除了1和它本身以外没有其他因数的整数,比如2、3、5、7等。由此可知,在判断素数时只需要判断这个数能否被2到sqrt(num)之间的整数整除即可。如果存在能够整除的数,那么这个数就不是素数。 2. 从输入中获取区间范围 首先,需要从输入中获取待求的区间范围,即起始数值…

    python 2023年6月5日
    00
  • Python统计学一数据的概括性度量详解

    Python统计学一数据的概括性度量详解 在统计学中,我们需要使用概括性度量来描述数据的特征,这样可以让我们更好地理解数据分布和变异性。Python中有丰富的函数库来管理数据,所以也有很多可用于计算概括性度量的函数。 1. 数据的基本概括性度量 1.1 均值 均值是最常见的区分数据集中趋势的量。均值是数据点的和除以数据点的数量。 在Python中,我们可以使…

    python 2023年6月5日
    00
  • python如何快速生成时间戳

    想要快速生成时间戳,我们可以使用 Python 中的 time 模块和 datetime 模块。下面是具体步骤: 1. 导入模块 import time import datetime 2. 使用 time 模块生成时间戳 使用 time 模块中的 time() 函数可以获取当前时间的时间戳。时间戳是一个浮点数,表示自 Epoch(1970 年 1 月 1 …

    python 2023年6月2日
    00
  • 如何利用Python和matplotlib更改纵横坐标刻度颜色

    我会详细讲解如何利用Python和matplotlib更改纵横坐标刻度颜色。 准备工作 在开始说明如何更改坐标刻度颜色前,我们需要准备一些工作: 安装Python和Matplotlib:在开始之前需要确保你已经成功安装了Python和matplotlib。如果没有安装,可以前往Python官网和Matplotlib官网进行下载和安装。 导入matplotli…

    python 2023年5月18日
    00
  • Python写脚本常用模块OS基础用法详解

    Python写脚本常用模块OS基础用法详解 随着Python在日常工作中的应用越来越广泛,越来越多的人开始使用Python来编写脚本进行自动化操作。而在编写Python脚本的过程中,常常会用到OS模块。本篇攻略将详细讲解Python中OS模块的基础用法。 OS模块的基本介绍 Python中的OS模块是一个用来访问操作系统服务的模块,它提供了许多访问操作系统底…

    python 2023年5月31日
    00
  • jenkins+python自动化测试持续集成教程

    以下是“Jenkins+Python自动化测试持续集成教程”的完整攻略: 什么是Jenkins? Jenkins是一款非常流行的开源自动化部署工具,它可以自动编译、测试和部署软件项目。 什么是Python自动化测试? Python自动化测试是使用Python语言编写的自动化测试脚本,可以自动完成软件测试过程。 Jenkins+Python自动化测试持续集成流…

    python 2023年6月6日
    00
  • Python爬虫实战之使用Scrapy爬取豆瓣图片

    下面我将为您详细讲解“Python爬虫实战之使用Scrapy爬取豆瓣图片”的完整攻略,包括如何使用Scrapy在豆瓣网站上爬取图片。 Scrapy爬虫实战:使用Scrapy爬取豆瓣图片 本次爬虫实战使用的主要工具是Scrapy框架,Scrapy是一个用于爬取网站数据的高级Python框架,它使用了Twisted异步网络框架来处理网络通讯,在性能上有着不错的表…

    python 2023年5月14日
    00
  • Python高效处理大文件的方法详解

    Python高效处理大文件的方法详解 处理大文件是Python程序中常见的任务之一。在处理大文件时,需要注意内存使用情况,以避免程序运行过程中出现内存溢出等问题。下面介绍一些Python高效处理大文件的方法。 读取大文件 读取大文件时,可以使用Python自带的文件读取方法。但是,如果一次读入整个文件,会占用大量的内存,因此需要一行一行地读取文件内容。下面是…

    python 2023年6月3日
    00
合作推广
合作推广
分享本页
返回顶部