Python手写回归树的实现

yizhihongxing

Python手写回归树的实现攻略

简介

回归树是一种常用的回归挖掘技术,其基本思想是通过对样本数据的递归划分来建立模型,对于每一次的划分都是基于当前样本集中的某一个特征,根据该特征分裂为若干子集,使得每个子集的目标值尽可能的接近,最终达到建立决策树模型的目的。在本文中,我们将使用 Python 语言手写一个回归树模型,并使用两个实例来说明其基本使用方法和实现效果。

实现步骤

1. 数据准备

首先需要准备好一份样本数据,样本数据至少含有一个特征和一个目标变量。这里我们使用以 2 个自变量和 1 个因变量组成的示例数据,具体如下:

import numpy as np

X = np.array([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 8], [8, 9]])
y = np.array([3, 6, 9, 12, 15, 18, 21, 24])

2. 定义节点类

对于回归树模型,每个树节点都需要保存一些基本信息,例如特征、特征值、目标值等等,因此我们需要定义一个树结点类来存储这些信息。其中,我们需要实现两个主要方法:

  • __init__(self, data):构造器方法,用于初始化节点类,输入参数为数据集。
  • choose_best_split(self, min_sample_split, min_impurity):选择当前数据集上最优的划分点,并返回最优划分点的信息。

树结点类的代码实现如下所示:

class Node:
    def __init__(self, data):
        self.data = data
        self.feature = None  # 分裂特征
        self.threshold = None  # 分裂特征值
        self.left = None  # 左子树
        self.right = None  # 右子树
        self.value = np.mean(data[:, -1])  # 叶子结点预测值

    def choose_best_split(self, min_sample_split, min_impurity):
        """
        选择最优的分裂特征及其对应的阈值
        :param min_sample_split: 样本最小分裂数
        :param min_impurity: 结点最小不纯度
        :return: 划分特征名,划分特征值
        """
        m, n = self.data.shape

        # 结点样本数小于最小分裂数,直接返回
        if m < min_sample_split:
            return None, None

        # 计算当前结点的不纯度,作为最终不纯度的阈值
        impurity = np.var(self.data[:, -1])
        if impurity < min_impurity:
            return None, None

        # 初始化最优划分信息
        best_feature, best_threshold, best_impurity = None, None, float('inf')

        # 对每个特征进行遍历,找出最优划分点
        for col in range(n - 1):
            for row in range(m):
                left = self.data[self.data[:, col] < self.data[row, col]]
                right = self.data[self.data[:, col] >= self.data[row, col]]

                if len(left) < min_sample_split or len(right) < min_sample_split:
                    continue

                impurity = np.var(left[:, -1]) + np.var(right[:, -1])
                if impurity < best_impurity:
                    best_feature = col
                    best_threshold = self.data[row, col]
                    best_impurity = impurity

        # 完成最优划分信息的更新
        if best_feature is not None and best_impurity < np.var(self.data[:, -1]):
            self.feature = best_feature
            self.threshold = best_threshold
            return best_feature, best_threshold
        else:
            return None, None

3. 定义回归树类

完成树结点类的定义之后,我们需要定义一个回归树类来统一管理整个树的构建过程。回归树类需要实现以下方法:

  • __init__(self, min_sample_split, min_impurity):构造器方法,用于初始化树类,输入参数为最小分裂样本数与最小不纯度阈值。
  • build_tree(self, node):递归构建子树的方法,输入参数为当前节点,输出参数为构建好的树。
  • predict(self, X):预测数据的方法,输入参数为待预测数据集,输出参数为预测结果。

注意,在 build_tree 函数中,如果当前节点数据集为空,将会停止递归。同时,在 constructor 中我们需要定义一个节点列表,用于保存回归树的所有叶子节点。回归树的代码实现如下所示:

class RegressionTree:
    def __init__(self, min_sample_split=2, min_impurity=1e-7):
        self.min_sample_split = min_sample_split
        self.min_impurity = min_impurity
        self.root = None  # 回归树的根节点
        self.leaves = []  # 叶子节点列表

    def build_tree(self, node):
        """
        构建子树
        """
        feature, thresh = node.choose_best_split(self.min_sample_split, self.min_impurity)

        # 如果当前结点是叶子结点,将该结点添加到叶子结点列表中
        if feature is None:
            self.leaves.append(node)
            return

        left_indices = node.data[:, feature] < thresh
        left_node = Node(node.data[left_indices, :])
        node.left = left_node

        right_indices = node.data[:, feature] >= thresh
        right_node = Node(node.data[right_indices, :])
        node.right = right_node

        self.build_tree(left_node)
        self.build_tree(right_node)

    def predict(self, X):
        """
        使用训练好的模型对新数据进行预测
        """
        results = []
        for data in X:
            node = self.root

            while node.left:
                if data[node.feature] < node.threshold:
                    node = node.left
                else:
                    node = node.right

            results.append(node.value)

        return results

4. 定义模型训练函数

树的构建过程已经完成,接下来我们需要定义一个模型训练函数,用于对输入数据进行训练,并输出构建好的回归树模型。模型训练函数的代码如下所示:

def train(X, y, min_sample_split=2, min_impurity=1e-7):
    # 初始化回归树
    regression_tree = RegressionTree(min_sample_split=min_sample_split, min_impurity=min_impurity)

    # 构建根节点
    root_node = Node(np.column_stack((X, y)))
    regression_tree.root = root_node

    # 构建树
    regression_tree.build_tree(root_node)

    return regression_tree

5. 测试模型

模型训练函数已经定义完成,我们可以通过以下代码来测试构建出的回归树模型的性能:

regression_tree = train(X, y)
y_pred = regression_tree.predict(X)
print(y_pred)  # 打印预测结果

6. 完成一个更复杂的示例

上述过程已经很好的讲解了手写回归树的基本知识点,下面我们来完成一个更复杂的示例,用于进一步学习回归树的应用。由于现实中的数据集往往比较复杂,因此我们需要使用一个实际数据集。

我们选用 sklearn 内置的波士顿房价数据集(Boston Housing Dataset)来作为我们的数据集,波士顿房价数据集包含了 13 个不同的特征,如 CRIM(人均犯罪率)、ZN(住宅用地占比)等等,目标值为该地区的房屋价格中位数。

使用以下代码导入数据集:

from sklearn.datasets import load_boston

boston = load_boston()
X = boston.data
y = boston.target

接下来,我们可以使用前述的模型训练函数 train 来训练回归树模型,并对其性能进行评估:

regression_tree = train(X, y)
y_pred = regression_tree.predict(X)
print(y_pred)  # 打印预测结果

最后,我们可以使用以下代码显示构建出来的回归树:

from sklearn.tree import export_graphviz
import graphviz

dot_data = export_graphviz(regression_tree.root, out_file=None, filled=True, rounded=True, special_characters=True)
graph = graphviz.Source(dot_data)
graph.render('regression_tree')

7. 小结

通过本文的介绍,我们了解了回归树的基本概念和基本实现方法,同时也学会了 Python 语言中的回归树手写实现方法,并使用了两个例子进行了说明。需要指出的是,本文中的实现并不是最优的,仅供学习和参考。如果需要进行真正的回归分析,应该使用更为专业和更加普遍应用的回归算法库。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Python手写回归树的实现 - Python技术站

(0)
上一篇 2023年6月3日
下一篇 2023年6月3日

相关文章

  • selenium3+python3环境搭建教程图解

    本文将详细讲解如何搭建selenium3+python3环境,并结合两个示例进行说明。 环境要求 在开始搭建之前,请确保您已经安装以下环境: Python3(建议使用3.6以上版本) pip3(Python包管理器) Chrome浏览器(建议使用最新版本) Chrome Driver(用于控制Chrome浏览器,需要和Chrome浏览器版本对应) 安装sel…

    python 2023年5月19日
    00
  • Python requests发送post请求的一些疑点

    以下是关于Python requests发送POST请求的一些疑点的攻略: Python requests发送POST请求的一些疑点 在使用Python requests发送POST请求时,可能会遇到一些疑点。以下是Python requests发送POST请求的一些疑点的攻略。 POST请求的请求体 在发送POST请求时,需要设置请求体。以下是设置POST…

    python 2023年5月14日
    00
  • Python如何用str.format()批量生成网址(豆瓣读书为例)

    要批量生成网址,我们可以使用Python中的 str.format() 方法。该方法可以让我们轻松生成一个字符串,其中可以插入一些占位符,以便我们在后面再填充数据。 下面我们以豆瓣读书为例,详细介绍如何使用 str.format() 方法来批量生成豆瓣读书的书籍网址。 第一步:定义网址模板 在生成网址之前,我们需要定义一个网址模板,用于指定网址的格式。以豆瓣…

    python 2023年5月18日
    00
  • Virtualenv 搭建 Py项目运行环境的教程详解

    Virtualenv搭建Py项目运行环境的教程详解 在本攻略中,我们将介绍如何使用Virtualenv搭建Python项目的运行环境。Virtualenv是一个用于创建Python虚拟环境的工具,它可以帮助我们在同一台机器上管理多个Python项目,并且可以避免不同项目之间的依赖冲突。 步骤1:安装Virtualenv 在使用Virtualenv之前,我们需…

    python 2023年5月15日
    00
  • Python实现的对一个数进行因式分解操作示例

    对一个数进行因式分解是数学中的一个重要问题,Python可以很方便地实现这个操作。本文将介绍Python实现对一个数进行因式分解完整攻略,包括两个示例说明。 1. 基本思路 对一个数进行因式分解的基本思路是,从2开始,不断地将这个数除以最小的质因数,直到这个数变成1为止。具体实现如下: def factorize(n): factors = [] i = 2…

    python 2023年5月14日
    00
  • python创建进程fork用法

    Python创建进程可以使用fork()方法,该方法可以复制主进程,生成新的进程,并让主进程和新进程同时运行。下面是Python创建进程fork用法的完整攻略,包含以下内容: fork()的使用方法 父子进程的区别 示例说明 1. fork()的使用方法 使用fork()方法需要先导入os模块。Python中的fork()函数会复制当前进程,父进程和子进程都…

    python 2023年5月30日
    00
  • python实现图书管理系统

    Python实现图书管理系统攻略 一、概述 图书管理系统是一个常见的管理软件,它可以用来管理图书信息,包括图书的编号、名称、作者、出版社、价格等信息。本文将介绍如何使用Python语言实现一个简单的图书管理系统。 图书管理系统主要有以下功能: 添加图书 删除图书 修改图书信息 查询图书信息 显示所有图书信息 二、程序设计 1. 数据结构设计 使用Python…

    python 2023年5月30日
    00
  • 如何使用Python实现数据库中数据的聚合查询?

    以下是使用Python实现数据库中数据的聚合查询的完整攻略。 数据库中数据的聚合查询简介 在数据库中,数据的聚合查询是指对数据进行统计分析,如计算平均值、最大值、最小值、总和等。在Python中可以使用pymysql库实现数据库中数据的聚合查询。 步骤1:连接到数据库 在Python中使用pymysql库连接到MySQL。以下是连接到MySQL数据库的基本语…

    python 2023年5月12日
    00
合作推广
合作推广
分享本页
返回顶部