教你用c++从头开始实现决策树

教你用c++从头开始实现决策树

决策树介绍

决策树是一种树形结构,它可以用于分类和回归分析。在这个树结构中,叶子节点表示分类或回归结果,而其他结点表示基于属性值对数据集进行分割的条件。决策树可以理解为一个类似流程图的结构,在这个流程图中,每次判断输入数据的属性值,然后根据属性值分支到不同的子结点,直到达到某个叶子结点得到分类或回归结果。

构建决策树的过程

决策树的构建过程如下:

  1. 选择最好的数据集划分方式(生成算法);
  2. 创建一个树结点;
  3. 将数据集按照选定的划分方式划分为子数据集;
  4. 对子数据集按照1到3的过程递归地创建结点;
  5. 停止条件:递归建树,直到所有数据被分类或者没有属性可用于分类为止。

实现过程

步骤1:定义数据格式

我们定义一个数据结构来表示样本:

class Sample {
private:
    vector<double> features;
    double label;
public:
    Sample(vector<double> _features, double _label) {
        features = _features;
        label = _label;
    }
    vector<double> getFeatures() {
        return features;
    }
    double getLabel() {
        return label;
    }
};

步骤2:定义计算信息熵的函数

首先我们需要实现计算信息熵的函数,信息熵越小,样本的纯度越高。因此,选择最小熵的划分方式就等价于选择最优化分方式。

double entropy(vector<Sample> samples) {
    map<double, int> labelCount;
    for (auto sample : samples) {
        if (labelCount.find(sample.getLabel()) == labelCount.end()) {
            labelCount[sample.getLabel()] = 0;
        }
        labelCount[sample.getLabel()] += 1;
    }
    double ent = 0.0;
    for (auto item : labelCount) {
        double prob = item.second / (double)samples.size();
        ent -= prob * log2(prob);
    }
    return ent;
}

步骤3:定义寻找最优属性的函数

我们需要寻找最优属性,对数据集进行划分。通过计算每个属性的信息增益,可以得出最优属性。

void splitSamples(vector<Sample> samples, int featIndex, double threshold,
                  vector<Sample>& leftSamples, vector<Sample>& rightSamples) {
    for (auto sample : samples) {
        if (sample.getFeatures()[featIndex] <= threshold) {
            leftSamples.push_back(sample);
        } else {
            rightSamples.push_back(sample);
        }
    }
}

double gain(vector<Sample> samples, int featIndex, double threshold) {
    vector<Sample> leftSamples, rightSamples;
    splitSamples(samples, featIndex, threshold, leftSamples, rightSamples);
    double leftRatio = leftSamples.size() / (double)samples.size();
    double rightRatio = rightSamples.size() / (double)samples.size();

    double gain = entropy(samples) - leftRatio * entropy(leftSamples) - rightRatio * entropy(rightSamples);

    return gain;
}

struct Split {
    int featIndex;
    double threshold;
    double gain;
};

Split findSplit(vector<Sample> samples, double minGain) {
    Split bestSplit;
    bestSplit.gain = -1.0;
    for (int i = 0; i < samples[0].getFeatures().size(); i++) {
        double maxFeat = -1e9, minFeat = 1e9;
        for (auto sample : samples) {
            double feat = sample.getFeatures()[i];
            if (feat > maxFeat) maxFeat = feat;
            if (feat < minFeat) minFeat = feat;
        }
        double step = (maxFeat - minFeat) / 10.0;
        for (double threshold = minFeat; threshold <= maxFeat + step; threshold += step) {
            double currentGain = gain(samples, i, threshold);
            if (currentGain > bestSplit.gain && currentGain > minGain) {
                bestSplit.gain = currentGain;
                bestSplit.featIndex = i;
                bestSplit.threshold = threshold;
            }
        }
    }
    return bestSplit;
}

步骤4:定义建树的函数

现在我们来实现建树的函数,采用递归的方法,对样本进行分割。每个节点包含三个数据成员:样本、阀值(切分标准)、左右子节点。

class Node {
private:
    vector<Sample> samples;
    double threshold;
    Node* leftChild;
    Node* rightChild;
public:
    friend class DecisionTree;
    Node(vector<Sample> _samples, double _threshold) {
        samples = _samples;
        threshold = _threshold;
        leftChild = nullptr;
        rightChild = nullptr;
    }
};

class DecisionTree {
private:
    double minGain;
    int maxHeight;
public:
    DecisionTree(double _minGain, int _maxHeight) {
        minGain = _minGain;
        maxHeight = _maxHeight;
    }

    Node* buildTree(vector<Sample> samples) {
        return buildTreeHelper(samples, 0);
    }

    Node* buildTreeHelper(vector<Sample> samples, int height) {
        if (samples.empty() || (maxHeight > 0 && height > maxHeight)) {
            return nullptr;
        }
        double majorityLabel;
        if (entropy(samples) == 0) {
            majorityLabel = samples[0].getLabel();
            return new Node(samples, majorityLabel);
        }
        Split bestSplit = findSplit(samples, minGain);
        if (bestSplit.gain == -1.0) {
            return nullptr;
        }
        vector<Sample> leftSamples, rightSamples;
        splitSamples(samples, bestSplit.featIndex, bestSplit.threshold, leftSamples, rightSamples);

        Node* node = new Node(samples, bestSplit.threshold);
        node->leftChild = buildTreeHelper(leftSamples, height + 1);
        node->rightChild = buildTreeHelper(rightSamples, height + 1);
        return node;
    }
};

步骤5:用示例检验实现

首先,我们准备一些人造数据集,然后进行训练,最后把测试集的样本输入到训练的决策树,得到分类结果。

void trainTestDecisionTree() {
    vector<Sample> samples = {{vector<double>({0, 0}), 0}, {vector<double>({1, 0}), 1},
                              {vector<double>({0, 1}), 1}, {vector<double>({1, 1}), 1}};
    DecisionTree tree(0.0, 0);
    Node* rootNode = tree.buildTree(samples);
    cout << "Root threshold: " << rootNode->threshold << endl;
    cout << "Left samples: " << endl;
    for (auto sample : rootNode->leftChild->samples) {
        cout << "(" << sample.getFeatures()[0] << ", " << sample.getFeatures()[1] << "): " << sample.getLabel() << endl;
    }
    cout << "Right samples: " << endl;
    for (auto sample : rootNode->rightChild->samples) {
        cout << "(" << sample.getFeatures()[0] << ", " << sample.getFeatures()[1] << "): " << sample.getLabel() << endl;
    }

    Sample testSample(vector<double>({0, 0}), 0);
    Node* node = rootNode;
    while (node) {
        if (node->leftChild == nullptr && node->rightChild == nullptr) {
            cout << "test sample label: " << node->samples[0].getLabel() << endl;
            break;
        }
        double featValue = testSample.getFeatures()[node->samples[0].getFeatures()[node->threshold]];
        if (featValue <= node->threshold) {
            node = node->leftChild;
        } else {
            node = node->rightChild;
        }
    }
}

另一个实例:使用Iris数据集进行分类

为进一步检验决策树的效果,我们使用一组真实的数据集进行分类。我们选择Iris数据集,这个数据集包含150个样本,其中每个样本包含4个特征和1个标签。特征分别为花萼长度(sepal length)、花萼宽度(sepal width)、花瓣长度(petal length)、花瓣宽度(petal width),标签有三种,分别是:山鸢尾(Iris Setosa)、变色鸢尾(Iris Versicolour)、维吉尼亚鸢尾(Iris Virginica)。我们将数据集的前100个样本作为训练集,后50个样本作为测试集,使用决策树对其进行分类。

void irisClassification() {
    vector<Sample> samples;
    ifstream fin("../data/iris.data");
    while (!fin.eof()) {
        vector<double> features(4, 0.0);
        double label = -1;
        fin >> features[0] >> features[1] >> features[2] >> features[3];
        string labelStr;
        fin >> labelStr;
        if (labelStr == "Iris-setosa") {
            label = 0;
        } else if (labelStr == "Iris-versicolor") {
            label = 1;
        } else if (labelStr == "Iris-virginica") {
            label = 2;
        }
        if (label >= 0) {
            samples.push_back({features, label});
        }
    }
    fin.close();
    random_shuffle(samples.begin(), samples.end());

    // split train and test samples
    int trainCnt = 100;
    vector<Sample> trainSamples(samples.begin(), samples.begin() + trainCnt);
    vector<Sample> testSamples(samples.begin() + trainCnt, samples.end());

    DecisionTree tree(0.0, 5);
    Node* rootNode = tree.buildTree(trainSamples);

    int correctCnt = 0;
    for (auto sample : testSamples) {
        Node* node = rootNode;
        while (node) {
            if (node->leftChild == nullptr && node->rightChild == nullptr) {
                if (node->samples[0].getLabel() == sample.getLabel()) {
                    correctCnt += 1;
                }
                break;
            }
            double featValue = sample.getFeatures()[node->samples[0].getFeatures()[node->threshold]];
            if (featValue <= node->threshold) {
                node = node->leftChild;
            } else {
                node = node->rightChild;
            }
        }
    }
    cout << "accuracy: " << correctCnt / (double)testSamples.size() << endl;
}

这里我们使用trainSamples对决策树进行训练,使用testSamples进行测试,最后输出分类准确率。

总结

上述介绍了使用C++从头实现决策树的完整过程,包括数据结构、信息熵计算、寻找最优属性、建树等具体步骤,在样本集较小时能够正常分类,但样本集过大时易趋于复杂,进而泛化能力降低。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:教你用c++从头开始实现决策树 - Python技术站

(0)
上一篇 2023年5月22日
下一篇 2023年5月22日

相关文章

  • C语言中如何进行动态内存分配?

    C语言中的动态内存分配功能是通过函数库和提供的。动态内存分配指的是程序在运行过程中,根据需要在堆区或自由存储区中动态地为变量分配所需的内存空间,使得程序可以根据需要动态地使用内存,从而更加灵活和高效地使用计算机的资源。 在C语言中,动态内存分配的过程可以分为以下三个步骤: 申请内存空间:使用malloc()函数在堆区分配一块适当大小的内存空间。malloc(…

    C 2023年4月27日
    00
  • C语言详解实现猜数字游戏步骤

    C语言详解实现猜数字游戏步骤 在这个攻略中,我们将使用C语言来实现猜数字游戏。首先,让我们讲一下游戏的规则: 游戏开始时,系统会随机生成一个数字在1-100之间。玩家需要猜出这个数字是多少。如果玩家猜错了,系统会提示玩家数字是高还是低。玩家需要不断猜测直到猜对为止。 下面是实现猜数字游戏的完整步骤: 1. 生成随机数 首先,我们需要生成1-100之间的随机数…

    C 2023年5月22日
    00
  • C语言队列和应用详情

    C 语言队列和应用详情 什么是队列 队列是一种数据结构,可以用来存储一组按顺序排列的元素。队列的特点就是先进先出,即First In First Out,缩写为 FIFO。也就是说,最先插入队列的元素会最先被取出,最后插入队列的元素则会最后被取出。常见的生活中队列应用包括的排队取号,排队坐火车,排队打饭等等。 C 语言实现队列 在 C 语言中,我们可以通过数…

    C 2023年5月23日
    00
  • C++中拷贝构造函数的应用详解

    C++中拷贝构造函数的应用详解 什么是拷贝构造函数 在 C++ 中拷贝构造函数是一种特殊的构造函数,其用途是从一个已经存在的对象复制数据到一个新创建的对象中。拷贝构造函数以引用的方式传递源对象并创建新的对象之后,将源对象的值复制到新对象中。拷贝构造函数的形式为 ClassName (const ClassName &obj),其中 obj 是要复制的…

    C 2023年5月22日
    00
  • C/C++百行代码实现热门游戏消消乐功能的示例代码

    C/C++百行代码实现热门游戏消消乐功能的示例代码攻略 简介 消消乐是一款非常流行的益智游戏,其核心游戏玩法是三消规则,在有限的步数内将相同颜色(或形状)的方块消除。本文将通过C/C++语言编写少于100行代码来实现消消乐游戏功能。 实现步骤 第一步:定义方块 我们需要定义游戏中的方块,方块应该包含颜色、形状以及消除状态等属性。具体实现如下: struct …

    C 2023年5月24日
    00
  • C++实现简单酒店管理系统

    C++实现简单酒店管理系统攻略 简介 C++实现简单酒店管理系统是一个典型的控制台应用程序,用于对酒店客房进行预定、入住、退房、查询、统计等操作。 设计 整个酒店管理系统可以分为以下几个部分: 客房类型 客房类型编号 客房类型名称 客房单价 客房信息 客房编号 客房类型 客房状态(已预订、已入住、空闲) 入住人姓名 入住人电话 入住日期 离店日期 订单信息 …

    C 2023年5月23日
    00
  • 解决JSON.parse转化不规范json字符串的问题

    当JSON.parse遇到不规范的JSON字符串时,它将会抛出JSON.parse错误,导致代码无法继续执行。这时可以采用一些技巧和工具来解决这个问题。 1.使用try-catch语句 在JSON.parse方法周围包裹try-catch语句是解决这个问题的一种常见方式。这样如果JSON.parse方法抛出异常,我们就可以在catch语句中捕获这个异常,然后…

    C 2023年5月23日
    00
  • Go语言的数据结构转JSON

    首先,在Go语言中将数据结构转换为JSON格式,需要使用标准库中的encoding/json包。下面是将数据结构转换为JSON的完整攻略: 步骤一:定义你的数据结构 首先,你需要定义一个数据结构,该数据结构将被转换成JSON格式。在这里,我们假设有一个Student结构体,该结构体包含了学生的姓名和年龄信息。 type Student struct { Na…

    C 2023年5月23日
    00
合作推广
合作推广
分享本页
返回顶部