Python手写回归树的实现

Python手写回归树的实现攻略

简介

回归树是一种常用的回归挖掘技术,其基本思想是通过对样本数据的递归划分来建立模型,对于每一次的划分都是基于当前样本集中的某一个特征,根据该特征分裂为若干子集,使得每个子集的目标值尽可能的接近,最终达到建立决策树模型的目的。在本文中,我们将使用 Python 语言手写一个回归树模型,并使用两个实例来说明其基本使用方法和实现效果。

实现步骤

1. 数据准备

首先需要准备好一份样本数据,样本数据至少含有一个特征和一个目标变量。这里我们使用以 2 个自变量和 1 个因变量组成的示例数据,具体如下:

import numpy as np

X = np.array([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 8], [8, 9]])
y = np.array([3, 6, 9, 12, 15, 18, 21, 24])

2. 定义节点类

对于回归树模型,每个树节点都需要保存一些基本信息,例如特征、特征值、目标值等等,因此我们需要定义一个树结点类来存储这些信息。其中,我们需要实现两个主要方法:

  • __init__(self, data):构造器方法,用于初始化节点类,输入参数为数据集。
  • choose_best_split(self, min_sample_split, min_impurity):选择当前数据集上最优的划分点,并返回最优划分点的信息。

树结点类的代码实现如下所示:

class Node:
    def __init__(self, data):
        self.data = data
        self.feature = None  # 分裂特征
        self.threshold = None  # 分裂特征值
        self.left = None  # 左子树
        self.right = None  # 右子树
        self.value = np.mean(data[:, -1])  # 叶子结点预测值

    def choose_best_split(self, min_sample_split, min_impurity):
        """
        选择最优的分裂特征及其对应的阈值
        :param min_sample_split: 样本最小分裂数
        :param min_impurity: 结点最小不纯度
        :return: 划分特征名,划分特征值
        """
        m, n = self.data.shape

        # 结点样本数小于最小分裂数,直接返回
        if m < min_sample_split:
            return None, None

        # 计算当前结点的不纯度,作为最终不纯度的阈值
        impurity = np.var(self.data[:, -1])
        if impurity < min_impurity:
            return None, None

        # 初始化最优划分信息
        best_feature, best_threshold, best_impurity = None, None, float('inf')

        # 对每个特征进行遍历,找出最优划分点
        for col in range(n - 1):
            for row in range(m):
                left = self.data[self.data[:, col] < self.data[row, col]]
                right = self.data[self.data[:, col] >= self.data[row, col]]

                if len(left) < min_sample_split or len(right) < min_sample_split:
                    continue

                impurity = np.var(left[:, -1]) + np.var(right[:, -1])
                if impurity < best_impurity:
                    best_feature = col
                    best_threshold = self.data[row, col]
                    best_impurity = impurity

        # 完成最优划分信息的更新
        if best_feature is not None and best_impurity < np.var(self.data[:, -1]):
            self.feature = best_feature
            self.threshold = best_threshold
            return best_feature, best_threshold
        else:
            return None, None

3. 定义回归树类

完成树结点类的定义之后,我们需要定义一个回归树类来统一管理整个树的构建过程。回归树类需要实现以下方法:

  • __init__(self, min_sample_split, min_impurity):构造器方法,用于初始化树类,输入参数为最小分裂样本数与最小不纯度阈值。
  • build_tree(self, node):递归构建子树的方法,输入参数为当前节点,输出参数为构建好的树。
  • predict(self, X):预测数据的方法,输入参数为待预测数据集,输出参数为预测结果。

注意,在 build_tree 函数中,如果当前节点数据集为空,将会停止递归。同时,在 constructor 中我们需要定义一个节点列表,用于保存回归树的所有叶子节点。回归树的代码实现如下所示:

class RegressionTree:
    def __init__(self, min_sample_split=2, min_impurity=1e-7):
        self.min_sample_split = min_sample_split
        self.min_impurity = min_impurity
        self.root = None  # 回归树的根节点
        self.leaves = []  # 叶子节点列表

    def build_tree(self, node):
        """
        构建子树
        """
        feature, thresh = node.choose_best_split(self.min_sample_split, self.min_impurity)

        # 如果当前结点是叶子结点,将该结点添加到叶子结点列表中
        if feature is None:
            self.leaves.append(node)
            return

        left_indices = node.data[:, feature] < thresh
        left_node = Node(node.data[left_indices, :])
        node.left = left_node

        right_indices = node.data[:, feature] >= thresh
        right_node = Node(node.data[right_indices, :])
        node.right = right_node

        self.build_tree(left_node)
        self.build_tree(right_node)

    def predict(self, X):
        """
        使用训练好的模型对新数据进行预测
        """
        results = []
        for data in X:
            node = self.root

            while node.left:
                if data[node.feature] < node.threshold:
                    node = node.left
                else:
                    node = node.right

            results.append(node.value)

        return results

4. 定义模型训练函数

树的构建过程已经完成,接下来我们需要定义一个模型训练函数,用于对输入数据进行训练,并输出构建好的回归树模型。模型训练函数的代码如下所示:

def train(X, y, min_sample_split=2, min_impurity=1e-7):
    # 初始化回归树
    regression_tree = RegressionTree(min_sample_split=min_sample_split, min_impurity=min_impurity)

    # 构建根节点
    root_node = Node(np.column_stack((X, y)))
    regression_tree.root = root_node

    # 构建树
    regression_tree.build_tree(root_node)

    return regression_tree

5. 测试模型

模型训练函数已经定义完成,我们可以通过以下代码来测试构建出的回归树模型的性能:

regression_tree = train(X, y)
y_pred = regression_tree.predict(X)
print(y_pred)  # 打印预测结果

6. 完成一个更复杂的示例

上述过程已经很好的讲解了手写回归树的基本知识点,下面我们来完成一个更复杂的示例,用于进一步学习回归树的应用。由于现实中的数据集往往比较复杂,因此我们需要使用一个实际数据集。

我们选用 sklearn 内置的波士顿房价数据集(Boston Housing Dataset)来作为我们的数据集,波士顿房价数据集包含了 13 个不同的特征,如 CRIM(人均犯罪率)、ZN(住宅用地占比)等等,目标值为该地区的房屋价格中位数。

使用以下代码导入数据集:

from sklearn.datasets import load_boston

boston = load_boston()
X = boston.data
y = boston.target

接下来,我们可以使用前述的模型训练函数 train 来训练回归树模型,并对其性能进行评估:

regression_tree = train(X, y)
y_pred = regression_tree.predict(X)
print(y_pred)  # 打印预测结果

最后,我们可以使用以下代码显示构建出来的回归树:

from sklearn.tree import export_graphviz
import graphviz

dot_data = export_graphviz(regression_tree.root, out_file=None, filled=True, rounded=True, special_characters=True)
graph = graphviz.Source(dot_data)
graph.render('regression_tree')

7. 小结

通过本文的介绍,我们了解了回归树的基本概念和基本实现方法,同时也学会了 Python 语言中的回归树手写实现方法,并使用了两个例子进行了说明。需要指出的是,本文中的实现并不是最优的,仅供学习和参考。如果需要进行真正的回归分析,应该使用更为专业和更加普遍应用的回归算法库。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Python手写回归树的实现 - Python技术站

(0)
上一篇 2023年6月3日
下一篇 2023年6月3日

相关文章

  • python 中文乱码问题深入分析

    下面是对于“Python 中文乱码问题深入分析”的完整攻略: Python 中文乱码问题深入分析 在使用 Python 进行中文编程或中文文本处理时,一旦遇到中文乱码问题,就会给开发工作带来很大的不便。本文将从字符编码和环境设置两个层面,深入分析 Python 中文乱码问题的影响原因及解决方案。 字符编码的影响 在 Python 中,文本处理涉及到两个重要的…

    python 2023年5月13日
    00
  • Python3将数据保存为txt文件的方法

    下面是Python3将数据保存为txt文件的完整攻略: 步骤一:打开并写入文件 首先,需要使用Python内置的open()函数来打开一个txt文件。open()函数的第一个参数是文件名(包括文件路径),第二个参数是打开模式(读写模式)。在这里,我们需要使用写入模式’w’来打开文件并写入数据。假设我们要将数据保存到名为data.txt的文件中,可以使用以下代…

    python 2023年6月2日
    00
  • Python的文本常量与字符串模板之string库

    Python的文本常量与字符串模板之string库 在Python中,文本处理是一个非常常见的任务。Python提供了多种处理文本的方法和库,其中包括string库。string库提供了多种文本常量和字符串模板,可以方便地处理文本。本文将总结Python的文本常量与字符串模板之string库的使用方法,并提供两个示例说明。 文本常量 string库提供了多个…

    python 2023年5月14日
    00
  • 基于Python的接口测试框架实例

    在Python中,我们可以使用接口测试框架进行接口测试。本文将介绍如何基于Python实现接口测试框架,并提供两个示例。 1. 使用unittest框架进行接口测试 我们可以使用unittest框架进行接口测试。以下是一个示例,演示如何使用unittest框架进行接口测试: import unittest import requests class Test…

    python 2023年5月15日
    00
  • 在python的嵌套循环中嵌套打印

    【问题标题】:Nested print in a nested loop in python在python的嵌套循环中嵌套打印 【发布时间】:2023-04-06 20:25:02 【问题描述】: 如何创建在两个 for 循环中创建的输出? 我想要什么: Name1 Adress1 Name2 Adress2 .. 我得到了什么: Name1 Name2 A…

    Python开发 2023年4月7日
    00
  • Python使用re模块实现正则表达式操作指南

    Python使用re模块实现正则表达式操作指南 正则表达式是一种强大的文本处理工具,可以用于各种文本处理,如数据清洗、文本分析、信息提取等。在Python中可以使用re模块来操作正则表达式。本攻略将详细讲解Python使用re模块实现正则表达式操作的指南,包括正则表达式的基本语法、常用函数和应用技巧。 正则表达式的基本语法 正则表达式由普通字符和元字符组成,…

    python 2023年5月14日
    00
  • Python 分析访问细节

    Python可以利用各种库和工具对网站的访问细节进行分析和解析,以了解有关网站性能和使用情况的详细信息。本文将介绍使用Python进行网站访问分析的完整攻略。 准备工作 在开始Python分析网站访问细节之前,需要安装并导入必要的库和工具。常用的库和工具包括: requests:发送HTTP请求以获取访问网站的响应。 Beautiful Soup:解析HTM…

    python-answer 2023年3月25日
    00
  • 解决node-sass下载不成功的问题

    下面是解决node-sass下载不成功的完整攻略: 问题分析 node-sass是一个Node.js扩展模块,用于编译Sass和Scss文件,但是在安装node-sass包时,很容易遇到下载失败的问题。这主要是因为node-sass依赖于Libsass,而Libsass是用C++编写的,需要先进行编译。 在安装node-sass时,npm会自动尝试编译Lib…

    python 2023年5月13日
    00
合作推广
合作推广
分享本页
返回顶部