python实现KNN近邻算法

yizhihongxing

让我来详细讲解一下“python实现KNN近邻算法”的完整攻略。

什么是KNN近邻算法

KNN近邻算法是机器学习领域中的一个简单、易懂、易于实现的算法。它主要用于分类问题,通过找到最近邻的K个数据点来决定新数据点所属的类别。KNN算法的基础思想是:样本之间的距离越近,它们所属的类别往往越相似。

KNN近邻算法的实现

KNN近邻算法的实现过程主要分为下面几个步骤:

  1. 准备数据:从数据集中提取特征值和目标值,分离出待预测的新数据点。

  2. 距离计算:计算新数据点与数据集中每个数据点的距离,选出距离最近的K个点。

  3. 找出K个点属于哪个类别:根据K个最近邻的类别,通过投票的方式决定新数据点的类别。

接下来让我们结合两个示例对这个过程进行说明。

示例1:鸢尾花分类

鸢尾花分类是一个比较经典的机器学习问题,数据集中包含三种不同类型的鸢尾花。我们来看一下如何用KNN近邻算法对鸢尾花进行分类。

首先我们需要准备数据。我们使用sklearn库中的load_iris函数来加载鸢尾花数据集,然后将数据集分为训练集和测试集。

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

iris = load_iris()
X = iris.data
y = iris.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)

接下来是距离计算。我们需要计算测试集中每个点与训练集中所有点的距离,然后选出最近的K个点。

from scipy.spatial.distance import cdist
import numpy as np

def knn(X_train, X_test, y_train, k=3):
    dist = cdist(X_test, X_train)  # 计算距离矩阵
    indices = np.argsort(dist, axis=1)  # 按距离升序排序得到索引矩阵
    knn_indices = indices[:, :k]  # 取前k个最近邻的索引
    return knn_indices

knn_indices = knn(X_train, X_test, y_train, k=3)

最后一步是找出K个点属于哪个类别。我们可以通过对K个点所属的类别进行投票,选择出投票数最多的类别作为预测结果。

def predict(X_train, y_train, knn_indices):
    knn_labels = y_train[knn_indices]  # 取出k个最近邻的类别标签
    pred_labels = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=1, arr=knn_labels)
    return pred_labels

y_pred = predict(X_train, y_train, knn_indices)

我们可以使用sklearn库中的metrics模块对我们的分类结果进行评估。

from sklearn.metrics import classification_report

print(classification_report(y_test, y_pred))

输出结果如下:

             precision    recall  f1-score   support

          0       1.00      1.00      1.00        13
          1       0.91      1.00      0.95        20
          2       1.00      0.86      0.92        14

avg / total       0.96      0.96      0.96        47

示例2:手写数字识别

我们再来看一个更加具体的例子,使用KNN近邻算法对手写数字进行分类。我们先加载手写数字数据集并将其分为训练集和测试集。

from sklearn.datasets import load_digits

digits = load_digits()
X = digits.data
y = digits.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)

接下来是距离计算。我们同样需要计算测试集中每个点与训练集中所有点的距离,然后选出最近的K个点。

knn_indices = knn(X_train, X_test, y_train, k=5)

最后一步是找出K个点属于哪个类别。

y_pred = predict(X_train, y_train, knn_indices)

我们依旧使用sklearn中的metrics模块对我们的分类结果进行评估。

print(classification_report(y_test, y_pred))

输出结果如下:

             precision    recall  f1-score   support

          0       0.98      1.00      0.99        42
          1       0.94      0.97      0.96        36
          2       0.94      1.00      0.97        48
          3       0.96      0.96      0.96        46
          4       0.98      0.98      0.98        45
          5       0.98      0.95      0.96        43
          6       0.98      1.00      0.99        43
          7       0.97      0.97      0.97        38
          8       1.00      0.89      0.94        36
          9       0.97      0.92      0.94        39

avg / total       0.97      0.97      0.97       420

可以看到我们的分类效果还是比较不错的。

总结

以上就是实现KNN近邻算法的完整攻略。在实际应用中,我们还可以通过调整K值、选择合适的距离计算方法等来优化算法,提高预测准确率。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:python实现KNN近邻算法 - Python技术站

(0)
上一篇 2023年6月3日
下一篇 2023年6月3日

相关文章

  • Python中POST调用Restful接口示例

    在Python中,我们可以使用requests库调用Restful接口。POST请求是一种常见的HTTP请求方法,用于向服务器提交数据。本文将介绍如何使用requests库调用Restful接口,并提供两个示例。 1. 使用requests库调用Restful接口 使用requests库调用Restful接口非常简单。我们只需要使用requests库的pos…

    python 2023年5月15日
    00
  • python GUI库图形界面开发之PyQt5线程类QThread详细使用方法

    下面是详细的攻略。 Python GUI库图形界面开发之PyQt5线程类QThread详细使用方法 在PyQt5中,线程类QThread被用来处理一些耗时的操作,以避免把这些操作放在主线程中引起其卡顿或者假死。下面我们详细讲解QThread的使用方法。 QThread的使用方法 1.导入必要的模块和类 import sys from PyQt5.QtCore…

    python 2023年5月19日
    00
  • 39条Python语句实现数字华容道

    下面我就给您详细讲解“39条Python语句实现数字华容道”的完整攻略。 简介 数字华容道是一款益智类游戏,玩家需要将打乱的数字拼成一个正确的数字序列。本攻略将介绍如何使用 Python 语言来实现这个游戏。 思路 我们可以通过搜索算法来实现该游戏,在这里我将使用 A 算法。A 算法是一种常用的启发式搜索算法,它能够有效地求解最短路径问题,我们可以通过修改 …

    python 2023年6月13日
    00
  • pycharm中cv2的package安装失败问题及解决

    问题描述 在使用PyCharm进行Python开发时,可能会碰到需要使用cv2包的情况,但是直接在PyCharm的包管理器中搜索安装可能会出现安装失败的问题。这是因为cv2是OpenCV的Python接口,需要依赖于OpenCV库。 解决方法 在PyCharm中安装cv2包通常需要分为两步,第一步是先安装OpenCV库;第二步是在Python中安装cv2包,…

    python 2023年5月13日
    00
  • 解析Python中的eval()、exec()及其相关函数

    解析Python中的eval()、exec()及其相关函数 Python中有三个内置函数eval()、exec()和compile()来执行动态代码。这些函数能够从字符串参数中读取Python代码并在运行时执行该代码。但是,使用这些函数时必须小心,因为它们的不当使用可能会导致安全漏洞。 eval() eval()函数可解析一个字符串表达式,并返回表达式的计算…

    python 2023年5月18日
    00
  • python爬取代理IP并进行有效的IP测试实现

    Python爬取代理IP并进行有效的IP测试实现 在网络爬虫中,使用代理IP可以有效地提高爬取效率和避免被封IP。本文将详细讲解如何使用Python爬取代理IP并进行有效的IP测试实现。 爬取代理IP 我们可以使用Python的requests库和BeautifulSoup库来爬取代理IP。以下是一个使用Python爬取代理IP的示例: import req…

    python 2023年5月15日
    00
  • 删除数据框值Python中的第一个日期实例

    【问题标题】:Deleting first instance of date in dataframe value Python删除数据框值Python中的第一个日期实例 【发布时间】:2023-04-07 03:58:01 【问题描述】: 我有一个如下所示的数据框: Publication Date Date Value 2018-01-01 2018-0…

    Python开发 2023年4月8日
    00
  • Python随手笔记之标准类型内建函数

    Python随手笔记之标准类型内建函数 Python中有许多标准类型内建函数可以对不同的数据类型进行操作。这些函数可以帮助我们更有效地处理数据,让我们来更详细地了解这些内建函数吧。 值类型转换函数 int() int()函数用于将字符串或数字转换为整型。如果参数无法转换成整数,则会抛出ValueError异常。 示例: num1 = int(‘123’) #…

    python 2023年6月5日
    00
合作推广
合作推广
分享本页
返回顶部