下面是详细讲解“kNN算法python实现和简单数字识别的方法”的完整攻略,包括算法原理、Python实现和两个示例说明。
算法原理
kNN算法是一种用的分类算法,其基本思想是通过计算待分类样本与训练集中各个样本的距离,选取距离最近的k个样本,根据这k个样本的类别进行投票,将待分类样本归为票数最多类别。具体步骤如下:
- 计算待分类样本与训练集中各个样本的距离;
- 选取距离最近的k个样本;
- 根据这k个样本的类别进行投票;
- 将待分类样本归为票数最多的类别。
Python实现代码
以下Python实现kNN算法的示例代码:
import numpy as np
class KNN:
def __init__(self, k):
self.k = k
def fit(self, X, y):
self.X = X
self.y = y
def predict(self, X):
y_pred = []
for x in X:
distances = np.sqrt(np.sum((self.X - x) ** 2, axis=1))
indices = np.argsort(distances)[:self.k]
labels = self.y[indices]
y_pred.append(np.bincount(labels).argmax())
return y_pred
上述代码中,定义了一个KNN
类表示kNN算法,包括k
表示选取的最近邻数,fit
方法表示训练模型,predict
方法表示预测样本类别。
示例说明
以下是两个示例,说明如何使用KNN
类进行操作。
示例1
使用KNN
类对鸢尾花数据进行分类。
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)
knn = KNN(k=3)
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.2f}")
输出结果:
Accuracy: 1.00
示例2
使用KNN
类对手写数据进行分类。
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
digits = load_digits()
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.2, random_state=42)
knn = KNN(k=3)
nn.fit(X_train, y_train)
y_pred = knn.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.2f}")
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(6, 6))
for i, ax in enumerate(axes.flat):
ax.imshow(X_test[i].reshape(8, 8), cmap="gray")
ax.set_title(f"True: {y_test[i]}, Pred: {y_pred[i]}")
ax.axis("off")
plt.tight_layout()
plt.show()
输出结果:
Accuracy: 0.98
同时,还会显示一个4x4的图像矩阵每个图像显示一个测试样本的图和其真实类别和预测类别。
总结
本文介绍了kNN算法的Python实现方法,包括算法原理、Python实现代码和两个示例说明。kNN算法是一种常用的分类算法,其基本思想是通过计算待分类样本与训练集中各个样本的距离,选取距离最近的k个样本,根据这k个样本的类别进行投票,将待分类样本归为票数最多的类别。在实际应用中,需要注意选取合适的k值,以获得更好的分类效果。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:kNN算法python实现和简单数字识别的方法 - Python技术站