Python机器学习之决策树算法

下面是关于“Python机器学习之决策树算法”的完整攻略。

1. 决策树算法的基本原理

决策树算法是一种基于树形结构的分类算法,它通过对数据集进行递归分割,生成一棵树形结构,用于对新数据进行分类。决策树算法的基本流程如下:

  1. 选择最优特征:根据某种评估指标,选择最优的特征作为当前节点的分裂特征。
  2. 分裂节点:根据分裂特征的取值,将当前节点分裂成多个子节点。
  3. 递归:对每个子节点递归执行步骤1和步骤2,直到满足终止条件。
  4. 终止条件:达到预设的终止条件,如树的深度、节点数等。

2. 决策树算法的Python实现

以下是决策树算法的Python实现示例:

import numpy as np

# 定义节点类
class Node:
    def __init__(self, feature_index=None, threshold=None, left=None, right=None, value=None):
        self.feature_index = feature_index  # 分裂特征的索引
        self.threshold = threshold  # 分裂特征的阈值
        self.left = left  # 左子节点
        self.right = right  # 右子节点
        self.value = value  # 叶节点的值

# 定义决策树类
class DecisionTree:
    def __init__(self, max_depth=None):
        self.max_depth = max_depth  # 树的最大深度
        self.root = None  # 根节点

    # 计算基尼指数
    def gini(self, y):
        _, counts = np.unique(y, return_counts=True)
        p = counts / len(y)
        return 1 - np.sum(p ** 2)

    # 计算信息熵
    def entropy(self, y):
        _, counts = np.unique(y, return_counts=True)
        p = counts / len(y)
        return -np.sum(p * np.log2(p))

    # 选择最优特征
    def choose_best_feature(self, X, y):
        best_feature_index = None
        best_threshold = None
        best_score = float('inf')
        for feature_index in range(X.shape[1]):
            thresholds = np.unique(X[:, feature_index])
            for threshold in thresholds:
                y_left = y[X[:, feature_index] < threshold]
                y_right = y[X[:, feature_index] >= threshold]
                score = len(y_left) * self.gini(y_left) + len(y_right) * self.gini(y_right)
                if score < best_score:
                    best_feature_index = feature_index
                    best_threshold = threshold
                    best_score = score
        return best_feature_index, best_threshold

    # 构建决策树
    def build_tree(self, X, y, depth=0):
        if depth == self.max_depth or len(np.unique(y)) == 1:
            return Node(value=np.bincount(y).argmax())
        feature_index, threshold = self.choose_best_feature(X, y)
        X_left, y_left = X[X[:, feature_index] < threshold], y[X[:, feature_index] < threshold]
        X_right, y_right = X[X[:, feature_index] >= threshold], y[X[:, feature_index] >= threshold]
        left = self.build_tree(X_left, y_left, depth+1)
        right = self.build_tree(X_right, y_right, depth+1)
        return Node(feature_index=feature_index, threshold=threshold, left=left, right=right)

    # 训练决策树
    def fit(self, X, y):
        self.root = self.build_tree(X, y)

    # 预测单个样本
    def predict_sample(self, x, node):
        if node.value is not None:
            return node.value
        if x[node.feature_index] < node.threshold:
            return self.predict_sample(x, node.left)
        else:
            return self.predict_sample(x, node.right)

    # 预测多个样本
    def predict(self, X):
        return np.array([self.predict_sample(x, self.root) for x in X])

在这个示例中,我们定义了一个Node类,用于表示决策树的节点。每个节点包含分裂特征的索引feature_index、分裂特征的阈值threshold、左子节点left、右子节点right和叶节点的值value。我们还定义了一个DecisionTree类,用于表示决策树。每个决策树包含树的最大深度max_depth和根节点root。我们使用gini()函数计算基尼指数,使用entropy()函数计算信息熵。我们使用choose_best_feature()函数选择最优特征,使用build_tree()函数构建决策树。最后,我们使用fit()函数训练决策树,使用predict()函数预测多个样本。

以下是使用决策树算法解决鸢尾花分类问题的Python示例:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from decision_tree import DecisionTree

# 加载数据集
iris = load_iris()
X, y = iris.data, iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 训练模型
model = DecisionTree(max_depth=3)
model.fit(X_train, y_train)

# 预测测试集
y_pred = model.predict(X_test)

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print('Accuracy:', accuracy)

在这个示例中,我们使用load_iris()函数加载鸢尾花数据集,使用train_test_split()函数划分训练集和测试集。接着,我们使用DecisionTree类训练决策树模型,并使用predict()函数预测测试集。最后,我们使用accuracy_score()函数计算准确率。

以下是使用决策树算法解决泰坦尼克号生存预测问题的Python示例:

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from decision_tree import DecisionTree

# 加载数据集
data = pd.read_csv('titanic.csv')
X = data[['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare']].values
y = data['Survived'].values

# 处理缺失值
X[:, 2][pd.isnull(X[:, 2])] = np.mean(X[:, 2][~pd.isnull(X[:, 2])])

# 处理分类变量
X[X[:, 1] == 'male', 1] = 0
X[X[:, 1] == 'female', 1] = 1

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 训练模型
model = DecisionTree(max_depth=3)
model.fit(X_train, y_train)

# 预测测试集
y_pred = model.predict(X_test)

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print('Accuracy:', accuracy)

在这个示例中,我们使用pd.read_csv()函数加载泰坦尼克号数据集,使用train_test_split()函数划分训练集和测试集。接着,我们使用DecisionTree类训练决策树模型,并使用predict()函数预测测试集。最后,我们使用accuracy_score()函数计算准确率。

3. 总结

决策树算法是一种基于树形结构的分类算法,它通过对数据集进行递归分割,生成一棵树形结构,用于对新数据进行分类。在Python中,我们可以使用类和函数等基本语言特性来实现决策树算法。决策树算法的应用非常广泛,可以用于分类、回归、特征选择等领域。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Python机器学习之决策树算法 - Python技术站

(0)
上一篇 2023年5月13日
下一篇 2023年5月13日

相关文章

  • python-docx如何缩进两个字符

    要让python-docx生成的Word文档内容缩进两个字符,可以使用Python字符串的缩进操作。 首先,我们需要安装python-docx库。可以使用pip命令进行安装: pip install python-docx 然后,我们可以使用python-docx库创建一个Word文档,并添加段落和文本内容: from docx import Documen…

    python 2023年6月5日
    00
  • python的链表基础知识点

    Python的链表基础知识点 链表的定义 链表是一种常见的数据结构,它的节点包含两个部分:数据和指向下一个节点的指针。链表的最后一个节点指向None。 Python中链表的定义可以使用class来实现。例如定义一个链表节点的类: class ListNode: def __init__(self, x): self.val = x self.next = N…

    python 2023年5月14日
    00
  • Python3 pickle模块的使用方法详细介绍

    Python3 pickle模块的使用方法详细介绍 pickle模块是Python提供的一种对象序列化和反序列化的工具,能够将Python对象转换为一个可以存储到磁盘上或者进行网络传输的字符串,同时也能够将这个字符串反序列化为原来的Python对象。使用pickle模块可以方便地实现数据的持久化和传输,是Python编程中非常重要的一部分。 序列化和反序列化…

    python 2023年6月2日
    00
  • 浅析Python 引号、注释、字符串

    在本攻略中,我们将浅析Python引号、注释、字符串。这些是Python编程中非常基础的概念,但也是非常重要的。 引号 在Python中,字符串可以使用单引号、双引号或三引号来表示。以下是一个示例代码,演示了如何使用不同类型的引号来表示字符串: # 使用单引号表示字符串 str1 = ‘Hello, World!’ print(str1) # 使用双引号表示…

    python 2023年5月15日
    00
  • python使用openpyxl库修改excel表格数据方法

    下面就分享一下关于“python使用openpyxl库修改excel表格数据方法”的详细实例教程。 一、openpyxl库介绍 openpyxl是用于读写Excel xlsx/xlsm文件的Python库。它不仅支持读取操作,还支持创建、修改、合并Excel文件的操作。openpyxl库具有较高的可扩展性和稳定性,因此在Python操作Excel文件方面得到…

    python 2023年5月13日
    00
  • python数据分析之用sklearn预测糖尿病

    Python数据分析之用sklearn预测糖尿病 在Python中,可以使用sklearn库对糖尿病数据进行预测。本文将为您详细讲解Python数据分析之用sklearn预测糖尿病的完整攻略,包数据收集、数据预处理、征工程、模型训练、模型评估等。程中将提供两个示例说明。 数据收集 糖尿病数据可以从各个数据源中获取,如UCI Machine Learning …

    python 2023年5月14日
    00
  • Python实现有趣的亲戚关系计算器

    Python实现有趣的亲戚关系计算器的完整攻略如下: 1. 确定需求 首先需要确定这个亲戚关系计算器需要实现哪些功能。例如,输入两个人的姓名,计算出他们之间的关系,或者输入一个人的姓名和关系,计算出与他有这个关系的所有人。 2. 确认实现方式 在Python中实现亲戚关系计算器,可以使用字典来存储家庭结构,以姓名为键,以对应的父母、兄弟、子女等亲戚关系为值。…

    python 2023年5月14日
    00
  • python numpy和list查询其中某个数的个数及定位方法

    以下是“Python numpy和list查询其中某个数的个数及定位方法”的完整攻略。 1. Python list count方法 在Python中,list是一种常用的数据结构,可以存储任意的数据。list提供了count()方法用来统计list某个元素出现的次数。count()方法的语法如下: .count(element) 其中,list要统计的li…

    python 2023年5月13日
    00
合作推广
合作推广
分享本页
返回顶部