Python手写回归树的实现

Python手写回归树的实现攻略

简介

回归树是一种常用的回归挖掘技术,其基本思想是通过对样本数据的递归划分来建立模型,对于每一次的划分都是基于当前样本集中的某一个特征,根据该特征分裂为若干子集,使得每个子集的目标值尽可能的接近,最终达到建立决策树模型的目的。在本文中,我们将使用 Python 语言手写一个回归树模型,并使用两个实例来说明其基本使用方法和实现效果。

实现步骤

1. 数据准备

首先需要准备好一份样本数据,样本数据至少含有一个特征和一个目标变量。这里我们使用以 2 个自变量和 1 个因变量组成的示例数据,具体如下:

import numpy as np

X = np.array([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 8], [8, 9]])
y = np.array([3, 6, 9, 12, 15, 18, 21, 24])

2. 定义节点类

对于回归树模型,每个树节点都需要保存一些基本信息,例如特征、特征值、目标值等等,因此我们需要定义一个树结点类来存储这些信息。其中,我们需要实现两个主要方法:

  • __init__(self, data):构造器方法,用于初始化节点类,输入参数为数据集。
  • choose_best_split(self, min_sample_split, min_impurity):选择当前数据集上最优的划分点,并返回最优划分点的信息。

树结点类的代码实现如下所示:

class Node:
    def __init__(self, data):
        self.data = data
        self.feature = None  # 分裂特征
        self.threshold = None  # 分裂特征值
        self.left = None  # 左子树
        self.right = None  # 右子树
        self.value = np.mean(data[:, -1])  # 叶子结点预测值

    def choose_best_split(self, min_sample_split, min_impurity):
        """
        选择最优的分裂特征及其对应的阈值
        :param min_sample_split: 样本最小分裂数
        :param min_impurity: 结点最小不纯度
        :return: 划分特征名,划分特征值
        """
        m, n = self.data.shape

        # 结点样本数小于最小分裂数,直接返回
        if m < min_sample_split:
            return None, None

        # 计算当前结点的不纯度,作为最终不纯度的阈值
        impurity = np.var(self.data[:, -1])
        if impurity < min_impurity:
            return None, None

        # 初始化最优划分信息
        best_feature, best_threshold, best_impurity = None, None, float('inf')

        # 对每个特征进行遍历,找出最优划分点
        for col in range(n - 1):
            for row in range(m):
                left = self.data[self.data[:, col] < self.data[row, col]]
                right = self.data[self.data[:, col] >= self.data[row, col]]

                if len(left) < min_sample_split or len(right) < min_sample_split:
                    continue

                impurity = np.var(left[:, -1]) + np.var(right[:, -1])
                if impurity < best_impurity:
                    best_feature = col
                    best_threshold = self.data[row, col]
                    best_impurity = impurity

        # 完成最优划分信息的更新
        if best_feature is not None and best_impurity < np.var(self.data[:, -1]):
            self.feature = best_feature
            self.threshold = best_threshold
            return best_feature, best_threshold
        else:
            return None, None

3. 定义回归树类

完成树结点类的定义之后,我们需要定义一个回归树类来统一管理整个树的构建过程。回归树类需要实现以下方法:

  • __init__(self, min_sample_split, min_impurity):构造器方法,用于初始化树类,输入参数为最小分裂样本数与最小不纯度阈值。
  • build_tree(self, node):递归构建子树的方法,输入参数为当前节点,输出参数为构建好的树。
  • predict(self, X):预测数据的方法,输入参数为待预测数据集,输出参数为预测结果。

注意,在 build_tree 函数中,如果当前节点数据集为空,将会停止递归。同时,在 constructor 中我们需要定义一个节点列表,用于保存回归树的所有叶子节点。回归树的代码实现如下所示:

class RegressionTree:
    def __init__(self, min_sample_split=2, min_impurity=1e-7):
        self.min_sample_split = min_sample_split
        self.min_impurity = min_impurity
        self.root = None  # 回归树的根节点
        self.leaves = []  # 叶子节点列表

    def build_tree(self, node):
        """
        构建子树
        """
        feature, thresh = node.choose_best_split(self.min_sample_split, self.min_impurity)

        # 如果当前结点是叶子结点,将该结点添加到叶子结点列表中
        if feature is None:
            self.leaves.append(node)
            return

        left_indices = node.data[:, feature] < thresh
        left_node = Node(node.data[left_indices, :])
        node.left = left_node

        right_indices = node.data[:, feature] >= thresh
        right_node = Node(node.data[right_indices, :])
        node.right = right_node

        self.build_tree(left_node)
        self.build_tree(right_node)

    def predict(self, X):
        """
        使用训练好的模型对新数据进行预测
        """
        results = []
        for data in X:
            node = self.root

            while node.left:
                if data[node.feature] < node.threshold:
                    node = node.left
                else:
                    node = node.right

            results.append(node.value)

        return results

4. 定义模型训练函数

树的构建过程已经完成,接下来我们需要定义一个模型训练函数,用于对输入数据进行训练,并输出构建好的回归树模型。模型训练函数的代码如下所示:

def train(X, y, min_sample_split=2, min_impurity=1e-7):
    # 初始化回归树
    regression_tree = RegressionTree(min_sample_split=min_sample_split, min_impurity=min_impurity)

    # 构建根节点
    root_node = Node(np.column_stack((X, y)))
    regression_tree.root = root_node

    # 构建树
    regression_tree.build_tree(root_node)

    return regression_tree

5. 测试模型

模型训练函数已经定义完成,我们可以通过以下代码来测试构建出的回归树模型的性能:

regression_tree = train(X, y)
y_pred = regression_tree.predict(X)
print(y_pred)  # 打印预测结果

6. 完成一个更复杂的示例

上述过程已经很好的讲解了手写回归树的基本知识点,下面我们来完成一个更复杂的示例,用于进一步学习回归树的应用。由于现实中的数据集往往比较复杂,因此我们需要使用一个实际数据集。

我们选用 sklearn 内置的波士顿房价数据集(Boston Housing Dataset)来作为我们的数据集,波士顿房价数据集包含了 13 个不同的特征,如 CRIM(人均犯罪率)、ZN(住宅用地占比)等等,目标值为该地区的房屋价格中位数。

使用以下代码导入数据集:

from sklearn.datasets import load_boston

boston = load_boston()
X = boston.data
y = boston.target

接下来,我们可以使用前述的模型训练函数 train 来训练回归树模型,并对其性能进行评估:

regression_tree = train(X, y)
y_pred = regression_tree.predict(X)
print(y_pred)  # 打印预测结果

最后,我们可以使用以下代码显示构建出来的回归树:

from sklearn.tree import export_graphviz
import graphviz

dot_data = export_graphviz(regression_tree.root, out_file=None, filled=True, rounded=True, special_characters=True)
graph = graphviz.Source(dot_data)
graph.render('regression_tree')

7. 小结

通过本文的介绍,我们了解了回归树的基本概念和基本实现方法,同时也学会了 Python 语言中的回归树手写实现方法,并使用了两个例子进行了说明。需要指出的是,本文中的实现并不是最优的,仅供学习和参考。如果需要进行真正的回归分析,应该使用更为专业和更加普遍应用的回归算法库。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Python手写回归树的实现 - Python技术站

(0)
上一篇 2023年6月3日
下一篇 2023年6月3日

相关文章

  • 详解Windows下PyCharm安装Numpy包及无法安装问题解决方案

    详解Windows下PyCharm安装Numpy包及无法安装问题解决方案 介绍 在使用Python开发过程中,Numpy作为一个重要的科学计算库不可或缺。然而,在安装Numpy的过程中,有时会遇到各种问题,导致无法成功安装。本文将针对Windows下使用PyCharm的情况,详细讲解Numpy包的安装及解决无法安装的问题。 安装步骤 1. 配置pip源 使用…

    python 2023年5月13日
    00
  • Python算法应用实战之栈详解

    Python算法应用实战之栈详解 什么是栈? 栈是一种常用的数据结构,它具有后进先出(LIFO)的特点。栈的基本操作包括入栈、出栈、获取栈元素和判断栈是否为空。 Python实现栈的过程 在Python中,可以使用列表来实现栈。以下是使用列表实现栈的示例代码: class Stack: def __init__(self): self.items = [] …

    python 2023年5月13日
    00
  • 基于Python实现实时监控CPU使用率

    我来为你详细讲解“基于Python实现实时监控CPU使用率”的完整攻略。 1. 准备工作 在开始实现之前,需要做好一些准备工作。具体包括: 安装Python:在官网上下载Python的安装包,按照安装向导一步步安装即可。 安装psutil模块:在命令行中输入pip install psutil,安装psutil模块。 2. 实现过程 接下来就开始实现了。具体…

    python 2023年6月3日
    00
  • Python 2.7 发布,并从网站获取结果

    【问题标题】:Python 2.7 posting, and getting result from web sitePython 2.7 发布,并从网站获取结果 【发布时间】:2023-04-06 05:29:01 【问题描述】: 提前感谢您的帮助。我正在尝试编写一个 python 脚本,将 IP 地址发布到下面引用的站点,并在终端或文件中打印出结果,然后…

    Python开发 2023年4月7日
    00
  • python使用百度翻译进行中翻英示例

    这里是Python使用百度翻译进行中翻英示例的攻略。 1. 百度翻译API准备 首先,我们需要去百度翻译API的官网注册一个账号,然后创建一个应用,获取到对应的APP_ID和SECRET_KEY,这两个参数在后续的接口调用中会用到。 2. Python设置 在Python中,我们需要引入requests库进行HTTP请求,引入json库用于将返回的JSON字…

    python 2023年6月5日
    00
  • 使用python的chardet库获得文件编码并修改编码

    使用Python的chardet库可以方便地获取文件编码信息,接着我们可以根据需要进行编码转换。以下是使用chardet库获取文件编码并修改编码的完整攻略。 第一步:安装 chardet 库 在使用chardet库之前,我们需要先安装它。可以通过以下命令在终端或命令提示符中安装: pip install chardet 第二步:获取文件编码 使用charde…

    python 2023年5月31日
    00
  • 解决python3.5 正常安装 却不能直接使用Tkinter包的问题

    针对 Python3.5 正常安装却不能直接使用 Tkinter 包的问题,可以按照以下步骤进行解决: 问题分析 在 Python3.5 中,Tkinter 包已经默认安装,但在某些情况下可能无法正常使用,这是因为 Tkinter 包本身依赖于 Tcl/Tk 库,如果 Tcl/Tk 库没有正确安装或者环境变量没有配置好,Tkinter 包就无法直接使用。 解…

    python 2023年6月13日
    00
  • 超级好用的4个Python命令行可视化库

    下面是关于“超级好用的4个Python命令行可视化库”的完整攻略。 简介 命令行可视化是指在终端中使用图形或者其他方式将数据可视化。在Python中,有很多开源工具可以用于命令行可视化。下面介绍了4个超级好用的Python命令行可视化库,每个库都提供了不同的绘图类型和样式,可根据需求选择合适的库进行使用。 这4个库分别是: curses:一个Python内置…

    python 2023年5月18日
    00
合作推广
合作推广
分享本页
返回顶部