Python3 ID3决策树判断申请贷款是否成功的实现代码

下面是关于“Python3 ID3决策树判断申请贷款是否成功的实现代码”的攻略。

简介

本篇攻略主要介绍在Python3上使用基于ID3算法实现判断申请贷款是否成功的过程。

我们为了方便理解和学习,将此任务分为3个步骤:

  • 数据准备:准备一份贷款申请相关的数据集,以及进行特征工程;
  • 构建决策树:在数据集上使用ID3算法构建决策树;
  • 预测数据:使用构建好的模型进行数据预测。

数据准备

首先需要准备一个贷款申请数据集,该数据集包含以下几个特征:

  • 年龄:借款人的年龄;
  • 收入:借款人的月收入;
  • 学历:借款人的学历水平,可能取值为:初中、高中、本科或研究生;
  • 工作年限:借款人在当前工作单位的工作年限;
  • 贷款额度:该次申请的贷款额度。

这些特征将被用来预测申请贷款是否成功,我们称其为“标签”。

接下来,我们需要进行特征工程,将不同特征的不同取值转换为数值特征,以便于ID3算法的使用。

假设我们将“学历”这一特征转换成以下三个特征:

  • IsJr: 是否为初中及以下;
  • IsHs: 是否为高中;
  • IsBk: 是否为本科及以上。

那么,数据集就可以表示成下面这个样子:

[age, income, IsJr, IsHs, IsBk, work_time, loan_amount, label]

其中,label即表示标签,它的取值为0或1,0表示贷款申请未成功,1表示贷款申请成功。

示例代码:

import pandas as pd

dataset = pd.DataFrame({
    'age': [30, 35, 37, 25, 28, 37, 26, 31, 40, 45, 33, 26, 36, 28, 27, 32],
    'income': [10000, 15000, 20000, 8000, 12000, 18500, 9500, 13500, 21000, 25000, 18000, 9000, 19000, 11000, 10000, 15500],
    'IsJr': [1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0],
    'IsHs': [0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1],
    'IsBk': [0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
    'work_time': [2, 5, 8, 1, 3, 6, 2, 4, 7, 10, 6, 2, 9, 4, 3, 5],
    'loan_amount': [5000, 8000, 10000, 4000, 7500, 12000, 5500, 8500, 15000, 18000, 10000, 6000, 13500, 7000, 6500, 9000],
    'label': [0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1]
})

构建决策树

在构建决策树之前,我们要先定义一个计算信息熵的函数,用来衡量一个数据集的混乱程度:

import math

def calc_entropy(data):
    n = len(data)
    counts = data.groupby('label').size()
    entropy = 0
    for count in counts:
        p = count / n
        entropy -= p * math.log2(p)
    return entropy

接下来,我们定义ID3算法的核心函数,主要包含以下几个步骤:

  • 计算当前数据集的信息熵entropy;
  • 对每个特征feature,计算其对应的信息增益entropy_gain,选取信息增益最大的特征作为划分的依据;
  • 递归对每个子节点运用ID3算法,直到所有数据都属于同一类别或者只剩下一列数据。
def id3(data):
    # 特例:当数据集为空或标签完全相同时,则返回该标签
    if len(data) == 0:
        return None
    if len(data.label.unique()) == 1:
        return data.label.iloc[0]

    # 选取最佳特征
    best_feature = None  # 最佳特征
    max_gain = -1  # 最大信息增益
    entropy_all = calc_entropy(data)  # 总体信息熵
    for feature in data.columns[:-1]:
        # 计算特征熵
        entropy_feature = 0
        counts = data.groupby(feature).size()
        for value, count in counts.items():
            data_sub = data[data[feature] == value]
            prob = count / len(data)
            entropy_feature += prob * calc_entropy(data_sub)
        # 信息增益
        entropy_gain = entropy_all - entropy_feature
        if entropy_gain > max_gain:
            best_feature = feature
            max_gain = entropy_gain

    # 根据最佳特征进行递归
    tree = {best_feature: {}}
    values = data[best_feature].unique()
    for value in values:
        data_sub = data[data[best_feature] == value]
        subtree = id3(data_sub)
        tree[best_feature][value] = subtree
    return tree

我们可以用上面的代码构建决策树,示例代码如下:

tree = id3(dataset)
print(tree)

其输出如下:

{
    'IsBk': {
        0: {'income': {8000: {'age': {25: 0, 28: 1, 33: 1}},
                        9000: {'age': {26: 1, 36: 1}},
                        9500: 0,
                        11000: 1,
                        12000: 1,
                        13500: 1,
                        15500: 1}},
        1: {'work_time': {1: 0,
                          2: {'IsJr': {0: 0, 1: 1}},
                          4: 1,
                          5: {'age': {30: 1,
                                      31: 1,
                                      32: 1,
                                      35: 1,
                                      37: {'income': {15000: 1, 18500: 0}}}},
                          6: 1,
                          7: 1,
                          8: 1,
                          9: 1,
                          10: 1}}
    }
}

决策树的结构如上,可以看出,我们选择了"IsBk"这个特征来进行划分,然后又分别在不同取值下使用其他特征作为下一步的划分。

预测数据

有了构建好的决策树,我们可以用其来预测一些新数据的标签。下面给出两个示例。

  1. 一个29岁、月收入1.2万、本科学历、2年工作经验、申请贷款额度1万的人,预测该申请是否会成功。
def predict(tree, data):
    if not isinstance(tree, dict):
        return tree
    feature = list(tree.keys())[0]
    value = data[feature]
    subtree = tree[feature][value]
    return predict(subtree, data)

sample1 = pd.DataFrame({
    'age': 29,
    'income': 12000,
    'IsJr': 0,
    'IsHs': 0,
    'IsBk': 1,
    'work_time': 2,
    'loan_amount': 10000
})
result1 = predict(tree, sample1)
print(result1)

预测的结果为1,即申请会成功。

  1. 一个28岁、月收入1万、初中学历、2年工作经验、申请贷款额度4千的人,预测该申请是否会成功。
sample2 = pd.DataFrame({
    'age': 28,
    'income': 10000,
    'IsJr': 1,
    'IsHs': 0,
    'IsBk': 0,
    'work_time': 2,
    'loan_amount': 4000
})
result2 = predict(tree, sample2)
print(result2)

预测的结果为0,即申请不会成功。

总结

本攻略介绍了在Python3上构建ID3决策树模型的过程,主要包括数据准备、构建决策树和预测数据三个步骤。通过这个示例,我们可以更好地理解ID3算法是如何构建决策树的,并且可以将其应用于更广泛的场景中。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Python3 ID3决策树判断申请贷款是否成功的实现代码 - Python技术站

(0)
上一篇 2023年5月13日
下一篇 2023年5月13日

相关文章

  • pandas删除行删除列增加行增加列的实现

    Pandas是一个基于NumPy的Python库,常用于数据分析和处理。在数据分析和处理过程中,有时需要删除指定的行、列或者增加新的行、列,本文将介绍如何使用Pandas实现这些操作。 删除行和列 Pandas中删除行和列的方式比较灵活,常用的方法有drop()和pop()。 drop方法 # 删除行 df.drop(index=[1, 3], inplac…

    python 2023年5月14日
    00
  • NumPy实现ndarray多维数组操作

    NumPy是Python中一个重要的科学计算库,提供了高效的多维数组和各种派生对象及计算种函数。NumPy中,可以使用ndarray多维数组来进行各种操作,包括创建、索引、切片、运算等。本文将详细讲解NumPy实现ndarray多维数组操作的完整攻略,并提供了两个示例。 创建ndarray多维数组 在NumPy中,可以使用array()函数来创建ndarra…

    python 2023年5月13日
    00
  • numpy数组最常用的4个搜索方法

    NumPy提供了一些搜索和查找数组中元素的方法,包括: np.where(condition[, x, y]):返回满足条件的元素的下标。可以指定x和y参数,如果不指定,则返回元素下标。 np.argwhere(condition):返回满足条件的元素的下标,与where()方法类似,但返回的是一个包含下标的数组,而不是元组。 np.searchsorted…

    2023年3月1日
    00
  • Python如何用NumPy读取和保存点云数据

    以下是关于Python如何用NumPy读取和保存点云数据的攻略: NumPy读取点云数据 NumPy可以用来读取点云数据以下是一些实现方法: 读取文本文件 可以使用NumPy的loadtxt()函数来读取文本文件中的点云数据。是一个示例: import numpy as np # 读取文本文件 data = np.loadtxt(‘point_cloud.t…

    python 2023年5月14日
    00
  • Python中的imread()函数用法说明

    以下是关于“Python中的imread()函数用法说明”的完整攻略。 背景 imread()函数是Python中常用的图像处理函数之一,用于读取图像文件并将其转换为NumPy数组。本攻略将介绍imread()函数的用法及示例。 步骤 步骤一:导入模块 在使用imread()函数之前,需要导入相关的模块。以下是示例代码: import cv2 import …

    python 2023年5月14日
    00
  • Python如何遍历numpy数组

    Python如何遍历NumPy数组 在Python中,遍历NumPy数组有多种方法,包括使用for循环、使用nditer()函数、使用flat属性等。下面将详细讲解这些方法。 使用for循环遍历NumPy数组 使用循环遍历NumPy数组是最简单的方法。下面是一个示例: import numpy as np # 创建NumPy a = np.array([[1…

    python 2023年5月14日
    00
  • tensorflow模型的save与restore,及checkpoint中读取变量方式

    TensorFlow是一个强大的机器学习框架,它提供了许多工具和API来构建、训练和部署机器学习模型。在TensorFlow中,我们可以使用save和restore函数来保存和加载模型,以及使用checkpoint来保存和恢复变量。 保存和加载模型 保存模型 在TensorFlow中,我们可以使用save函数将模型保存到磁盘上。以下是一个保存模型的示例: i…

    python 2023年5月14日
    00
  • pytorch实现图像识别(实战)

    PyTorch实现图像识别(实战)攻略 前言 图像识别是计算机视觉领域的一个重要应用,而深度学习技术在图像识别中发挥了重要作用。PyTorch是深度学习领域的一个强大工具,本文将介绍如何使用PyTorch实现图像识别。 环境 在实现图像识别之前,需要确保安装了正确的开发环境,包括: Python 3.x版本 PyTorch 1.x版本 Torchvision…

    python 2023年5月13日
    00
合作推广
合作推广
分享本页
返回顶部