Python手写回归树的实现攻略
简介
回归树是一种常用的回归挖掘技术,其基本思想是通过对样本数据的递归划分来建立模型,对于每一次的划分都是基于当前样本集中的某一个特征,根据该特征分裂为若干子集,使得每个子集的目标值尽可能的接近,最终达到建立决策树模型的目的。在本文中,我们将使用 Python 语言手写一个回归树模型,并使用两个实例来说明其基本使用方法和实现效果。
实现步骤
1. 数据准备
首先需要准备好一份样本数据,样本数据至少含有一个特征和一个目标变量。这里我们使用以 2 个自变量和 1 个因变量组成的示例数据,具体如下:
import numpy as np
X = np.array([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 8], [8, 9]])
y = np.array([3, 6, 9, 12, 15, 18, 21, 24])
2. 定义节点类
对于回归树模型,每个树节点都需要保存一些基本信息,例如特征、特征值、目标值等等,因此我们需要定义一个树结点类来存储这些信息。其中,我们需要实现两个主要方法:
__init__(self, data)
:构造器方法,用于初始化节点类,输入参数为数据集。choose_best_split(self, min_sample_split, min_impurity)
:选择当前数据集上最优的划分点,并返回最优划分点的信息。
树结点类的代码实现如下所示:
class Node:
def __init__(self, data):
self.data = data
self.feature = None # 分裂特征
self.threshold = None # 分裂特征值
self.left = None # 左子树
self.right = None # 右子树
self.value = np.mean(data[:, -1]) # 叶子结点预测值
def choose_best_split(self, min_sample_split, min_impurity):
"""
选择最优的分裂特征及其对应的阈值
:param min_sample_split: 样本最小分裂数
:param min_impurity: 结点最小不纯度
:return: 划分特征名,划分特征值
"""
m, n = self.data.shape
# 结点样本数小于最小分裂数,直接返回
if m < min_sample_split:
return None, None
# 计算当前结点的不纯度,作为最终不纯度的阈值
impurity = np.var(self.data[:, -1])
if impurity < min_impurity:
return None, None
# 初始化最优划分信息
best_feature, best_threshold, best_impurity = None, None, float('inf')
# 对每个特征进行遍历,找出最优划分点
for col in range(n - 1):
for row in range(m):
left = self.data[self.data[:, col] < self.data[row, col]]
right = self.data[self.data[:, col] >= self.data[row, col]]
if len(left) < min_sample_split or len(right) < min_sample_split:
continue
impurity = np.var(left[:, -1]) + np.var(right[:, -1])
if impurity < best_impurity:
best_feature = col
best_threshold = self.data[row, col]
best_impurity = impurity
# 完成最优划分信息的更新
if best_feature is not None and best_impurity < np.var(self.data[:, -1]):
self.feature = best_feature
self.threshold = best_threshold
return best_feature, best_threshold
else:
return None, None
3. 定义回归树类
完成树结点类的定义之后,我们需要定义一个回归树类来统一管理整个树的构建过程。回归树类需要实现以下方法:
__init__(self, min_sample_split, min_impurity)
:构造器方法,用于初始化树类,输入参数为最小分裂样本数与最小不纯度阈值。build_tree(self, node)
:递归构建子树的方法,输入参数为当前节点,输出参数为构建好的树。predict(self, X)
:预测数据的方法,输入参数为待预测数据集,输出参数为预测结果。
注意,在 build_tree 函数中,如果当前节点数据集为空,将会停止递归。同时,在 constructor 中我们需要定义一个节点列表,用于保存回归树的所有叶子节点。回归树的代码实现如下所示:
class RegressionTree:
def __init__(self, min_sample_split=2, min_impurity=1e-7):
self.min_sample_split = min_sample_split
self.min_impurity = min_impurity
self.root = None # 回归树的根节点
self.leaves = [] # 叶子节点列表
def build_tree(self, node):
"""
构建子树
"""
feature, thresh = node.choose_best_split(self.min_sample_split, self.min_impurity)
# 如果当前结点是叶子结点,将该结点添加到叶子结点列表中
if feature is None:
self.leaves.append(node)
return
left_indices = node.data[:, feature] < thresh
left_node = Node(node.data[left_indices, :])
node.left = left_node
right_indices = node.data[:, feature] >= thresh
right_node = Node(node.data[right_indices, :])
node.right = right_node
self.build_tree(left_node)
self.build_tree(right_node)
def predict(self, X):
"""
使用训练好的模型对新数据进行预测
"""
results = []
for data in X:
node = self.root
while node.left:
if data[node.feature] < node.threshold:
node = node.left
else:
node = node.right
results.append(node.value)
return results
4. 定义模型训练函数
树的构建过程已经完成,接下来我们需要定义一个模型训练函数,用于对输入数据进行训练,并输出构建好的回归树模型。模型训练函数的代码如下所示:
def train(X, y, min_sample_split=2, min_impurity=1e-7):
# 初始化回归树
regression_tree = RegressionTree(min_sample_split=min_sample_split, min_impurity=min_impurity)
# 构建根节点
root_node = Node(np.column_stack((X, y)))
regression_tree.root = root_node
# 构建树
regression_tree.build_tree(root_node)
return regression_tree
5. 测试模型
模型训练函数已经定义完成,我们可以通过以下代码来测试构建出的回归树模型的性能:
regression_tree = train(X, y)
y_pred = regression_tree.predict(X)
print(y_pred) # 打印预测结果
6. 完成一个更复杂的示例
上述过程已经很好的讲解了手写回归树的基本知识点,下面我们来完成一个更复杂的示例,用于进一步学习回归树的应用。由于现实中的数据集往往比较复杂,因此我们需要使用一个实际数据集。
我们选用 sklearn 内置的波士顿房价数据集(Boston Housing Dataset)来作为我们的数据集,波士顿房价数据集包含了 13 个不同的特征,如 CRIM(人均犯罪率)、ZN(住宅用地占比)等等,目标值为该地区的房屋价格中位数。
使用以下代码导入数据集:
from sklearn.datasets import load_boston
boston = load_boston()
X = boston.data
y = boston.target
接下来,我们可以使用前述的模型训练函数 train
来训练回归树模型,并对其性能进行评估:
regression_tree = train(X, y)
y_pred = regression_tree.predict(X)
print(y_pred) # 打印预测结果
最后,我们可以使用以下代码显示构建出来的回归树:
from sklearn.tree import export_graphviz
import graphviz
dot_data = export_graphviz(regression_tree.root, out_file=None, filled=True, rounded=True, special_characters=True)
graph = graphviz.Source(dot_data)
graph.render('regression_tree')
7. 小结
通过本文的介绍,我们了解了回归树的基本概念和基本实现方法,同时也学会了 Python 语言中的回归树手写实现方法,并使用了两个例子进行了说明。需要指出的是,本文中的实现并不是最优的,仅供学习和参考。如果需要进行真正的回归分析,应该使用更为专业和更加普遍应用的回归算法库。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Python手写回归树的实现 - Python技术站