Python 实现LeNet网络模型的训练及预测

Python实现LeNet网络模型的训练及预测

LeNet是一种经典的卷积神经网络模型,由Yann LeCun等人于1998年提出,主要用于手写数字识别。本文将详细讲解如何使用Python实现LeNet网络模型的训练及预测,包括数据集准备、模型的搭建、训练和预测等。

数据集准备

在实现LeNet网络模型之前,需要准备一个合适的数据集。在本文中,我们将使用MNIST数据集,它包含了60000张28x28像素的手写数字图片,共分为10个类别。可以使用以下代码和加载MNIST数据集:

import tensorflow as tf
from tensorflow.keras.datasets import mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

模型搭建

在数据集准备好之后,可以开始搭建LeNet网络模型。以下是LeNet网络模型的代码实现:

from tensorflow.keras import layers, models

model = models.Sequential()
model.add(layers.Conv2D(6, (5, 5), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(16, (5, 5), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Flatten())
model.add(layers.Dense(120, activation='relu'))
model.add(layers.Dense(84, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))

上面的代码使用了Keras API搭建了LeNet网络模型,其中Conv2D层和MaxPooling2D层分别表示卷积层和池化层,Flatten层用于将卷积层的输出展平,Dense层表示全连接层,softmax函数用于多分类问题的输出。

模型训练

在搭建好LeNet网络模型之后,可以开始训练模型。以下是模型训练的代码实现:

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test))

上面的代码使用了compile函数编译了模型,使用了fit函数训练了模型。在编译模型时,使用了adam优化器、sparse_categorical_crossentropy损失函数和accuracy评估指标。在训练模型时,使用了训练集和测试集的数据和标签,设置了5个epochs和64个batch_size。

模型预测

在训练好LeNet网络模型之后,可以使用模型进行预测。以下是模型预测的代码实现:

import numpy as np

predictions = model.predict(x_test)
y_pred = np.argmax(predictions, axis=1)

print(y_pred[:10])
print(y_test[:10])

上面的代码使用了predict函数对测试集进行预测,使用了argmax函数获取预测结果中概率最大的类别,然后输出了前10个预测结果和真实标签。

示例一:完整代码实现

以下是完整的LeNet网络模型的训练和预测的代码实现:

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras import layers, models
import numpy as np

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape((60000, 28, 28, 1))
x_test = x_test.reshape((10000, 28, 28, 1))
x_train, x_test = x_train / 255.0, x_test / 255.0

model = models.Sequential()
model.add(layers.Conv2D(6, (5, 5), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(16, (5, 5), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Flatten())
model.add(layers.Dense(120, activation='relu'))
model.add(layers.Dense(84, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test))

predictions = model.predict(x_test)
y_pred = np.argmax(predictions, axis=1)

print(y_pred[:10])
print(y_test[:10])

示例二:可视化训练过程

可以使用Matplotlib库可视化LeNet网络模型的训练过程。以下是可视化训练过程的代码实现:

import matplotlib.pyplot as plt

history = model.fit(x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test))

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(1, len(acc) + 1)

plt.plot(epochs, acc, 'bo', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend()

plt.figure()

plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation')
plt.title('Training and validation loss')
plt.legend()

plt.show()

上面的代码使用了fit函数训练模型,并将训练过程中的准确率和损失值保存在``变量中。然后使用Matplotlib库绘制了训练和验证准确和损失值的曲线图。

总结

本文详细讲解了如何使用Python实现LeNet网络模型的训练及预测,包括数据集的准备、模型的搭建、训练和预测等。在实现LeNet网络模型时,需要注意数据集的格式、模型的层次结构和参数设置,以及训练和预测的过程。LeNet网络模型是深度学习领域的经典模型,可以用于手写数字识别等多种任务。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Python 实现LeNet网络模型的训练及预测 - Python技术站

(0)
上一篇 2023年5月14日
下一篇 2023年5月14日

相关文章

  • numpy数组切片的使用

    以下是关于“numpy数组切片的使用”的完整攻略。 背景 在NumPy中,我们可以使用切片(slice)来访问数组中的元素。本攻略将介绍如何使用NumPy数组切片,并提供两个示例来演示如何使用这些方法。 NumPy数组切片 以下是使用NumPy数组切片的示例: import numpy as np # 创建一个数组 arr = np.array([1, 2,…

    python 2023年5月14日
    00
  • 详解NumPy中数组的索引和取值

    在NumPy中,可以使用索引和切片操作来获取数组中的元素和子数组。下面详细介绍NumPy数组的索引和取值方法。 NumPy数组索引 NumPy数组可以像Python列表一样使用索引来获取元素。数组的索引从0开始,可以是负数,表示从末尾开始索引。可以使用以下方法对NumPy数组进行索引: 单个元素索引 可以通过指定元素的下标来获取数组中的单个元素,如: imp…

    2023年2月28日
    00
  • Python中Numpy mat的使用详解

    以下是关于“Python中Numpy.mat的使用详解”的完整攻略。 Numpy.mat的使用 Numpy.mat是Numpy中的一个子类,它提供了一些特殊的矩阵运算方法。使用Numpy创建矩阵的方法非常简单,只需要使用np.mat()函数即可。下面是Numpy.mat的使用示例: 创建矩阵 使用Numpy.mat创建矩阵的方法非简单,只需要使用np.mat…

    python 2023年5月14日
    00
  • PyTorch中 tensor.detach() 和 tensor.data 的区别解析

    当我们使用PyTorch时,经常会遇到需要“切断计算图”的情况,同时需要保留某些tensor的值。两个常用的方法就是 detach() 和 data,但它们具有一些区别。 detach()和data的基本作用 detach(): 用于将一个tensor从计算图上分离出来,并返回一个新的不与计算图相连接的tensor。使用detach()可以阻止梯度反向传播算…

    python 2023年5月14日
    00
  • Numpy中的shape函数的用法详解

    以下是关于“Numpy中的shape函数的用法详解”的完整攻略。 Numpy中的shape函数 在Numpy中,shape函数用于获取数组的形状,即数组的维度和大小。shape函数返回一个元组,元组中的每个元素表示数组在对应维度上的大小。 获取数组的形状 下面是一个使用shape函数获取数组形状的示例代码: import numpy as np # 创建一个…

    python 2023年5月14日
    00
  • Python压缩解压缩zip文件及破解zip文件密码的方法

    Python压缩解压缩zip文件及破解zip文件密码的方法 Python提供了标准库 zipfile 来对zip文件进行压缩解压缩操作,并且可以在这个库的基础上扩展实现zip文件的密码破解。 压缩zip文件 使用 zipfile 库中的 ZipFile() 函数可以创建一个zip文件,并且可以使用 write() 函数向zip文件中添加文件。 import …

    python 2023年5月14日
    00
  • Python 提速器numba

    当你需要加速Python代码时,Numba是一个非常有用的工具。Numba是一个开源的JIT(即时编译器),它可以将Python代码转换为本地机器代码,从而提高代码的执行速度。下面是使用Numba的完整攻略: 安装Numba 在终端中运行以下命令来安装Numba: pip install numba 导入Numba 在Python脚本中导入Numba: im…

    python 2023年5月14日
    00
  • conda虚拟环境默认路径的修改方法

    Conda虚拟环境默认路径的修改方法 在本攻略中,我们将介绍如何修改Conda虚拟环境默认路径。以下是整个攻略,含两个示例说明。 示例1:使用conda config命令修改默认路径 以下是使用conda config命令修改默认路径的步骤: 打开终端。可以使用以下快捷键打开终端: Windows:Win + R,输入cmd,按Enter键 macOS:Co…

    python 2023年5月14日
    00
合作推广
合作推广
分享本页
返回顶部