教你用c++从头开始实现决策树

教你用c++从头开始实现决策树

决策树介绍

决策树是一种树形结构,它可以用于分类和回归分析。在这个树结构中,叶子节点表示分类或回归结果,而其他结点表示基于属性值对数据集进行分割的条件。决策树可以理解为一个类似流程图的结构,在这个流程图中,每次判断输入数据的属性值,然后根据属性值分支到不同的子结点,直到达到某个叶子结点得到分类或回归结果。

构建决策树的过程

决策树的构建过程如下:

  1. 选择最好的数据集划分方式(生成算法);
  2. 创建一个树结点;
  3. 将数据集按照选定的划分方式划分为子数据集;
  4. 对子数据集按照1到3的过程递归地创建结点;
  5. 停止条件:递归建树,直到所有数据被分类或者没有属性可用于分类为止。

实现过程

步骤1:定义数据格式

我们定义一个数据结构来表示样本:

class Sample {
private:
    vector<double> features;
    double label;
public:
    Sample(vector<double> _features, double _label) {
        features = _features;
        label = _label;
    }
    vector<double> getFeatures() {
        return features;
    }
    double getLabel() {
        return label;
    }
};

步骤2:定义计算信息熵的函数

首先我们需要实现计算信息熵的函数,信息熵越小,样本的纯度越高。因此,选择最小熵的划分方式就等价于选择最优化分方式。

double entropy(vector<Sample> samples) {
    map<double, int> labelCount;
    for (auto sample : samples) {
        if (labelCount.find(sample.getLabel()) == labelCount.end()) {
            labelCount[sample.getLabel()] = 0;
        }
        labelCount[sample.getLabel()] += 1;
    }
    double ent = 0.0;
    for (auto item : labelCount) {
        double prob = item.second / (double)samples.size();
        ent -= prob * log2(prob);
    }
    return ent;
}

步骤3:定义寻找最优属性的函数

我们需要寻找最优属性,对数据集进行划分。通过计算每个属性的信息增益,可以得出最优属性。

void splitSamples(vector<Sample> samples, int featIndex, double threshold,
                  vector<Sample>& leftSamples, vector<Sample>& rightSamples) {
    for (auto sample : samples) {
        if (sample.getFeatures()[featIndex] <= threshold) {
            leftSamples.push_back(sample);
        } else {
            rightSamples.push_back(sample);
        }
    }
}

double gain(vector<Sample> samples, int featIndex, double threshold) {
    vector<Sample> leftSamples, rightSamples;
    splitSamples(samples, featIndex, threshold, leftSamples, rightSamples);
    double leftRatio = leftSamples.size() / (double)samples.size();
    double rightRatio = rightSamples.size() / (double)samples.size();

    double gain = entropy(samples) - leftRatio * entropy(leftSamples) - rightRatio * entropy(rightSamples);

    return gain;
}

struct Split {
    int featIndex;
    double threshold;
    double gain;
};

Split findSplit(vector<Sample> samples, double minGain) {
    Split bestSplit;
    bestSplit.gain = -1.0;
    for (int i = 0; i < samples[0].getFeatures().size(); i++) {
        double maxFeat = -1e9, minFeat = 1e9;
        for (auto sample : samples) {
            double feat = sample.getFeatures()[i];
            if (feat > maxFeat) maxFeat = feat;
            if (feat < minFeat) minFeat = feat;
        }
        double step = (maxFeat - minFeat) / 10.0;
        for (double threshold = minFeat; threshold <= maxFeat + step; threshold += step) {
            double currentGain = gain(samples, i, threshold);
            if (currentGain > bestSplit.gain && currentGain > minGain) {
                bestSplit.gain = currentGain;
                bestSplit.featIndex = i;
                bestSplit.threshold = threshold;
            }
        }
    }
    return bestSplit;
}

步骤4:定义建树的函数

现在我们来实现建树的函数,采用递归的方法,对样本进行分割。每个节点包含三个数据成员:样本、阀值(切分标准)、左右子节点。

class Node {
private:
    vector<Sample> samples;
    double threshold;
    Node* leftChild;
    Node* rightChild;
public:
    friend class DecisionTree;
    Node(vector<Sample> _samples, double _threshold) {
        samples = _samples;
        threshold = _threshold;
        leftChild = nullptr;
        rightChild = nullptr;
    }
};

class DecisionTree {
private:
    double minGain;
    int maxHeight;
public:
    DecisionTree(double _minGain, int _maxHeight) {
        minGain = _minGain;
        maxHeight = _maxHeight;
    }

    Node* buildTree(vector<Sample> samples) {
        return buildTreeHelper(samples, 0);
    }

    Node* buildTreeHelper(vector<Sample> samples, int height) {
        if (samples.empty() || (maxHeight > 0 && height > maxHeight)) {
            return nullptr;
        }
        double majorityLabel;
        if (entropy(samples) == 0) {
            majorityLabel = samples[0].getLabel();
            return new Node(samples, majorityLabel);
        }
        Split bestSplit = findSplit(samples, minGain);
        if (bestSplit.gain == -1.0) {
            return nullptr;
        }
        vector<Sample> leftSamples, rightSamples;
        splitSamples(samples, bestSplit.featIndex, bestSplit.threshold, leftSamples, rightSamples);

        Node* node = new Node(samples, bestSplit.threshold);
        node->leftChild = buildTreeHelper(leftSamples, height + 1);
        node->rightChild = buildTreeHelper(rightSamples, height + 1);
        return node;
    }
};

步骤5:用示例检验实现

首先,我们准备一些人造数据集,然后进行训练,最后把测试集的样本输入到训练的决策树,得到分类结果。

void trainTestDecisionTree() {
    vector<Sample> samples = {{vector<double>({0, 0}), 0}, {vector<double>({1, 0}), 1},
                              {vector<double>({0, 1}), 1}, {vector<double>({1, 1}), 1}};
    DecisionTree tree(0.0, 0);
    Node* rootNode = tree.buildTree(samples);
    cout << "Root threshold: " << rootNode->threshold << endl;
    cout << "Left samples: " << endl;
    for (auto sample : rootNode->leftChild->samples) {
        cout << "(" << sample.getFeatures()[0] << ", " << sample.getFeatures()[1] << "): " << sample.getLabel() << endl;
    }
    cout << "Right samples: " << endl;
    for (auto sample : rootNode->rightChild->samples) {
        cout << "(" << sample.getFeatures()[0] << ", " << sample.getFeatures()[1] << "): " << sample.getLabel() << endl;
    }

    Sample testSample(vector<double>({0, 0}), 0);
    Node* node = rootNode;
    while (node) {
        if (node->leftChild == nullptr && node->rightChild == nullptr) {
            cout << "test sample label: " << node->samples[0].getLabel() << endl;
            break;
        }
        double featValue = testSample.getFeatures()[node->samples[0].getFeatures()[node->threshold]];
        if (featValue <= node->threshold) {
            node = node->leftChild;
        } else {
            node = node->rightChild;
        }
    }
}

另一个实例:使用Iris数据集进行分类

为进一步检验决策树的效果,我们使用一组真实的数据集进行分类。我们选择Iris数据集,这个数据集包含150个样本,其中每个样本包含4个特征和1个标签。特征分别为花萼长度(sepal length)、花萼宽度(sepal width)、花瓣长度(petal length)、花瓣宽度(petal width),标签有三种,分别是:山鸢尾(Iris Setosa)、变色鸢尾(Iris Versicolour)、维吉尼亚鸢尾(Iris Virginica)。我们将数据集的前100个样本作为训练集,后50个样本作为测试集,使用决策树对其进行分类。

void irisClassification() {
    vector<Sample> samples;
    ifstream fin("../data/iris.data");
    while (!fin.eof()) {
        vector<double> features(4, 0.0);
        double label = -1;
        fin >> features[0] >> features[1] >> features[2] >> features[3];
        string labelStr;
        fin >> labelStr;
        if (labelStr == "Iris-setosa") {
            label = 0;
        } else if (labelStr == "Iris-versicolor") {
            label = 1;
        } else if (labelStr == "Iris-virginica") {
            label = 2;
        }
        if (label >= 0) {
            samples.push_back({features, label});
        }
    }
    fin.close();
    random_shuffle(samples.begin(), samples.end());

    // split train and test samples
    int trainCnt = 100;
    vector<Sample> trainSamples(samples.begin(), samples.begin() + trainCnt);
    vector<Sample> testSamples(samples.begin() + trainCnt, samples.end());

    DecisionTree tree(0.0, 5);
    Node* rootNode = tree.buildTree(trainSamples);

    int correctCnt = 0;
    for (auto sample : testSamples) {
        Node* node = rootNode;
        while (node) {
            if (node->leftChild == nullptr && node->rightChild == nullptr) {
                if (node->samples[0].getLabel() == sample.getLabel()) {
                    correctCnt += 1;
                }
                break;
            }
            double featValue = sample.getFeatures()[node->samples[0].getFeatures()[node->threshold]];
            if (featValue <= node->threshold) {
                node = node->leftChild;
            } else {
                node = node->rightChild;
            }
        }
    }
    cout << "accuracy: " << correctCnt / (double)testSamples.size() << endl;
}

这里我们使用trainSamples对决策树进行训练,使用testSamples进行测试,最后输出分类准确率。

总结

上述介绍了使用C++从头实现决策树的完整过程,包括数据结构、信息熵计算、寻找最优属性、建树等具体步骤,在样本集较小时能够正常分类,但样本集过大时易趋于复杂,进而泛化能力降低。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:教你用c++从头开始实现决策树 - Python技术站

(0)
上一篇 2023年5月22日
下一篇 2023年5月22日

相关文章

  • C 程序 检查整数是正还是负

    C程序检查整数是正还是负 程序概述 这个程序可以检查一个输入的整数是正还是负数。如果输入的整数大于0,则会输出”Positive”,否则输出”Negative”。 程序代码 #include <stdio.h> int main() { int num; printf("Enter a number: "); scanf(&q…

    C 2023年5月9日
    00
  • C++11智能指针之weak_ptr详解

    C++11智能指针之weak_ptr详解 简介 C++11添加了4种智能指针:unique_ptr、shared_ptr、weak_ptr、auto_ptr。其中weak_ptr是一种弱引用类型的指针,它不对所指对象进行引用计数,可以防止 shared_ptr 的循环引用问题。 特点 weak_ptr 所指向的对象可能已经被删除了,因此在使用 weak_pt…

    C 2023年5月22日
    00
  • 如何修改MYSQL5.7.17数据库存储文件的路径

    以下是具体的攻略,分为以下几个步骤: 1. 关闭MySQL数据库 在开始修改MySQL数据库存储文件的路径之前,需要先关闭MySQL数据库,具体操作可以参照以下命令: sudo /etc/init.d/mysql stop 2. 复制原存储文件内容 在进行路径修改之前,需要先将原来的存储文件内容复制到将要修改的路径下,具体操作可以参照以下命令: sudo c…

    C 2023年5月23日
    00
  • Java利用Optional解决空指针异常

    当我们在编写Java代码时,常常会遇到空指针异常(NullPointerException)的情况,这会给我们的程序带来很大的不稳定性和安全性问题。而Java 8中新增的Optional类可以有效地解决这一问题。本文将详细讲解如何利用Optional解决空指针异常。 Optional的介绍 Optional类是Java 8中新增的一个类,可以用来解决空指针异…

    C 2023年5月22日
    00
  • C typedef

    当我们使用C语言开发时,我们可能会遇到一些复杂的数据类型,为了使代码更加简单易读并方便调用这些数据类型,我们可以使用C语言中的typedef关键字来定义自定义的数据类型别名。本文将详细介绍C语言中typedef的使用方法,包括定义基本类型别名和结构体别名等内容。 定义基本类型别名 我们可以使用typedef定义一些基本类型的别名,例如: typedef un…

    C 2023年5月10日
    00
  • windows警告致命错误C0000034 正在更新操作怎么办?

    Windows 警告致命错误 C0000034 正在更新操作怎么办? 如果你在更新 Windows 操作系统时遇到了警告致命错误 C0000034,不要惊慌,下面提供了一些解决方法。 1. 运行自动修复 Windows 系统提供了一个自动修复工具,可以自动修复并纠正一些常见的 Windows 更新问题。具体操作如下: 按下 Windows 键 + X 组合键…

    C 2023年5月23日
    00
  • C语言数组指针表示法

    C语言数组指针表示法是指使用指针访问数组元素的方法。在使用中,我们可以将数组名作为指针使用,指向数组的第一个元素,通过加减指针的偏移量来访问数组中的其他元素。 基本使用方法 定义数组,声明指针 c int a[5] = {1, 2, 3, 4, 5}; int *p; 将数组名作为指针使用,指向数组的第一个元素 c p = &a[0]; // 或者 …

    C 2023年5月9日
    00
  • C++驱动bash的实现代码

    要实现C++驱动bash,我们需要理解两件事情:首先是调用shell命令,其次是获取shell命令的输出。下面是完整的攻略。 调用shell命令 在C++中调用shell命令的最常用的方法是使用system函数。该函数可以在程序中执行给定的命令,并等待该命令完成。例如,在Linux中,我们可以使用以下代码执行ls命令: #include <stdlib…

    C 2023年5月23日
    00
合作推广
合作推广
分享本页
返回顶部