教你用c++从头开始实现决策树
决策树介绍
决策树是一种树形结构,它可以用于分类和回归分析。在这个树结构中,叶子节点表示分类或回归结果,而其他结点表示基于属性值对数据集进行分割的条件。决策树可以理解为一个类似流程图的结构,在这个流程图中,每次判断输入数据的属性值,然后根据属性值分支到不同的子结点,直到达到某个叶子结点得到分类或回归结果。
构建决策树的过程
决策树的构建过程如下:
- 选择最好的数据集划分方式(生成算法);
- 创建一个树结点;
- 将数据集按照选定的划分方式划分为子数据集;
- 对子数据集按照1到3的过程递归地创建结点;
- 停止条件:递归建树,直到所有数据被分类或者没有属性可用于分类为止。
实现过程
步骤1:定义数据格式
我们定义一个数据结构来表示样本:
class Sample {
private:
vector<double> features;
double label;
public:
Sample(vector<double> _features, double _label) {
features = _features;
label = _label;
}
vector<double> getFeatures() {
return features;
}
double getLabel() {
return label;
}
};
步骤2:定义计算信息熵的函数
首先我们需要实现计算信息熵的函数,信息熵越小,样本的纯度越高。因此,选择最小熵的划分方式就等价于选择最优化分方式。
double entropy(vector<Sample> samples) {
map<double, int> labelCount;
for (auto sample : samples) {
if (labelCount.find(sample.getLabel()) == labelCount.end()) {
labelCount[sample.getLabel()] = 0;
}
labelCount[sample.getLabel()] += 1;
}
double ent = 0.0;
for (auto item : labelCount) {
double prob = item.second / (double)samples.size();
ent -= prob * log2(prob);
}
return ent;
}
步骤3:定义寻找最优属性的函数
我们需要寻找最优属性,对数据集进行划分。通过计算每个属性的信息增益,可以得出最优属性。
void splitSamples(vector<Sample> samples, int featIndex, double threshold,
vector<Sample>& leftSamples, vector<Sample>& rightSamples) {
for (auto sample : samples) {
if (sample.getFeatures()[featIndex] <= threshold) {
leftSamples.push_back(sample);
} else {
rightSamples.push_back(sample);
}
}
}
double gain(vector<Sample> samples, int featIndex, double threshold) {
vector<Sample> leftSamples, rightSamples;
splitSamples(samples, featIndex, threshold, leftSamples, rightSamples);
double leftRatio = leftSamples.size() / (double)samples.size();
double rightRatio = rightSamples.size() / (double)samples.size();
double gain = entropy(samples) - leftRatio * entropy(leftSamples) - rightRatio * entropy(rightSamples);
return gain;
}
struct Split {
int featIndex;
double threshold;
double gain;
};
Split findSplit(vector<Sample> samples, double minGain) {
Split bestSplit;
bestSplit.gain = -1.0;
for (int i = 0; i < samples[0].getFeatures().size(); i++) {
double maxFeat = -1e9, minFeat = 1e9;
for (auto sample : samples) {
double feat = sample.getFeatures()[i];
if (feat > maxFeat) maxFeat = feat;
if (feat < minFeat) minFeat = feat;
}
double step = (maxFeat - minFeat) / 10.0;
for (double threshold = minFeat; threshold <= maxFeat + step; threshold += step) {
double currentGain = gain(samples, i, threshold);
if (currentGain > bestSplit.gain && currentGain > minGain) {
bestSplit.gain = currentGain;
bestSplit.featIndex = i;
bestSplit.threshold = threshold;
}
}
}
return bestSplit;
}
步骤4:定义建树的函数
现在我们来实现建树的函数,采用递归的方法,对样本进行分割。每个节点包含三个数据成员:样本、阀值(切分标准)、左右子节点。
class Node {
private:
vector<Sample> samples;
double threshold;
Node* leftChild;
Node* rightChild;
public:
friend class DecisionTree;
Node(vector<Sample> _samples, double _threshold) {
samples = _samples;
threshold = _threshold;
leftChild = nullptr;
rightChild = nullptr;
}
};
class DecisionTree {
private:
double minGain;
int maxHeight;
public:
DecisionTree(double _minGain, int _maxHeight) {
minGain = _minGain;
maxHeight = _maxHeight;
}
Node* buildTree(vector<Sample> samples) {
return buildTreeHelper(samples, 0);
}
Node* buildTreeHelper(vector<Sample> samples, int height) {
if (samples.empty() || (maxHeight > 0 && height > maxHeight)) {
return nullptr;
}
double majorityLabel;
if (entropy(samples) == 0) {
majorityLabel = samples[0].getLabel();
return new Node(samples, majorityLabel);
}
Split bestSplit = findSplit(samples, minGain);
if (bestSplit.gain == -1.0) {
return nullptr;
}
vector<Sample> leftSamples, rightSamples;
splitSamples(samples, bestSplit.featIndex, bestSplit.threshold, leftSamples, rightSamples);
Node* node = new Node(samples, bestSplit.threshold);
node->leftChild = buildTreeHelper(leftSamples, height + 1);
node->rightChild = buildTreeHelper(rightSamples, height + 1);
return node;
}
};
步骤5:用示例检验实现
首先,我们准备一些人造数据集,然后进行训练,最后把测试集的样本输入到训练的决策树,得到分类结果。
void trainTestDecisionTree() {
vector<Sample> samples = {{vector<double>({0, 0}), 0}, {vector<double>({1, 0}), 1},
{vector<double>({0, 1}), 1}, {vector<double>({1, 1}), 1}};
DecisionTree tree(0.0, 0);
Node* rootNode = tree.buildTree(samples);
cout << "Root threshold: " << rootNode->threshold << endl;
cout << "Left samples: " << endl;
for (auto sample : rootNode->leftChild->samples) {
cout << "(" << sample.getFeatures()[0] << ", " << sample.getFeatures()[1] << "): " << sample.getLabel() << endl;
}
cout << "Right samples: " << endl;
for (auto sample : rootNode->rightChild->samples) {
cout << "(" << sample.getFeatures()[0] << ", " << sample.getFeatures()[1] << "): " << sample.getLabel() << endl;
}
Sample testSample(vector<double>({0, 0}), 0);
Node* node = rootNode;
while (node) {
if (node->leftChild == nullptr && node->rightChild == nullptr) {
cout << "test sample label: " << node->samples[0].getLabel() << endl;
break;
}
double featValue = testSample.getFeatures()[node->samples[0].getFeatures()[node->threshold]];
if (featValue <= node->threshold) {
node = node->leftChild;
} else {
node = node->rightChild;
}
}
}
另一个实例:使用Iris数据集进行分类
为进一步检验决策树的效果,我们使用一组真实的数据集进行分类。我们选择Iris数据集,这个数据集包含150个样本,其中每个样本包含4个特征和1个标签。特征分别为花萼长度(sepal length)、花萼宽度(sepal width)、花瓣长度(petal length)、花瓣宽度(petal width),标签有三种,分别是:山鸢尾(Iris Setosa)、变色鸢尾(Iris Versicolour)、维吉尼亚鸢尾(Iris Virginica)。我们将数据集的前100个样本作为训练集,后50个样本作为测试集,使用决策树对其进行分类。
void irisClassification() {
vector<Sample> samples;
ifstream fin("../data/iris.data");
while (!fin.eof()) {
vector<double> features(4, 0.0);
double label = -1;
fin >> features[0] >> features[1] >> features[2] >> features[3];
string labelStr;
fin >> labelStr;
if (labelStr == "Iris-setosa") {
label = 0;
} else if (labelStr == "Iris-versicolor") {
label = 1;
} else if (labelStr == "Iris-virginica") {
label = 2;
}
if (label >= 0) {
samples.push_back({features, label});
}
}
fin.close();
random_shuffle(samples.begin(), samples.end());
// split train and test samples
int trainCnt = 100;
vector<Sample> trainSamples(samples.begin(), samples.begin() + trainCnt);
vector<Sample> testSamples(samples.begin() + trainCnt, samples.end());
DecisionTree tree(0.0, 5);
Node* rootNode = tree.buildTree(trainSamples);
int correctCnt = 0;
for (auto sample : testSamples) {
Node* node = rootNode;
while (node) {
if (node->leftChild == nullptr && node->rightChild == nullptr) {
if (node->samples[0].getLabel() == sample.getLabel()) {
correctCnt += 1;
}
break;
}
double featValue = sample.getFeatures()[node->samples[0].getFeatures()[node->threshold]];
if (featValue <= node->threshold) {
node = node->leftChild;
} else {
node = node->rightChild;
}
}
}
cout << "accuracy: " << correctCnt / (double)testSamples.size() << endl;
}
这里我们使用trainSamples
对决策树进行训练,使用testSamples
进行测试,最后输出分类准确率。
总结
上述介绍了使用C++从头实现决策树的完整过程,包括数据结构、信息熵计算、寻找最优属性、建树等具体步骤,在样本集较小时能够正常分类,但样本集过大时易趋于复杂,进而泛化能力降低。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:教你用c++从头开始实现决策树 - Python技术站