教你用c++从头开始实现决策树

教你用c++从头开始实现决策树

决策树介绍

决策树是一种树形结构,它可以用于分类和回归分析。在这个树结构中,叶子节点表示分类或回归结果,而其他结点表示基于属性值对数据集进行分割的条件。决策树可以理解为一个类似流程图的结构,在这个流程图中,每次判断输入数据的属性值,然后根据属性值分支到不同的子结点,直到达到某个叶子结点得到分类或回归结果。

构建决策树的过程

决策树的构建过程如下:

  1. 选择最好的数据集划分方式(生成算法);
  2. 创建一个树结点;
  3. 将数据集按照选定的划分方式划分为子数据集;
  4. 对子数据集按照1到3的过程递归地创建结点;
  5. 停止条件:递归建树,直到所有数据被分类或者没有属性可用于分类为止。

实现过程

步骤1:定义数据格式

我们定义一个数据结构来表示样本:

class Sample {
private:
    vector<double> features;
    double label;
public:
    Sample(vector<double> _features, double _label) {
        features = _features;
        label = _label;
    }
    vector<double> getFeatures() {
        return features;
    }
    double getLabel() {
        return label;
    }
};

步骤2:定义计算信息熵的函数

首先我们需要实现计算信息熵的函数,信息熵越小,样本的纯度越高。因此,选择最小熵的划分方式就等价于选择最优化分方式。

double entropy(vector<Sample> samples) {
    map<double, int> labelCount;
    for (auto sample : samples) {
        if (labelCount.find(sample.getLabel()) == labelCount.end()) {
            labelCount[sample.getLabel()] = 0;
        }
        labelCount[sample.getLabel()] += 1;
    }
    double ent = 0.0;
    for (auto item : labelCount) {
        double prob = item.second / (double)samples.size();
        ent -= prob * log2(prob);
    }
    return ent;
}

步骤3:定义寻找最优属性的函数

我们需要寻找最优属性,对数据集进行划分。通过计算每个属性的信息增益,可以得出最优属性。

void splitSamples(vector<Sample> samples, int featIndex, double threshold,
                  vector<Sample>& leftSamples, vector<Sample>& rightSamples) {
    for (auto sample : samples) {
        if (sample.getFeatures()[featIndex] <= threshold) {
            leftSamples.push_back(sample);
        } else {
            rightSamples.push_back(sample);
        }
    }
}

double gain(vector<Sample> samples, int featIndex, double threshold) {
    vector<Sample> leftSamples, rightSamples;
    splitSamples(samples, featIndex, threshold, leftSamples, rightSamples);
    double leftRatio = leftSamples.size() / (double)samples.size();
    double rightRatio = rightSamples.size() / (double)samples.size();

    double gain = entropy(samples) - leftRatio * entropy(leftSamples) - rightRatio * entropy(rightSamples);

    return gain;
}

struct Split {
    int featIndex;
    double threshold;
    double gain;
};

Split findSplit(vector<Sample> samples, double minGain) {
    Split bestSplit;
    bestSplit.gain = -1.0;
    for (int i = 0; i < samples[0].getFeatures().size(); i++) {
        double maxFeat = -1e9, minFeat = 1e9;
        for (auto sample : samples) {
            double feat = sample.getFeatures()[i];
            if (feat > maxFeat) maxFeat = feat;
            if (feat < minFeat) minFeat = feat;
        }
        double step = (maxFeat - minFeat) / 10.0;
        for (double threshold = minFeat; threshold <= maxFeat + step; threshold += step) {
            double currentGain = gain(samples, i, threshold);
            if (currentGain > bestSplit.gain && currentGain > minGain) {
                bestSplit.gain = currentGain;
                bestSplit.featIndex = i;
                bestSplit.threshold = threshold;
            }
        }
    }
    return bestSplit;
}

步骤4:定义建树的函数

现在我们来实现建树的函数,采用递归的方法,对样本进行分割。每个节点包含三个数据成员:样本、阀值(切分标准)、左右子节点。

class Node {
private:
    vector<Sample> samples;
    double threshold;
    Node* leftChild;
    Node* rightChild;
public:
    friend class DecisionTree;
    Node(vector<Sample> _samples, double _threshold) {
        samples = _samples;
        threshold = _threshold;
        leftChild = nullptr;
        rightChild = nullptr;
    }
};

class DecisionTree {
private:
    double minGain;
    int maxHeight;
public:
    DecisionTree(double _minGain, int _maxHeight) {
        minGain = _minGain;
        maxHeight = _maxHeight;
    }

    Node* buildTree(vector<Sample> samples) {
        return buildTreeHelper(samples, 0);
    }

    Node* buildTreeHelper(vector<Sample> samples, int height) {
        if (samples.empty() || (maxHeight > 0 && height > maxHeight)) {
            return nullptr;
        }
        double majorityLabel;
        if (entropy(samples) == 0) {
            majorityLabel = samples[0].getLabel();
            return new Node(samples, majorityLabel);
        }
        Split bestSplit = findSplit(samples, minGain);
        if (bestSplit.gain == -1.0) {
            return nullptr;
        }
        vector<Sample> leftSamples, rightSamples;
        splitSamples(samples, bestSplit.featIndex, bestSplit.threshold, leftSamples, rightSamples);

        Node* node = new Node(samples, bestSplit.threshold);
        node->leftChild = buildTreeHelper(leftSamples, height + 1);
        node->rightChild = buildTreeHelper(rightSamples, height + 1);
        return node;
    }
};

步骤5:用示例检验实现

首先,我们准备一些人造数据集,然后进行训练,最后把测试集的样本输入到训练的决策树,得到分类结果。

void trainTestDecisionTree() {
    vector<Sample> samples = {{vector<double>({0, 0}), 0}, {vector<double>({1, 0}), 1},
                              {vector<double>({0, 1}), 1}, {vector<double>({1, 1}), 1}};
    DecisionTree tree(0.0, 0);
    Node* rootNode = tree.buildTree(samples);
    cout << "Root threshold: " << rootNode->threshold << endl;
    cout << "Left samples: " << endl;
    for (auto sample : rootNode->leftChild->samples) {
        cout << "(" << sample.getFeatures()[0] << ", " << sample.getFeatures()[1] << "): " << sample.getLabel() << endl;
    }
    cout << "Right samples: " << endl;
    for (auto sample : rootNode->rightChild->samples) {
        cout << "(" << sample.getFeatures()[0] << ", " << sample.getFeatures()[1] << "): " << sample.getLabel() << endl;
    }

    Sample testSample(vector<double>({0, 0}), 0);
    Node* node = rootNode;
    while (node) {
        if (node->leftChild == nullptr && node->rightChild == nullptr) {
            cout << "test sample label: " << node->samples[0].getLabel() << endl;
            break;
        }
        double featValue = testSample.getFeatures()[node->samples[0].getFeatures()[node->threshold]];
        if (featValue <= node->threshold) {
            node = node->leftChild;
        } else {
            node = node->rightChild;
        }
    }
}

另一个实例:使用Iris数据集进行分类

为进一步检验决策树的效果,我们使用一组真实的数据集进行分类。我们选择Iris数据集,这个数据集包含150个样本,其中每个样本包含4个特征和1个标签。特征分别为花萼长度(sepal length)、花萼宽度(sepal width)、花瓣长度(petal length)、花瓣宽度(petal width),标签有三种,分别是:山鸢尾(Iris Setosa)、变色鸢尾(Iris Versicolour)、维吉尼亚鸢尾(Iris Virginica)。我们将数据集的前100个样本作为训练集,后50个样本作为测试集,使用决策树对其进行分类。

void irisClassification() {
    vector<Sample> samples;
    ifstream fin("../data/iris.data");
    while (!fin.eof()) {
        vector<double> features(4, 0.0);
        double label = -1;
        fin >> features[0] >> features[1] >> features[2] >> features[3];
        string labelStr;
        fin >> labelStr;
        if (labelStr == "Iris-setosa") {
            label = 0;
        } else if (labelStr == "Iris-versicolor") {
            label = 1;
        } else if (labelStr == "Iris-virginica") {
            label = 2;
        }
        if (label >= 0) {
            samples.push_back({features, label});
        }
    }
    fin.close();
    random_shuffle(samples.begin(), samples.end());

    // split train and test samples
    int trainCnt = 100;
    vector<Sample> trainSamples(samples.begin(), samples.begin() + trainCnt);
    vector<Sample> testSamples(samples.begin() + trainCnt, samples.end());

    DecisionTree tree(0.0, 5);
    Node* rootNode = tree.buildTree(trainSamples);

    int correctCnt = 0;
    for (auto sample : testSamples) {
        Node* node = rootNode;
        while (node) {
            if (node->leftChild == nullptr && node->rightChild == nullptr) {
                if (node->samples[0].getLabel() == sample.getLabel()) {
                    correctCnt += 1;
                }
                break;
            }
            double featValue = sample.getFeatures()[node->samples[0].getFeatures()[node->threshold]];
            if (featValue <= node->threshold) {
                node = node->leftChild;
            } else {
                node = node->rightChild;
            }
        }
    }
    cout << "accuracy: " << correctCnt / (double)testSamples.size() << endl;
}

这里我们使用trainSamples对决策树进行训练,使用testSamples进行测试,最后输出分类准确率。

总结

上述介绍了使用C++从头实现决策树的完整过程,包括数据结构、信息熵计算、寻找最优属性、建树等具体步骤,在样本集较小时能够正常分类,但样本集过大时易趋于复杂,进而泛化能力降低。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:教你用c++从头开始实现决策树 - Python技术站

(0)
上一篇 2023年5月22日
下一篇 2023年5月22日

相关文章

  • lenovo c4030一体机怎么拆机添加内存条?

    拆卸Lenovo C4030一体机并添加内存条需要进行以下步骤: 步骤一:准备工具和材料 在拆卸和添加内存条之前,请确保您拥有以下工具和材料: 适当大小的螺丝刀 ESD防静电处理工具(可选) 合适的内存条 请注意,添加内存条前请检查您的内存条支持的规格,如DDR3或DDR4,并确保您购买的内存条与您的机器配备的类型匹配。 步骤二:关闭电源并拆下机器后盖 在拆…

    C 2023年5月23日
    00
  • Android的日志系统分层与logcat使用

    Android的日志系统分为四层,分别是: 核心层(kernel):负责底层的输入输出、内存、磁盘、进程等操作。本层日志主要是通过printk等函数输出,并存储在ring buffer中,只有在安卓手机发生严重错误时才需要查看。 系统层(system):包括system_server、Zygote和ActivityManager等系统服务,存放的是系统服务的…

    C 2023年5月24日
    00
  • C++实现学生住宿管理系统

    C++实现学生住宿管理系统攻略 系统介绍 学生住宿管理系统主要功能是管理学生住宿信息,包括学生的基本信息和住宿信息,如宿舍楼、宿舍号、床位号等。该系统可以实现学生住宿信息的增删改查等基本操作,方便学生和管理员进行管理。 系统设计 数据库设计 首先,我们需要设计一个数据库,用来存储学生信息和住宿信息。可以使用MySQL或SQLite等关系型数据库,也可以使用文…

    C 2023年5月23日
    00
  • c++实现LinkBlockedQueue的问题

    让我们来详细讲解“c++实现LinkBlockedQueue的问题”该如何解决。 首先,我们需要阅读题目并理解其中所涉及的术语。“LinkBlockedQueue”是一个队列类,其中“Link”指的是链表,“Blocked”指的是阻塞,即队列为空时,出队操作会一直阻塞等待直到队列中有元素可出队。 接下来,我们可以通过以下步骤实现LinkBlockedQueu…

    C 2023年5月23日
    00
  • c++入门必学算法之快速幂思想及实现

    以下是“C++入门必学算法之快速幂思想及实现”的攻略。 教程概述 快速幂是一种计算幂运算(类似于指数运算)的高效算法。在求解幂运算时,我们通常是采用暴力方法进行连乘,这样的时间复杂度为 $O(n)$,效率较低。而快速幂算法能够在 $O(log_2(n))$ 的时间复杂度内完成幂运算,提高了计算效率。 在本教程中,我们将会介绍快速幂算法的思想和具体实现方法,并…

    C 2023年5月22日
    00
  • PHP简洁函数(PHP简单明了函数语法)

    PHP简洁函数(PHP简单明了函数语法) PHP简洁函数是一种通过使用闭包函数创建匿名函数来减少不必要的代码和提高代码可读性的技术。它允许你在需要的地方定义函数同时避免使用全局变量和函数名冲突的问题。PHP简洁函数的语法非常简单明了,它的形式如下: $func = function($arg1, $arg2, …) { // function body …

    C 2023年5月22日
    00
  • c++隐式类型转换存在的问题解析

    c++隐式类型转换存在的问题解析 什么是c++隐式类型转换 在C++中,隐式类型转换(Implicit Type Conversion)指的是在程序中自动进行的类型转换,而不需要程序员手动调用类型转换函数。隐式类型转换是由C++编译器自动完成的。 例如,我们可以将一个int类型的变量赋值给一个double类型的变量,编译器会自动把int类型转换成double…

    C 2023年5月23日
    00
  • C语言Make命令用法讲解

    C语言Make命令用法讲解 简介 Make命令是一种构建工具,可以用来自动化执行多个编译步骤,从而生成可执行文件,库文件等。在C语言编程中,Make命令可用于自动化编译操作,减少开发者的工作量,提高程序的可维护性。 安装 Make命令在GNU编译器套件(GCC)中自带,因此大多数Linux、Unix系统中已经预安装了Make。在Windows操作系统中,可以…

    C 2023年5月22日
    00
合作推广
合作推广
分享本页
返回顶部