K-近邻算法的python实现代码分享

yizhihongxing

下面是详细讲解“K-近邻算法的Python实现代码分享”的完整攻略。

K-近邻算法

K-近邻算法是一种常用的分类算法,其基本思想是在训练集中找到与测试样本最近的K个样本,然后根据这K个样本的类别投票,将测试样本归为票数最多的类别。

下面是一个Python实现K-近邻算法的示例:

import numpy as np

def knn(X_train, y_train, X_test, k=3):
    distances = []
    for i in range(len(X_train)):
        distance = np.sqrt(np.sum(np.square(X_test - X_train[i, :])))
        distances.append([distance, i])
    distances = sorted(distances)
    k_neighbors = [y_train[distances[i][1]] for i in range(k)]
    return max(k_neighbors, key=k_neighbors.count)

X_train = np.array([[1, 2], [2, 3], [3, 1], [4, 2]])
y_train = np.array([0, 0, 1, 1])
X_test = np.array([3, 2])

prediction = knn(X_train, y_train, X_test, k=3)
print("Prediction: ", prediction)

上述代码中,首先定义了一个knn函数,函数接受训练集X_train、训练集标签y_train、测试集X_test和K值k。在函数中,计算测试样本与训练样本之间的距离,并将距离和训练样本的索引存储在distances列表中。然后,对distances列表进行排序,并选取前K个距离最近的训练样本的标签,将存储在k_neighbors列表中。最后,返回k_neighbors中出现次数最多的标签。

然后,定义了一个训练集X_train、练集标签y_train和测试集X_test。在本例中,训练集包含4个样本,每个样本有2个特征,标签分别为0和1。测试集包含1个样本,也有2个特征。

最后,使用测试集调用knn函数,计算测试样本的标签。

K-近邻算法的优化

K-邻算法的计算复杂度较高,因为需要计算测试样本与所有训练样本之间的距离。为了提高算的效率,可以使用KD树来优化K-近邻算法。

下面是一个使用KD树优化K-近邻算法的Python示例:

from collections import Counter
import numpy as np

class KDTree:
    def __init__(self, data, depth=0):
        if len(data) > 0:
            k = len(data[0])
            axis = depth % k
            sorted_data = sorted(data, key=lambda x: x[axis])
            mid = len(sorted_data) // 2
            self.location = sorted_data[mid]
            self.left_child = KDTree(sorted_data[:mid], depth+1)
            self.right_child = KDTree(sorted_data[mid+1:], depth+1)
        else:
            self.location = None
            self.left_child = None
            self.right_child = None

    def search_knn(self, point, k=3, dist_func=lambda x, y: np.sqrt(np.sum(np.square(x - y)))):
        knn = []
        self._search_knn(point, k, knn, dist_func)
        return [x[1] for x in sorted(knn)]

    def _search_knn(self, point, k, knn, dist_func):
        if self.location is None:
            return
        distance = dist_func(point, self.location)
        if len(knn) < k:
            knn.append((distance, self.location))
        elif distance < knn[-1][0]:
            knn.pop()
            knn.append((distance, self.location))
        axis = len(point) % len(self.location)
        if point[axis] < self.location[axis]:
            self.left_child._search_knn(point, k, knn, dist_func)
        else:
            self.right_child._search_knn(point, k, knn, dist_func)

def knn(X_train, y_train, X_test, k=3):
    tree = KDTree(X_train)
    knn_indices = tree.search_knn(X_test, k=k)
    k_neighbors = [y_train[i] for i in knn_indices]
    return Counter(k_neighbors).most_common(1)[0][0]

X_train = np.array([[1, 2], [2, 3], [3, 1], [4, 2]])
y_train = np.array([0, 0, 1, 1])
X_test = np.array([3, 2])

prediction = knn(X_train, y_train, X_test, k=3)
print("Prediction: ", prediction)

上述代码中,首先定义了一个KDTree类,该类用于构建KD树。在类的构造函数中,根据当前深度选择划分的维度,然后将数据集按照该维度排序,并选择中位数作为当前节点的位置。然后,递归构建左子树和右子树。

然后,定义了一个search_knn方法,该方法用于搜索距离测试样本最近的K个训练样本。在方法中,使用递归搜索KD树,找到距离测试样本最近的K个训练样本,并将其存储在knn列表中。

最后,定义了一个knn函数,该函数接训练集X_train、训练集标签y_train、测试集X_test和K值k。在函数中,使用KD树搜索距离测试样本最近的K个训练样本的标签,并返回出现次数最多的标签。

然后,定义了一个训练集X_train、训练集标签y_train和测试集X_test。在本例中,训练集包4个样本,每个样本有2个特征,标签分别为0和1。测试集包含1个样本,也有2个征。

最后,使用测试集调用knn函数,计算测试样本的标。

总结

K-近邻算法是一种常用的分类算法,可以使用KD树来优化算法的效率。Python中可以使用NumPy库和collections库进行实现。在实现过程中,需要定义KDTree类和knn函数,并使用递归KD树,找到距离测试样本最近的K个训练样本的标签。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:K-近邻算法的python实现代码分享 - Python技术站

(0)
上一篇 2023年5月14日
下一篇 2023年5月14日

相关文章

  • Python中利用all()来优化减少判断的实例分析

    在Python中使用all()函数可以用来简化代码并提高程序效率,本攻略将为大家详细介绍利用all()函数来优化减少判断的实例分析。 1. all()函数的基本用法 all()函数可以接受一个可迭代对象作为参数,返回值为True或False。当参数中所有元素都为True时,all()函数的返回值为True;当参数中存在一个False元素时,all()函数的返…

    python 2023年6月3日
    00
  • Python实战之整蛊神器合集加速友尽

    Python实战之整蛊神器合集加速友尽攻略 背景介绍 在日常生活、工作中,使用整蛊神器来逗乐朋友、增加生活趣味性已经成为一种常见现象。本攻略将向大家分享如何使用Python实现各种有趣的整蛊神器,并加速友谊的建立。 整蛊神器合集 整蛊神器合集是众多有趣的小工具的合集,其中包含了许多既能逗乐朋友,又具有实用价值的小工具,如抢课、获取美女照片等。 攻略讲解 整蛊…

    python 2023年5月23日
    00
  • 一文带你吃透Python中的日期时间模块

    一文带你吃透Python中的日期时间模块 Python中的datetime模块提供了处理日期和时间的标准接口。该模块包含多个类和函数,可以很便捷地进行日期和时间的处理。在这篇文章中,我们将介绍如何使用datetime模块来格式化、解析、计算日期和时间。 获取当前日期和时间 在Python中,我们可以使用datetime模块的datetime类来获取当前的日期…

    python 2023年5月14日
    00
  • Python语言描述KNN算法与Kd树

    下面是关于Python语言描述KNN算法与Kd树的攻略。 KNN算法是什么? KNN算法全称为K-近邻算法,基于特征之间的相似度计算样本之间的距离,进而来进行分类或回归。KNN是一个简单但十分有效的算法,它的主要思想是:新样本到训练样本中距离最近的K个样本的类别来决定它的类别。 KNN算法的应用场景 KNN算法适用于数据比较大、准确度要求不是那么高的场景,比…

    python 2023年6月3日
    00
  • python微信跳一跳系列之棋子定位颜色识别

    下面是“Python微信跳一跳系列之棋子定位颜色识别”的完整攻略。 前言 本攻略是关于使用Python实现微信跳一跳自动玩游戏的系列文章之一,主要介绍棋子定位和颜色识别的方法,用于辅助自动玩游戏。 棋子定位 在跳一跳游戏中,我们利用手机截图并导入电脑后,需要先找到当前界面中棋子所在的位置,从而计算出距离和方向。因此,在Python中需要实现棋子的定位操作。 …

    python 2023年6月6日
    00
  • python可视化分析绘制散点图和边界气泡图

    当我们需要展示数据之间的关系或趋势时,可视化分析是非常有用的工具。散点图和边界气泡图是其中两个常用的表现形式。以下是Python中使用Matplotlib库可视化分析绘制散点图和边界气泡图的完整攻略。 准备工作 在绘制散点图和边界气泡图之前,我们需要安装相关的库。我们可以通过在终端中运行以下命令安装: pip install matplotlib 绘制散点图…

    python 2023年6月3日
    00
  • Python 编码Basic Auth使用方法简单实例

    下面开始讲解“Python 编码Basic Auth使用方法简单实例”的攻略: 1. 什么是Basic Auth Basic Auth 是一种 HTTP 认证机制,它是通过 Authorization 头传递用户名和密码的方式来完成身份验证。在 HTTP 请求头中,Authorization 头的内容格式通常是:“Basic base64(username:…

    python 2023年5月31日
    00
  • 全面了解python字符串和字典

    全面了解Python字符串和字典 字符串 什么是字符串 字符串是在Python中最常用的数据类型之一。它是一个由字符组成的序列。可以使用单引号(‘)或双引号(“)来表示字符串。 示例代码: s1 = "Hello, World!" # 使用双引号来表示字符串 s2 = ‘Hello, World!’ # 使用单引号来表示字符串 print…

    python 2023年5月13日
    00
合作推广
合作推广
分享本页
返回顶部