python KNN算法实现鸢尾花数据集分类

Python实现KNN算法对鸢尾花数据集进行分类

介绍

KNN(K-Nearest-Neighbor)算法是一种非常常用且简单的分类算法之一。它的基本思想是把未知数据的标签与训练集中最邻近的K个数据的标签相比较,得票最多的标签就是未知数据的标签。本文将介绍如何使用Python实现对鸢尾花数据集进行KNN分类。

步骤

  1. 加载数据

首先,我们需要加载鸢尾花数据集。sklearn库中提供了该数据集,我们可以使用load_iris()函数进行加载。

from sklearn.datasets import load_iris

iris = load_iris()
X = iris.data       # 特征矩阵
y = iris.target     # 标签数组
  1. 数据预处理

为了保证KNN算法的准确性,我们需要对数据进行预处理。这里我们采用Z-score标准化方法对特征矩阵进行归一化处理。

from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
X = scaler.fit_transform(X)
  1. 分割训练集和测试集

为了避免过拟合,我们需要将数据集分为训练集和测试集。我们使用train_test_split函数来将数据集随机划分成70%的训练集和30%的测试集。

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
  1. 训练模型

现在我们可以开始训练模型了。KNN算法只有一个参数——K值。对于这个参数,我们需要进行调参。在本次实验中,我们使用交叉验证法来训练模型并选择最佳的K值。

from sklearn.model_selection import cross_val_score
from sklearn.neighbors import KNeighborsClassifier
import matplotlib.pyplot as plt

k_range = range(1, 31)
scores = []
for k in k_range:
    knn = KNeighborsClassifier(n_neighbors=k)
    score = cross_val_score(knn, X_train, y_train, cv=10, scoring='accuracy').mean()
    scores.append(score)

plt.plot(k_range, scores)
plt.xlabel('K')
plt.ylabel('Accuracy')
plt.show()

运行以上代码后,我们会得到一个准确率随着K值变化的折线图。基于该图,我们可以选择最优的K值作为KNN模型的参数。

  1. 预测

在得到最优的K值后,我们可以开始对测试集进行预测了。在本次实验中,我们选择K=5进行预测。

knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)
  1. 评价模型

最后,我们需要评价模型的准确率。我们可以使用scikit-learn库中的accuracy_score函数来评价模型的准确率。

from sklearn.metrics import accuracy_score

score = accuracy_score(y_test, y_pred)
print('Accuracy:', score)

示例

下面我们给出两个使用Python实现KNN算法的鸢尾花数据集分类示例。

示例一

在此示例中,我们将调节K值,并输出最优的K值和其对应的准确率。完整代码如下:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

iris = load_iris()
X = iris.data
y = iris.target

k_range = range(1, 31)
max_score = 0
max_k = 0

for k in k_range:
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
    knn = KNeighborsClassifier(n_neighbors=k)
    knn.fit(X_train, y_train)
    y_pred = knn.predict(X_test)
    score = accuracy_score(y_test, y_pred)

    if score > max_score:
        max_score = score
        max_k = k

print('The best accuracy:', max_score, 'with the best k:', max_k)

在运行以上代码后,我们会输出最优的K值和其对应的准确率。

示例二

在此示例中,我们将训练KNN模型,并使用该模型对测试集进行预测。完整代码如下:

from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

iris = load_iris()
X = iris.data
y = iris.target

scaler = StandardScaler()
X = scaler.fit_transform(X)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)

print('Accuracy:', accuracy_score(y_test, y_pred))

在运行以上代码后,我们会输出模型的准确率。

结论

本文介绍了如何使用Python实现对鸢尾花数据集进行KNN分类,并给出了两个具体实现示例。KNN算法简单、易于理解,它虽然不如其他一些机器学习算法精度高,但在某些问题上表现出色。我们可以通过模型调优和数据预处理等手段来提高算法的准确性。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:python KNN算法实现鸢尾花数据集分类 - Python技术站

(0)
上一篇 2023年5月19日
下一篇 2023年5月19日

相关文章

  • PHP rsa加密解密算法原理解析

    PHP RSA加密解密算法原理解析 RSA是一种非对称加密算法,它使用两个密钥:公钥和私钥。公钥可以向外公开,用于加密数据;而私钥只由数据的持有者保管,用于解密数据。在本文中,我们会使用PHP实现RSA加密解密算法,并分享一些示例代码。 RSA加密解密算法原理 RSA加密解密算法的原理主要是基于数学中的大数分解问题和欧拉定理。以下是RSA算法的一般流程: 用…

    算法与数据结构 2023年5月19日
    00
  • 利用JavaScript实现的10种排序算法总结

    作为“利用JavaScript实现的10种排序算法总结”的作者,首先需要明确以下内容: 熟悉10种排序算法的原理与流程 理解JavaScript作为一门编程语言的特点和应用场景 知道如何将算法的流程用JavaScript代码实现 针对以上内容,可以采取以下步骤: 梳理10种排序算法的流程和实现方式,用markdown文本形式编写对应的标题和文本,例如: 插入…

    算法与数据结构 2023年5月19日
    00
  • 异常点/离群点检测算法——LOF解析

    异常点/离群点检测算法——LOF解析 什么是离群点(Outlier)? 在数据分析领域中,离群点通常指的是数据集中与其他数据点显著不同的数据点,也就是说,离群点是远离其他数据点的数据点。离群点检测是一个非常重要的数据挖掘任务,被广泛应用于异常检测、金融欺诈检测、医学诊断等领域。 LOF算法简介 LOF (Local Outlier Factor) 算法是一种…

    算法与数据结构 2023年5月19日
    00
  • java插入排序 Insert sort实例

    下面我将详细讲解如何实现Java的插入排序算法。 插入排序 Insert Sort 插入排序是一种简单直观的排序算法,它的基本思想是将未排序的数据依次插入到已排序数据中的合适位置,使得插入后序列仍然有序。 插入排序的算法步骤如下: 从第一个元素开始,该元素可以认为已经被排序; 取出下一个元素,在已经排序的元素序列中从后向前扫描; 如果该元素(已排序)大于新元…

    算法与数据结构 2023年5月19日
    00
  • Java 十大排序算法之计数排序刨析

    Java 十大排序算法之计数排序刨析 算法介绍 计数排序是一个时间复杂度为O(n+k)的非基于比较的排序算法,其中n是待排序元素的个数,k是待排序元素的范围,即待排序元素的最大值减去最小值再加1。 算法通过构建一个长度为k的计数数组来统计每个元素出现的次数,然后借助计数数组按顺序输出每个元素,就完成了排序过程。 因为计数排序是非基于比较的算法,因此可以在一定…

    算法与数据结构 2023年5月19日
    00
  • JavaScript实现的七种排序算法总结(推荐!)

    JavaScript实现的七种排序算法总结(推荐!) 简介 本文介绍了JavaScript实现的七种排序算法,包括插入排序、冒泡排序、选择排序、希尔排序、归并排序、快速排序和堆排序。每种算法都有对应的JavaScript代码实现,并且详细说明了算法的原理、时间复杂度和代码实现过程。 插入排序 插入排序是一种简单的排序算法,它的基本思想是将数组分成已排序和未排…

    算法与数据结构 2023年5月19日
    00
  • c++实现排序算法之希尔排序方式

    C++实现排序算法之希尔排序 前置知识 希尔排序是一种基于插入排序的排序算法 插入排序是一种简单直观的排序算法 算法思路 希尔排序是一种分组插入排序的算法。它的基本思想是:先将待排序序列按照一定规则分成若干子序列,对各个子序列进行插入排序,然后逐步缩小子序列的长度,最终使整个序列成为一个有序序列。 例如,对于一个序列 5 2 8 9 1 3 7 6 4,我们…

    算法与数据结构 2023年5月19日
    00
  • C++详细讲解图的拓扑排序

    C++详细讲解图的拓扑排序 什么是拓扑排序 拓扑排序是对于有向无环图(Directed Acyclic Graph)的一种排序,其输出结果为图中每个节点的线性先后序列,满足如果存在一条从节点 A 到节点 B 的路径,则在序列中节点 A 出现在节点 B 的前面。 什么是有向无环图(DAG) 有向无环图是不包含环路并且有一个或多个源点和汇点的有向图。其中源点指没…

    算法与数据结构 2023年5月19日
    00
合作推广
合作推广
分享本页
返回顶部