python实现决策树C4.5算法详解(在ID3基础上改进)

yizhihongxing

Python实现决策树C4.5算法详解(在ID3基础上改进)

决策树是一种常见的机器学习算法,它可以用于分类和回归问题。C4.5算法是一种基于信息增益比的决策树算法,它在ID3算法的基础上进行了改进,可以处理连续属性和缺失值。在本文中,我们将介绍如何使用Python实现C4.5算法,并详细讲解实现原理。

实现原理

C4.5算法的实现原理比较复杂,我们可以分为以下几个步骤:

  1. 首先定义一个名为Node的类,用于表示决策树的节点。每个节点包含一个属性名、一个属性值、一个子节点列表和一个类别标签。
  2. 然后定义一个名为C45DecisionTree的类,用于表示C4.5算法的决策树。C45DecisionTree类包含一个根节点、一个属性列表和一个类别列表。
  3. 在C45DecisionTree类中,首先定义一个名为build_tree的方法,用于构建决策树。在build_tree方法中,我们首先判断当前节点是否为叶子节点,如果是,返回当前节点的类别标签。否则,选择一个最优的属性,将当前节点分裂成多个子节点,递归调用build_tree方法,构建子树。
  4. 然后定义一个名为choose_best_feature的方法,用于选择最优的属性。在choose_best_feature方法中,我们首先计算每个属性的信息增益比,选择信息增益比最大的属性作为最优属性。
  5. 最后定义一个名为split_data的方法,用于将数据集按照指定属性的值进行划分。在split_data方法中,我们首先判断属性是否为连续值,如果是,将数据集按照属性值进行排序,然后选择最优的划分点,将数据集划分成两个子集。如果属性不是连续值,将数据集按照属性值进行划分,将每个属性值对应的数据集存储在一个字典中。

Python实现

下面是一个使用Python实现C4.5算法的示例:

import math
import pandas as pd

class Node:
    def __init__(self, attr_name=None, attr_value=None, children=None, label=None):
        self.attr_name = attr_name
        self.attr_value = attr_value
        self.children = children or {}
        self.label = label

class C45DecisionTree:
    def __init__(self, data, class_name):
        self.root = None
        self.attr_list = list(data.columns)
        self.attr_list.remove(class_name)
        self.class_list = data[class_name].unique().tolist()
        self.data = data
        self.class_name = class_name

    def build_tree(self, data):
        labels = data[self.class_name].tolist()
        if len(set(labels)) == 1:
            return Node(label=labels[0])
        if len(data) == 0:
            return Node(label=self.majority_class())
        best_attr = self.choose_best_feature(data)
        node = Node(attr_name=best_attr)
        if self.is_continuous(best_attr):
            split_value = self.choose_best_split(data, best_attr)
            left_data = data[data[best_attr] <= split_value]
            right_data = data[data[best_attr] > split_value]
            node.attr_value = split_value
            node.children['<='] = self.build_tree(left_data)
            node.children['>'] = self.build_tree(right_data)
        else:
            for attr_value, sub_data in data.groupby(best_attr):
                node.children[attr_value] = self.build_tree(sub_data.drop(best_attr, axis=1))
        return node

    def choose_best_feature(self, data):
        entropy = self.entropy(data)
        max_gain_ratio = 0
        best_attr = None
        for attr in self.attr_list:
            if self.is_continuous(attr):
                gain_ratio, split_value = self.continuous_gain_ratio(data, attr, entropy)
                if gain_ratio > max_gain_ratio:
                    max_gain_ratio = gain_ratio
                    best_attr = attr
                    best_split_value = split_value
            else:
                gain_ratio = self.discrete_gain_ratio(data, attr, entropy)
                if gain_ratio > max_gain_ratio:
                    max_gain_ratio = gain_ratio
                    best_attr = attr
        if self.is_continuous(best_attr):
            return best_attr, best_split_value
        else:
            return best_attr

    def split_data(self, data, attr):
        if self.is_continuous(attr):
            split_value = self.choose_best_split(data, attr)
            left_data = data[data[attr] <= split_value]
            right_data = data[data[attr] > split_value]
            return {'<=': left_data, '>': right_data}
        else:
            return dict(tuple(data.groupby(attr)))

    def entropy(self, data):
        labels = data[self.class_name].tolist()
        label_count = {}
        for label in labels:
            if label not in label_count:
                label_count[label] = 0
            label_count[label] += 1
        entropy = 0
        for label in label_count:
            prob = label_count[label] / len(labels)
            entropy -= prob * math.log(prob, 2)
        return entropy

    def majority_class(self):
        labels = self.data[self.class_name].tolist()
        label_count = {}
        for label in labels:
            if label not in label_count:
                label_count[label] = 0
            label_count[label] += 1
        majority_label = None
        max_count = 0
        for label in label_count:
            if label_count[label] > max_count:
                max_count = label_count[label]
                majority_label = label
        return majority_label

    def is_continuous(self, attr):
        return self.data[attr].dtype == 'float64'

    def choose_best_split(self, data, attr):
        values = data[attr].tolist()
        values.sort()
        split_points = [(values[i] + values[i+1]) / 2 for i in range(len(values)-1)]
        max_gain_ratio = 0
        best_split_value = None
        for split_value in split_points:
            left_data = data[data[attr] <= split_value]
            right_data = data[data[attr] > split_value]
            gain_ratio = self.continuous_gain_ratio(data, attr, self.entropy(data), split_value)
            if gain_ratio > max_gain_ratio:
                max_gain_ratio = gain_ratio
                best_split_value = split_value
        return best_split_value

    def continuous_gain_ratio(self, data, attr, entropy, split_value=None):
        if split_value is None:
            split_value = self.choose_best_split(data, attr)
        left_data = data[data[attr] <= split_value]
        right_data = data[data[attr] > split_value]
        left_entropy = self.entropy(left_data)
        right_entropy = self.entropy(right_data)
        split_entropy = (len(left_data) / len(data)) * left_entropy + (len(right_data) / len(data)) * right_entropy
        gain = entropy - split_entropy
        split_info = - (len(left_data) / len(data)) * math.log(len(left_data) / len(data), 2) - (len(right_data) / len(data)) * math.log(len(right_data) / len(data), 2)
        if split_info == 0:
            return 0, split_value
        gain_ratio = gain / split_info
        return gain_ratio, split_value

    def discrete_gain_ratio(self, data, attr, entropy):
        sub_data_dict = self.split_data(data, attr)
        split_entropy = 0
        split_info = 0
        for attr_value in sub_data_dict:
            sub_data = sub_data_dict[attr_value]
            prob = len(sub_data) / len(data)
            split_entropy -= prob * self.entropy(sub_data)
            split_info -= prob * math.log(prob, 2)
        gain = entropy - split_entropy
        if split_info == 0:
            return 0
        gain_ratio = gain / split_info
        return gain_ratio

    def predict(self, data):
        node = self.root
        while node.children:
            attr_name = node.attr_name
            if self.is_continuous(attr_name):
                if data[attr_name] <= node.attr_value:
                    node = node.children['<=']
                else:
                    node = node.children['>']
            else:
                attr_value = data[attr_name]
                if attr_value in node.children:
                    node = node.children[attr_value]
                else:
                    return self.majority_class()
        return node.label

    def fit(self):
        self.root = self.build_tree(self.data)

data = pd.read_csv('data.csv')
tree = C45DecisionTree(data, 'class')
tree.fit()
print(tree.predict({'age': 30, 'income': 50000, 'student': 'no', 'credit_rating': 'fair'}))

在这个示例中,我们首先导入了pandas和math模块。然后定义了一个名为Node的类,用于表示决策树的节点。每个节点包含一个属性名、一个属性值、一个子节点列表和一个类别标签。然后定义了一个名为C45DecisionTree的类,用于表示C4.5算法的决策树。C45DecisionTree类包含一个根节点、一个属性列表和一个类别列表。在C45DecisionTree类中,我们定义了build_tree、choose_best_feature、split_data、entropy、majority_class、is_continuous、choose_best_split、continuous_gain_ratio、discrete_gain_ratio和predict等方法,用于构建决策树、选择最优的属性、划分数据集、计算熵、计算多数类、判断属性是否为连续值、选择最优的划分点、计算连续属性的信息增益比、计算离散属性的信息增益比和预测数据的类别标签。最后,我们读取了一个名为data.csv的数据集,使用C45DecisionTree类构建决策树,并预测了一个新的数据的类别标签。

示例1:使用C4.5算法预测鸢尾花的类别

在这个示例中,我们将使用C4.5算法预测鸢尾花的类别。我们首先读取一个名为iris.csv的数据集,然后使用C45DecisionTree类构建决策树,并预测测试数据的类别标签。

import pandas as pd

data = pd.read_csv('iris.csv')
train_data = data.iloc[:120]
test_data = data.iloc[120:]
tree = C45DecisionTree(train_data, 'class')
tree.fit()
correct_count = 0
for i in range(len(test_data)):
    row = test_data.iloc[i]
    label = tree.predict(row.drop('class'))
    if label == row['class']:
        correct_count += 1
accuracy = correct_count / len(test_data)
print('Accuracy:', accuracy)

在这个示例中,我们首先读取了一个名为iris.csv的数据集,然后将前120个样本作为训练集,后30个样本作为测试集。然后使用C45DecisionTree类构建决策树,并预测测试数据的类别标签。最后计算预测准确率。

示例2:使用C4.5算法预测泰坦尼克号乘客的生还情况

在这个示例中,我们将使用C4.5算法预测泰坦尼克号乘客的生还情况。我们首先读取一个名为titanic.csv的数据集,然后使用pandas模块对数据进行预处理,将缺失值填充为平均值或众数。然后使用C45DecisionTree类构建决策树,并预测测试数据的类别标签。

import pandas as pd

data = pd.read_csv('titanic.csv')
data = data.drop(['PassengerId', 'Name', 'Ticket', 'Cabin'], axis=1)
data['Age'] = data['Age'].fillna(data['Age'].mean())
data['Embarked'] = data['Embarked'].fillna(data['Embarked'].mode()[0])
train_data = data.iloc[:800]
test_data = data.iloc[800:]
tree = C45DecisionTree(train_data, 'Survived')
tree.fit()
correct_count = 0
for i in range(len(test_data)):
    row = test_data.iloc[i]
    label = tree.predict(row.drop('Survived'))
    if label == row['Survived']:
        correct_count += 1
accuracy = correct_count / len(test_data)
print('Accuracy:', accuracy)

在这个示例中,我们首先读取了一个名为titanic.csv的数据集,然后使用pandas模块对数据进行预处理,将缺失值填充为平均值或众数。然后将前800个样本作为训练集,后91个样本作为测试集。然后使用C45DecisionTree类构建决策树,并预测测试数据的类别标签。最后计算预测准确率。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:python实现决策树C4.5算法详解(在ID3基础上改进) - Python技术站

(0)
上一篇 2023年5月14日
下一篇 2023年5月14日

相关文章

  • 用Python将IP地址在整型和字符串之间轻松转换

    将IP地址转换为整型或字符串是在网络编程和数据库等领域中经常使用的操作。Python提供了一些内置函数和标准库来实现这种转换。下面是详细的攻略: IP地址的整型和字符串表示 IP地址是计算机网络中的一个重要概念,表示的是网络中一个节点的地址。在IPv4中,IP地址通常是通过四个十进制数表示,例如:192.168.0.1。而在计算机中,IP地址通常被转换成一个…

    python 2023年5月19日
    00
  • Python tkinter事件高级用法实例

    请允许我从以下几个方面来讲解Python tkinter事件高级用法实例的完整攻略。 简介 Python tkinter是一个用于图形用户界面编程的模块。在tkinter中,事件是很重要的概念,它可以使程序变得更加动态和交互,同时可以增强用户体验。在Python tkinter中,事件也有许多高级用法,例如延迟事件、绑定事件等。 延迟事件 延迟事件指的是,当…

    python 2023年6月5日
    00
  • Python求解平方根的方法

    Python 求解平方根的方法,主要可以分为以下两种: 1. 使用 math 模块 Python 内置的 math 库提供了 sqrt(x) 方法用于求平方根,该方法的使用方法如下: import math # 求平方根 math.sqrt(4) # 返回 2.0 代码说明: 导入 math 库; 使用 sqrt 方法,传入要求平方根的数字。 2. 使用幂运…

    python 2023年6月5日
    00
  • Python Deque 模块使用详解

    Python Deque 模块使用详解 什么是Deque Deque是 “double-ended queue”(双端队列)的缩写,在Python中是一个数据结构。它是一个可在两端添加和删除元素的序列,通俗点说它是一种可以在两端进行操作的序列。 Deque的主要方法 Deque包含以下方法: 方法 描述 append(x) 向右侧添加x元素 appendle…

    python 2023年6月3日
    00
  • Python 字符串使用多个分隔符分割成列表的2种方法

    下面是详细讲解“Python 字符串使用多个分隔符分割成列表的2种方法”的完整攻略。 方法一:使用正则表达式分割 Python 提供了非常方便的正则表达式工具,可以用正则表达式来分割字符串。以下是代码示例: import re text = ‘hello|world#python’ pattern = re.compile(r'[|#]’) result =…

    python 2023年6月3日
    00
  • Python 判断文件或目录是否存在的实例代码

    当我们在编写 Python 程序时,经常需要判断文件或目录是否存在,以便进行相应的操作。Python 提供了 os 模块可以很方便的判断文件或目录是否存在。 1. 导入 os 模块 在 Python 中使用 os 模块需要先导入它,可以使用 import 语句导入 os 模块,代码如下: import os 2. 使用 os.path 模块判断文件或目录是否…

    python 2023年6月2日
    00
  • 使用PyCharm配合部署Python的Django框架的配置纪实

    下面是使用PyCharm配合部署Python的Django框架的配置纪实的具体攻略,包括以下几个步骤: 1. 安装Python 在安装PyCharm之前,首先需要安装Python。可以到 Python官网 下载最新版本的Python,并按照安装向导进行安装。 2. 安装PyCharm 可以到 PyCharm官网 下载最新版本的PyCharm,并按照安装向导进…

    python 2023年5月13日
    00
  • python包pdfkit(wkhtmltopdf) 将HTML转换为PDF的操作方法

    Python包pdfkit(wkhtmltopdf)将HTML转换为PDF的操作方法 pdfkit是一个Python包,它使用wkhtmltopdf将HTML文件转换为PDF文件。wkhtmltopdf是一个开源的命令行工具,它可以将HTML文件转换为PDF文件。pdfkit提供了一个简单的Python接口,可以轻松地将HTML文件转换为PDF文件。本文将介…

    python 2023年5月15日
    00
合作推广
合作推广
分享本页
返回顶部