python使用KNN算法手写体识别
介绍
K最近邻(K-Nearest Neighbor,KNN)算法是一种用于分类和回归的非参数方法。在模型管理中,KNN被认为是一种有监督的学习方法,其中非标记数据分类或回归信息传递给最近邻居的标记数据来预测新输入的标记。
本文将会使用Python编程语言和KNN算法来手写体识别。下面是一个完整的攻略:
总体步骤
步骤1:数据收集
手写数字数据集MNIST,其中包含有60,000个示例的训练集以及10,000个示例的测试集。本文将使用这个数据集,该数据集可在http://yann.lecun.com/exdb/mnist/下载。
在这个数据集中,每个图像都是28×28像素的灰度图像,并且已经标记为0-9的数字之一。
步骤2:数据预处理
在这一步骤中,我们通过Python代码将数据预处理为KNN算法可用的格式。
代码示例:
# 导入相关库
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
import pandas as pd
# 加载数据集
train_df = pd.read_csv('mnist_train.csv')
test_df = pd.read_csv('mnist_test.csv')
# 分离特征与标签
X_train = train_df.iloc[:, 1:].values
y_train = train_df.iloc[:, 0].values
X_test = test_df.iloc[:, 1:].values
y_test = test_df.iloc[:, 0].values
# 将特征归一化
X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0
# 将数据集分为训练集和验证集
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)
步骤3:训练模型
在这一步骤中,我们将使用KNN算法在数据集上训练模型。
代码示例:
# 构建模型
knn_model = KNeighborsClassifier(n_neighbors=5, weights='uniform', p=2, metric='minkowski')
# 训练模型
knn_model.fit(X_train, y_train)
步骤4:模型评估
在这一步骤中,我们评估训练出来的模型的准确率。
代码示例:
# 预测验证集
y_val_pred = knn_model.predict(X_val)
# 计算验证集的准确率
val_accuracy = accuracy_score(y_val, y_val_pred)
# 输出准确率
print("Validation Accuracy: {:.2f}%".format(val_accuracy * 100))
步骤5:使用模型预测
在这一步骤中,我们将使用训练好的模型对测试集中的手写数字图像进行预测。
代码示例:
# 预测测试集
y_test_pred = knn_model.predict(X_test)
# 计算测试集的准确率
test_accuracy = accuracy_score(y_test, y_test_pred)
# 输出准确率
print("Test Accuracy: {:.2f}%".format(test_accuracy * 100))
示例1:KNN手写体识别代码的完整实现
下面是使用KNN算法进行手写体识别的完整Python代码:
# 导入相关库
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
import pandas as pd
# 加载数据集
train_df = pd.read_csv('mnist_train.csv')
test_df = pd.read_csv('mnist_test.csv')
# 分离特征与标签
X_train = train_df.iloc[:, 1:].values
y_train = train_df.iloc[:, 0].values
X_test = test_df.iloc[:, 1:].values
y_test = test_df.iloc[:, 0].values
# 将特征归一化
X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0
# 将数据集分为训练集和验证集
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)
# 构建模型
knn_model = KNeighborsClassifier(n_neighbors=5, weights='uniform', p=2, metric='minkowski')
# 训练模型
knn_model.fit(X_train, y_train)
# 预测验证集
y_val_pred = knn_model.predict(X_val)
# 计算验证集的准确率
val_accuracy = accuracy_score(y_val, y_val_pred)
# 输出准确率
print("Validation Accuracy: {:.2f}%".format(val_accuracy * 100))
# 预测测试集
y_test_pred = knn_model.predict(X_test)
# 计算测试集的准确率
test_accuracy = accuracy_score(y_test, y_test_pred)
# 输出准确率
print("Test Accuracy: {:.2f}%".format(test_accuracy * 100))
示例2:调参
在这个示例中,我们尝试通过调整不同的参数来提高模型的准确率。
代码示例:
# 导入相关库
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
import pandas as pd
# 加载数据集
train_df = pd.read_csv('mnist_train.csv')
test_df = pd.read_csv('mnist_test.csv')
# 分离特征与标签
X_train = train_df.iloc[:, 1:].values
y_train = train_df.iloc[:, 0].values
X_test = test_df.iloc[:, 1:].values
y_test = test_df.iloc[:, 0].values
# 将特征归一化
X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0
# 将数据集分为训练集和验证集
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)
for k in range(1, 11):
for metric in ['euclidean', 'manhattan', 'minkowski']:
# 构建模型
knn_model = KNeighborsClassifier(n_neighbors=k, weights='uniform', p=2, metric=metric)
# 训练模型
knn_model.fit(X_train, y_train)
# 预测验证集
y_val_pred = knn_model.predict(X_val)
# 计算验证集的准确率
val_accuracy = accuracy_score(y_val, y_val_pred)
# 输出准确率
print("k: {}, Metric: {}, Validation Accuracy: {:.2f}%".format(k, metric, val_accuracy * 100))
结论
在本文中,我们介绍了利用Python和KNN算法进行手写体识别的完整攻略,包括数据收集、数据预处理、训练模型、模型评估和使用模型预测等步骤。示例代码也提供了两个具体的案例,读者可在此基础上进一步实践和改进,提高认识和运用KNN算法的能力。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:python使用KNN算法手写体识别 - Python技术站