Python3 ID3决策树判断申请贷款是否成功的实现代码

yizhihongxing

下面是关于“Python3 ID3决策树判断申请贷款是否成功的实现代码”的攻略。

简介

本篇攻略主要介绍在Python3上使用基于ID3算法实现判断申请贷款是否成功的过程。

我们为了方便理解和学习,将此任务分为3个步骤:

  • 数据准备:准备一份贷款申请相关的数据集,以及进行特征工程;
  • 构建决策树:在数据集上使用ID3算法构建决策树;
  • 预测数据:使用构建好的模型进行数据预测。

数据准备

首先需要准备一个贷款申请数据集,该数据集包含以下几个特征:

  • 年龄:借款人的年龄;
  • 收入:借款人的月收入;
  • 学历:借款人的学历水平,可能取值为:初中、高中、本科或研究生;
  • 工作年限:借款人在当前工作单位的工作年限;
  • 贷款额度:该次申请的贷款额度。

这些特征将被用来预测申请贷款是否成功,我们称其为“标签”。

接下来,我们需要进行特征工程,将不同特征的不同取值转换为数值特征,以便于ID3算法的使用。

假设我们将“学历”这一特征转换成以下三个特征:

  • IsJr: 是否为初中及以下;
  • IsHs: 是否为高中;
  • IsBk: 是否为本科及以上。

那么,数据集就可以表示成下面这个样子:

[age, income, IsJr, IsHs, IsBk, work_time, loan_amount, label]

其中,label即表示标签,它的取值为0或1,0表示贷款申请未成功,1表示贷款申请成功。

示例代码:

import pandas as pd

dataset = pd.DataFrame({
    'age': [30, 35, 37, 25, 28, 37, 26, 31, 40, 45, 33, 26, 36, 28, 27, 32],
    'income': [10000, 15000, 20000, 8000, 12000, 18500, 9500, 13500, 21000, 25000, 18000, 9000, 19000, 11000, 10000, 15500],
    'IsJr': [1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0],
    'IsHs': [0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1],
    'IsBk': [0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
    'work_time': [2, 5, 8, 1, 3, 6, 2, 4, 7, 10, 6, 2, 9, 4, 3, 5],
    'loan_amount': [5000, 8000, 10000, 4000, 7500, 12000, 5500, 8500, 15000, 18000, 10000, 6000, 13500, 7000, 6500, 9000],
    'label': [0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1]
})

构建决策树

在构建决策树之前,我们要先定义一个计算信息熵的函数,用来衡量一个数据集的混乱程度:

import math

def calc_entropy(data):
    n = len(data)
    counts = data.groupby('label').size()
    entropy = 0
    for count in counts:
        p = count / n
        entropy -= p * math.log2(p)
    return entropy

接下来,我们定义ID3算法的核心函数,主要包含以下几个步骤:

  • 计算当前数据集的信息熵entropy;
  • 对每个特征feature,计算其对应的信息增益entropy_gain,选取信息增益最大的特征作为划分的依据;
  • 递归对每个子节点运用ID3算法,直到所有数据都属于同一类别或者只剩下一列数据。
def id3(data):
    # 特例:当数据集为空或标签完全相同时,则返回该标签
    if len(data) == 0:
        return None
    if len(data.label.unique()) == 1:
        return data.label.iloc[0]

    # 选取最佳特征
    best_feature = None  # 最佳特征
    max_gain = -1  # 最大信息增益
    entropy_all = calc_entropy(data)  # 总体信息熵
    for feature in data.columns[:-1]:
        # 计算特征熵
        entropy_feature = 0
        counts = data.groupby(feature).size()
        for value, count in counts.items():
            data_sub = data[data[feature] == value]
            prob = count / len(data)
            entropy_feature += prob * calc_entropy(data_sub)
        # 信息增益
        entropy_gain = entropy_all - entropy_feature
        if entropy_gain > max_gain:
            best_feature = feature
            max_gain = entropy_gain

    # 根据最佳特征进行递归
    tree = {best_feature: {}}
    values = data[best_feature].unique()
    for value in values:
        data_sub = data[data[best_feature] == value]
        subtree = id3(data_sub)
        tree[best_feature][value] = subtree
    return tree

我们可以用上面的代码构建决策树,示例代码如下:

tree = id3(dataset)
print(tree)

其输出如下:

{
    'IsBk': {
        0: {'income': {8000: {'age': {25: 0, 28: 1, 33: 1}},
                        9000: {'age': {26: 1, 36: 1}},
                        9500: 0,
                        11000: 1,
                        12000: 1,
                        13500: 1,
                        15500: 1}},
        1: {'work_time': {1: 0,
                          2: {'IsJr': {0: 0, 1: 1}},
                          4: 1,
                          5: {'age': {30: 1,
                                      31: 1,
                                      32: 1,
                                      35: 1,
                                      37: {'income': {15000: 1, 18500: 0}}}},
                          6: 1,
                          7: 1,
                          8: 1,
                          9: 1,
                          10: 1}}
    }
}

决策树的结构如上,可以看出,我们选择了"IsBk"这个特征来进行划分,然后又分别在不同取值下使用其他特征作为下一步的划分。

预测数据

有了构建好的决策树,我们可以用其来预测一些新数据的标签。下面给出两个示例。

  1. 一个29岁、月收入1.2万、本科学历、2年工作经验、申请贷款额度1万的人,预测该申请是否会成功。
def predict(tree, data):
    if not isinstance(tree, dict):
        return tree
    feature = list(tree.keys())[0]
    value = data[feature]
    subtree = tree[feature][value]
    return predict(subtree, data)

sample1 = pd.DataFrame({
    'age': 29,
    'income': 12000,
    'IsJr': 0,
    'IsHs': 0,
    'IsBk': 1,
    'work_time': 2,
    'loan_amount': 10000
})
result1 = predict(tree, sample1)
print(result1)

预测的结果为1,即申请会成功。

  1. 一个28岁、月收入1万、初中学历、2年工作经验、申请贷款额度4千的人,预测该申请是否会成功。
sample2 = pd.DataFrame({
    'age': 28,
    'income': 10000,
    'IsJr': 1,
    'IsHs': 0,
    'IsBk': 0,
    'work_time': 2,
    'loan_amount': 4000
})
result2 = predict(tree, sample2)
print(result2)

预测的结果为0,即申请不会成功。

总结

本攻略介绍了在Python3上构建ID3决策树模型的过程,主要包括数据准备、构建决策树和预测数据三个步骤。通过这个示例,我们可以更好地理解ID3算法是如何构建决策树的,并且可以将其应用于更广泛的场景中。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Python3 ID3决策树判断申请贷款是否成功的实现代码 - Python技术站

(0)
上一篇 2023年5月13日
下一篇 2023年5月13日

相关文章

  • Python计算库numpy进行方差/标准方差/样本标准方差/协方差的计算

    Python计算库numpy进行方差/标准方差/样本标准方差/协方差的计算 NumPy是Python中一个重要的科学计算库,提供了高效的多维数组和各种派生对象以于计各种函数。其中,方差、标准方差、样本标准方差和协方差是用的统计量,本文将讲解如使用NumPy计算这些统计量。 方差的计算 方差是一组数据其平均数之差的平方和的平均,用于衡量数据的离散程度。在Num…

    python 2023年5月13日
    00
  • python之array赋值技巧分享

    在Python中,数组是一种常见的数据结构,可以用于存储和处理大量数据。在使用数组时,赋值是一个常见的操作。本文将介绍Python中数组的赋值技巧,并提供两个示例。 示例一:使用Python数组的切片赋值 要使用切片赋值,可以使用以下步骤: 导入必要的库 import numpy as np 创建一个数组 arr = np.array([1, 2, 3, 4…

    python 2023年5月14日
    00
  • python基于numpy的线性回归

    以下是关于“Python基于Numpy的线性回归”的完整攻略。 线性回归简介 线性回归是一种常见的机器学习算法,用于建立一个线性模型来预测一个续的输出变量。在线性回归中,我们假设输入变量和输出变量之间存在线性关系,然后使用最小二法来拟合这个线性模型。 Numpy实现线性回归 在Python中,可以使用Numpy库来实现线性回归下面是一个示例代码,演示了如何使…

    python 2023年5月14日
    00
  • python中字符串变二维数组的实例讲解

    在Python中,可以使用字符串的split()方法将字符串按照指定的分隔符分割成一个列表,然后将列表转换为二维数组。本文将详细介绍Python中字符串变维数组的实现方法,并提供两个示例。 示例一:将字符串按行分割成二维数组 假设有一个字符串,其中每包含多个数字,数字之间用空格分。要将这个字符串按行分割成二维数组,可以使用步骤: 1.字符串按行分割成一个列表…

    python 2023年5月14日
    00
  • pandas DataFrame索引行列的实现

    下面是关于“Pandas DataFrame索引行列的实现”的攻略。 Pandas DataFrame的索引 Pandas DataFrame是一种二维表格数据结构,由于其数据处理和分析的便捷性,近年来受到越来越多数据科学家和分析师的青睐。在使用 Pandas DataFrame 时,最常用的方式就是使用索引来定位并处理表格中的数据。 行索引 Pandas …

    python 2023年5月14日
    00
  • Python numpy中的ndarray介绍

    Python Numpy中的ndarray介绍 ndarray是Numpy中一个重要的数据结构,它是一个多维数组,可以用于存储和处理大量的数据。本攻略将详细介绍Python Numpy中的ndarray。 导入Numpy模块 在使用Numpy模块之前,需要先导入它。可以以下命令在Python脚本中导入Numpy模块: import numpy as np 在…

    python 2023年5月13日
    00
  • 利用Pandas和Numpy按时间戳将数据以Groupby方式分组

    在Python中,我们可以使用Pandas和Numpy库按时间戳将数据以Groupby方式分组。本文将详细讲解如何使用Pandas和Numpy库按时间戳将数据以Groupby方式分组,并提供两个示例说明。 导入库 在使用Pandas和Numpy库按时间戳将数据以Groupby方式分组之前,我们需要导入这些库。可以使用以下命令导入这些库: import pan…

    python 2023年5月14日
    00
  • Python 取numpy数组的某几行某几列方法

    Python取numpy数组的某几行某几列方法 在Python中,可以使用numpy库进行数组操作。有时候,我们需要从一个numpy数组中取出某几行或某几列。本文将详细讲解如何使用numpy库取出数组的某几行或某几列,并提供两个示例说明。 1. 取出某几行 在numpy库中,可以使用切片操作取出数组的某几行。以下是一个示例说明: import numpy a…

    python 2023年5月14日
    00
合作推广
合作推广
分享本页
返回顶部