python实现决策树C4.5算法详解(在ID3基础上改进)

Python实现决策树C4.5算法详解(在ID3基础上改进)

决策树是一种常见的机器学习算法,它可以用于分类和回归问题。C4.5算法是一种基于信息增益比的决策树算法,它在ID3算法的基础上进行了改进,可以处理连续属性和缺失值。在本文中,我们将介绍如何使用Python实现C4.5算法,并详细讲解实现原理。

实现原理

C4.5算法的实现原理比较复杂,我们可以分为以下几个步骤:

  1. 首先定义一个名为Node的类,用于表示决策树的节点。每个节点包含一个属性名、一个属性值、一个子节点列表和一个类别标签。
  2. 然后定义一个名为C45DecisionTree的类,用于表示C4.5算法的决策树。C45DecisionTree类包含一个根节点、一个属性列表和一个类别列表。
  3. 在C45DecisionTree类中,首先定义一个名为build_tree的方法,用于构建决策树。在build_tree方法中,我们首先判断当前节点是否为叶子节点,如果是,返回当前节点的类别标签。否则,选择一个最优的属性,将当前节点分裂成多个子节点,递归调用build_tree方法,构建子树。
  4. 然后定义一个名为choose_best_feature的方法,用于选择最优的属性。在choose_best_feature方法中,我们首先计算每个属性的信息增益比,选择信息增益比最大的属性作为最优属性。
  5. 最后定义一个名为split_data的方法,用于将数据集按照指定属性的值进行划分。在split_data方法中,我们首先判断属性是否为连续值,如果是,将数据集按照属性值进行排序,然后选择最优的划分点,将数据集划分成两个子集。如果属性不是连续值,将数据集按照属性值进行划分,将每个属性值对应的数据集存储在一个字典中。

Python实现

下面是一个使用Python实现C4.5算法的示例:

import math
import pandas as pd

class Node:
    def __init__(self, attr_name=None, attr_value=None, children=None, label=None):
        self.attr_name = attr_name
        self.attr_value = attr_value
        self.children = children or {}
        self.label = label

class C45DecisionTree:
    def __init__(self, data, class_name):
        self.root = None
        self.attr_list = list(data.columns)
        self.attr_list.remove(class_name)
        self.class_list = data[class_name].unique().tolist()
        self.data = data
        self.class_name = class_name

    def build_tree(self, data):
        labels = data[self.class_name].tolist()
        if len(set(labels)) == 1:
            return Node(label=labels[0])
        if len(data) == 0:
            return Node(label=self.majority_class())
        best_attr = self.choose_best_feature(data)
        node = Node(attr_name=best_attr)
        if self.is_continuous(best_attr):
            split_value = self.choose_best_split(data, best_attr)
            left_data = data[data[best_attr] <= split_value]
            right_data = data[data[best_attr] > split_value]
            node.attr_value = split_value
            node.children['<='] = self.build_tree(left_data)
            node.children['>'] = self.build_tree(right_data)
        else:
            for attr_value, sub_data in data.groupby(best_attr):
                node.children[attr_value] = self.build_tree(sub_data.drop(best_attr, axis=1))
        return node

    def choose_best_feature(self, data):
        entropy = self.entropy(data)
        max_gain_ratio = 0
        best_attr = None
        for attr in self.attr_list:
            if self.is_continuous(attr):
                gain_ratio, split_value = self.continuous_gain_ratio(data, attr, entropy)
                if gain_ratio > max_gain_ratio:
                    max_gain_ratio = gain_ratio
                    best_attr = attr
                    best_split_value = split_value
            else:
                gain_ratio = self.discrete_gain_ratio(data, attr, entropy)
                if gain_ratio > max_gain_ratio:
                    max_gain_ratio = gain_ratio
                    best_attr = attr
        if self.is_continuous(best_attr):
            return best_attr, best_split_value
        else:
            return best_attr

    def split_data(self, data, attr):
        if self.is_continuous(attr):
            split_value = self.choose_best_split(data, attr)
            left_data = data[data[attr] <= split_value]
            right_data = data[data[attr] > split_value]
            return {'<=': left_data, '>': right_data}
        else:
            return dict(tuple(data.groupby(attr)))

    def entropy(self, data):
        labels = data[self.class_name].tolist()
        label_count = {}
        for label in labels:
            if label not in label_count:
                label_count[label] = 0
            label_count[label] += 1
        entropy = 0
        for label in label_count:
            prob = label_count[label] / len(labels)
            entropy -= prob * math.log(prob, 2)
        return entropy

    def majority_class(self):
        labels = self.data[self.class_name].tolist()
        label_count = {}
        for label in labels:
            if label not in label_count:
                label_count[label] = 0
            label_count[label] += 1
        majority_label = None
        max_count = 0
        for label in label_count:
            if label_count[label] > max_count:
                max_count = label_count[label]
                majority_label = label
        return majority_label

    def is_continuous(self, attr):
        return self.data[attr].dtype == 'float64'

    def choose_best_split(self, data, attr):
        values = data[attr].tolist()
        values.sort()
        split_points = [(values[i] + values[i+1]) / 2 for i in range(len(values)-1)]
        max_gain_ratio = 0
        best_split_value = None
        for split_value in split_points:
            left_data = data[data[attr] <= split_value]
            right_data = data[data[attr] > split_value]
            gain_ratio = self.continuous_gain_ratio(data, attr, self.entropy(data), split_value)
            if gain_ratio > max_gain_ratio:
                max_gain_ratio = gain_ratio
                best_split_value = split_value
        return best_split_value

    def continuous_gain_ratio(self, data, attr, entropy, split_value=None):
        if split_value is None:
            split_value = self.choose_best_split(data, attr)
        left_data = data[data[attr] <= split_value]
        right_data = data[data[attr] > split_value]
        left_entropy = self.entropy(left_data)
        right_entropy = self.entropy(right_data)
        split_entropy = (len(left_data) / len(data)) * left_entropy + (len(right_data) / len(data)) * right_entropy
        gain = entropy - split_entropy
        split_info = - (len(left_data) / len(data)) * math.log(len(left_data) / len(data), 2) - (len(right_data) / len(data)) * math.log(len(right_data) / len(data), 2)
        if split_info == 0:
            return 0, split_value
        gain_ratio = gain / split_info
        return gain_ratio, split_value

    def discrete_gain_ratio(self, data, attr, entropy):
        sub_data_dict = self.split_data(data, attr)
        split_entropy = 0
        split_info = 0
        for attr_value in sub_data_dict:
            sub_data = sub_data_dict[attr_value]
            prob = len(sub_data) / len(data)
            split_entropy -= prob * self.entropy(sub_data)
            split_info -= prob * math.log(prob, 2)
        gain = entropy - split_entropy
        if split_info == 0:
            return 0
        gain_ratio = gain / split_info
        return gain_ratio

    def predict(self, data):
        node = self.root
        while node.children:
            attr_name = node.attr_name
            if self.is_continuous(attr_name):
                if data[attr_name] <= node.attr_value:
                    node = node.children['<=']
                else:
                    node = node.children['>']
            else:
                attr_value = data[attr_name]
                if attr_value in node.children:
                    node = node.children[attr_value]
                else:
                    return self.majority_class()
        return node.label

    def fit(self):
        self.root = self.build_tree(self.data)

data = pd.read_csv('data.csv')
tree = C45DecisionTree(data, 'class')
tree.fit()
print(tree.predict({'age': 30, 'income': 50000, 'student': 'no', 'credit_rating': 'fair'}))

在这个示例中,我们首先导入了pandas和math模块。然后定义了一个名为Node的类,用于表示决策树的节点。每个节点包含一个属性名、一个属性值、一个子节点列表和一个类别标签。然后定义了一个名为C45DecisionTree的类,用于表示C4.5算法的决策树。C45DecisionTree类包含一个根节点、一个属性列表和一个类别列表。在C45DecisionTree类中,我们定义了build_tree、choose_best_feature、split_data、entropy、majority_class、is_continuous、choose_best_split、continuous_gain_ratio、discrete_gain_ratio和predict等方法,用于构建决策树、选择最优的属性、划分数据集、计算熵、计算多数类、判断属性是否为连续值、选择最优的划分点、计算连续属性的信息增益比、计算离散属性的信息增益比和预测数据的类别标签。最后,我们读取了一个名为data.csv的数据集,使用C45DecisionTree类构建决策树,并预测了一个新的数据的类别标签。

示例1:使用C4.5算法预测鸢尾花的类别

在这个示例中,我们将使用C4.5算法预测鸢尾花的类别。我们首先读取一个名为iris.csv的数据集,然后使用C45DecisionTree类构建决策树,并预测测试数据的类别标签。

import pandas as pd

data = pd.read_csv('iris.csv')
train_data = data.iloc[:120]
test_data = data.iloc[120:]
tree = C45DecisionTree(train_data, 'class')
tree.fit()
correct_count = 0
for i in range(len(test_data)):
    row = test_data.iloc[i]
    label = tree.predict(row.drop('class'))
    if label == row['class']:
        correct_count += 1
accuracy = correct_count / len(test_data)
print('Accuracy:', accuracy)

在这个示例中,我们首先读取了一个名为iris.csv的数据集,然后将前120个样本作为训练集,后30个样本作为测试集。然后使用C45DecisionTree类构建决策树,并预测测试数据的类别标签。最后计算预测准确率。

示例2:使用C4.5算法预测泰坦尼克号乘客的生还情况

在这个示例中,我们将使用C4.5算法预测泰坦尼克号乘客的生还情况。我们首先读取一个名为titanic.csv的数据集,然后使用pandas模块对数据进行预处理,将缺失值填充为平均值或众数。然后使用C45DecisionTree类构建决策树,并预测测试数据的类别标签。

import pandas as pd

data = pd.read_csv('titanic.csv')
data = data.drop(['PassengerId', 'Name', 'Ticket', 'Cabin'], axis=1)
data['Age'] = data['Age'].fillna(data['Age'].mean())
data['Embarked'] = data['Embarked'].fillna(data['Embarked'].mode()[0])
train_data = data.iloc[:800]
test_data = data.iloc[800:]
tree = C45DecisionTree(train_data, 'Survived')
tree.fit()
correct_count = 0
for i in range(len(test_data)):
    row = test_data.iloc[i]
    label = tree.predict(row.drop('Survived'))
    if label == row['Survived']:
        correct_count += 1
accuracy = correct_count / len(test_data)
print('Accuracy:', accuracy)

在这个示例中,我们首先读取了一个名为titanic.csv的数据集,然后使用pandas模块对数据进行预处理,将缺失值填充为平均值或众数。然后将前800个样本作为训练集,后91个样本作为测试集。然后使用C45DecisionTree类构建决策树,并预测测试数据的类别标签。最后计算预测准确率。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:python实现决策树C4.5算法详解(在ID3基础上改进) - Python技术站

(1)
上一篇 2023年5月14日
下一篇 2023年5月14日

相关文章

  • python实现selenium网络爬虫的方法小结

    Python实现Selenium网络爬虫的方法小结 什么是Selenium? Selenium是一个自动化测试工具,通过模拟真实的用户操作,例如点击、输入等,与网站进行交互,获取所需数据。 安装Selenium 在Python中安装Selenium很简单,使用pip命令安装即可: pip install selenium 下载并配置浏览器驱动 Seleniu…

    python 2023年5月14日
    00
  • Python中format()格式输出全解

    Python中format()格式输出全解 在Python中,我们可以使用format()方法对字符串进行格式化输出。使用format()方法可以方便地为字符串添加变量、数字等内容,使输出的字符串更加灵活和具有可读性。 语法 Python中format()方法的语法如下所示: string.format(args) 其中,string是要格式化输出的字符串,…

    python 2023年6月5日
    00
  • Python3爬虫学习之爬虫利器Beautiful Soup用法分析

    Python3爬虫学习之爬虫利器Beautiful Soup用法分析 介绍 在Python3中,爬虫领域有许多实用的工具,而Beautiful Soup就是其中一款非常常用的解析库。 环境配置 在使用Beautiful Soup之前,需要先安装: pip install beautifulsoup4 基本语法 在使用Beautiful Soup解析网页前,需…

    python 2023年5月14日
    00
  • Python函数关键字参数详解

    在Python函数中,关键字参数是一种通过参数名称传递值的方法,而不是按照参数在函数定义中的顺序进行传递。使用关键字参数可以使代码更具可读性,并且可以方便地忽略函数定义中的一些参数。以下是Python函数关键字参数的用法: 定义函数时使用关键字参数 在定义函数时,可以使用关键字参数来指定函数参数的默认值。这样,在调用函数时,如果没有传递参数,则使用默认值。例…

    2023年2月20日
    00
  • 修复CentOS7升级Python到3.6版本后yum不能正确使用的解决方法

    下面是修复 CentOS 7 升级 Python 到 3.6 版本后 yum 不能正确使用的解决方法的攻略过程: 问题描述 当我们在 CentOS 7 系统中升级 Python 版本到 3.6 之后,会出现 yum 不能正确使用的问题,报错信息如下: [root@centos7 ~]# yum Traceback (most recent call last…

    python 2023年5月13日
    00
  • Python 文件读写操作实例详解

    首先,我们来介绍一下Python文件读写操作中常用的函数: open(file, mode=’r’, buffering=-1, encoding=None, errors=None, newline=None, closefd=True, opener=None):打开一个文件并返回文件对象。其中参数file表示文件名(包含路径),mode表示打开文件的模…

    python 2023年5月19日
    00
  • Python模块包中__init__.py文件功能分析

    当我们创建 Python 模块包时,我们经常会创建一个名为 __init__.py 的文件,但是大多数时候,我们可能没有意识到这个文件的作用。在本文中,我将详细讲解 __init__.py 文件在 Python 模块包中的功能分析。 什么是 init.py 文件 __init__.py 是一个特殊的文件名,它告诉 Python 解释器该目录应当视为一个 Py…

    python 2023年6月6日
    00
  • python修改文件内容的3种方法详解

    Python修改文件内容的3种方法详解 在Python编程过程中,我们经常需要修改文件内容。本文将介绍Python中三种常见的修改文件内容的方法。 方法一:将整个文件读入内存,修改后再写入文件 with open(‘file.txt’, ‘r’) as f: lines = f.readlines() with open(‘file.txt’, ‘w’) a…

    python 2023年6月3日
    00
合作推广
合作推广
分享本页
返回顶部