Python实现KNN(K-近邻)算法的示例代码

yizhihongxing

下面是详细讲解“Python实现KNN(K-近邻)算法的示例代码”的完整攻略,包括算法原理、Python实现和两个示例。

算法原理

KNN(K近邻)算法是一种基于实例的学习算法,其主要思想是通过计算样本间的距离,找到与目标样本最近的K个样本,然后根据这K个样本的类别,来预测目标样本的类别。

KNN算法的实现过程如下:

  1. 计算目标样本与每个样本之间的距离。
  2. 选取与目标样本距离最近的K个样本。
  3. 根据这K个样本的类别,来预测目标样本的类别。

KNN算法的核心在于如何计算样本之间的距离,常见的距离计算方法包括欧氏距离、曼哈顿距离和余弦距离等。

Python实现

以下是Python实现KNN算法的示例代码:

import numpy as np
from collections import Counter

class KNN:
    def __init__(self, k=3, distance='euclidean'):
        self.k = k
        self.distance = distance

    def fit(self, X, y):
        self.X_train = X
        self.y_train = y

    def predict(self, X):
        y_pred = []
        for x in X:
            distances = []
            for x_train in self.X_train:
                if self.distance == 'euclidean':
                    dist = np.sqrt(np.sum((x - x_train) ** 2))
                elif self.distance == 'manhattan':
                    dist = np.sum(np.abs(x - x_train))
                elif self.distance == 'cosine':
                    dist = np.dot(x, x_train) / (np.linalg.norm(x) * np.linalg.norm(x_train))
                distances.append(dist)
            k_nearest = np.argsort(distances)[:self.k]
            k_nearest_labels = [self.y_train[i] for i in k_nearest]
            most_common = Counter(k_nearest_labels).most_common(1)
            y_pred.append(most_common[0][0])
        return y_pred

上述代码中,使用Python实现了KNN算法。首先定义了一个KNN类,表示KNN算法,包括K值和距离计算方法。在KNN类中,定义了拟合函数fit和预测函数predict。然后使用KNN算法进行分类,返回预测结果。

示例说明

以下两个示例,说明如何使用上述代码进行KNN算法。

示例1

使用KNN算法鸢尾花数据集进行分类。

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)

knn = KNN(k=3, distance='euclidean')
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)
print(accuracy_score(y_test, y_pred))

运行上述代码,输出结果如下:

1.0

上述代码,使用KNN算法对鸢尾花数据集进行分类。首先使用train_test_split函数将数据集分为训练集和测试集,然后使用KNN算法进行分类,最后使用accuracy_score函数计算分类准确率。运行结果为分类确率。

示例2

使用KNN算法对手写数字数据集进行分类。

from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

digits = load_digits()
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.2, random_state=42)

knn = KNN(k=3, distance='euclidean')
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)
print(accuracy_score(y_test, y_pred))

运行上述代码,输出结果如下:

0.9861111111111112

上述代码中,使用KNN算对手写数字数据集进行分类。首先使用train_test_split函数将数据集分为训练集和测试集,然后使用KNN算法进行分类,最后使用accuracy_score函数计算分类准确率。运行结果为分类准确率。

结语

本文介绍了如何使用Python实现KNN算法,包括算法原理、Python实现和两个示例说明。KNN算法是一种常用的分类算法,其主要思想是通过计算样本之间的距离,找到与目标样本最近的K个样本,然后根据这K个样本的类别,预测目标样本的类别。在实现中,需要注意选择合适的K值和距离计算方法,并根据具体情况进行调整。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Python实现KNN(K-近邻)算法的示例代码 - Python技术站

(0)
上一篇 2023年5月14日
下一篇 2023年5月14日

相关文章

  • 一篇文章带你了解kali局域网攻击

    一篇文章带你了解kali局域网攻击 什么是Kali Linux? Kali Linux 是基于 Debian 的 GNU/Linux 发行版。设计用于数字鉴定和渗透测试。此操作系统包含了数百个预先安装好的工具,可以用于测试网络安全性,包括端口扫描、漏洞攻击、渗透测试等。 实施攻击前需要知道的基础信息 在使用Kali Linux进行攻击之前,需要先收集一些基础…

    python 2023年5月20日
    00
  • Python实现解析命令行参数的常见方法总结

    标题:Python实现解析命令行参数的常见方法总结 引言:命令行参数是指在控制台或者终端中输入的参数,对于很多脚本程序及应用程序,都需要支持特定的命令行参数。Python提供了许多解析命令行参数的库,本文将会介绍两种常见的方法:argparse和getopt。 正文: 一、argparse解析命令行参数 1. argparse库的安装 pip install…

    python 2023年6月2日
    00
  • Python使用ntplib库同步校准当地时间的方法

    当我们需要精确地获取当地的时间,或者需要与其他国家、地区的服务器时间同步,我们可以使用Python的ntplib库来实现。 使用ntplib库同步校准当地时间的方法 以下是使用Python的ntplib库同步校准当地时间的方法。 1. 导入ntplib库 当我们需要使用ntplib库来操作时间时,我们需要先导入这个库: import ntplib 2. 创建…

    python 2023年6月2日
    00
  • Python文件时间操作步骤代码详解

    Python文件时间操作步骤代码详解 1. 文件时间戳 1.1 获取文件最后的访问时间、修改时间和状态时间 在Python中,我们可以通过os.path模块下的getatime、getmtime和getctime函数分别获取文件的最后访问时间、最后修改时间和最后状态改变时间。这些返回值为从1970年1月1日到当前时间的秒数,是一个浮点数。 import os…

    python 2023年6月3日
    00
  • python去除字符串中的空格、特殊字符和指定字符的三种方法

    下面对三种方法进行详细讲解。 方法一:使用Python内置的字符串函数 Python内置的字符串函数strip()、replace()和translate()可以方便地去除字符串中的空格、特殊字符和指定字符。 1. 去除空格 string_with_spaces = " This is a string with spaces. " st…

    python 2023年6月5日
    00
  • 如何使用 Redis 的缓存功能来提高网站性能?

    以下是详细讲解如何使用 Redis 的缓存功能来提高网站性能的完整使用攻略。 Redis 缓存简介 Redis 是一种高性能的键值存储数据库,支持多种结构和高级功能。其中,缓存是 Redis 的一个重要功能,可以用于提高网站性能。Redis 缓存的特点如下: Redis 缓存是基于内存,读写速度非常快。 Redis 缓存是分布式的,可以将缓存数据分布在个节点…

    python 2023年5月12日
    00
  • 如何使用Python在MySQL中删除索引?

    要使用Python在MySQL中删除索引,可以使用Python的内置模块sqlite3或第三方库mysql-connector-python。以下是使用mysql-connector-python在MySQL中删除索引的完整攻略: 连接 要连接到MySQL,需要提供MySQL的主机、用户名、和密码。可以使用以下代码连接MySQL: mysql.connect…

    python 2023年5月12日
    00
  • 对Python中 \r, \n, \r\n的彻底理解

    下面是对Python中\r、\n和\r\n的详细解释。 背景 在计算机中,换行分两种:回车(Carriage Return)和换行(Line Feed)。在以前的打字机时代,回车的操作是由一个机械装置来完成的,它会把打印头快速地移回行首,这个操作会造成打印纸移动一行的效果。而换行则是让打印头下移一行。 在计算机中,我们通常使用的是ASCII码作为字符编码,其…

    python 2023年5月31日
    00
合作推广
合作推广
分享本页
返回顶部