Python LeNet网络详解及PyTorch实现
本文将介绍LeNet网络的结构和实现,并使用PyTorch实现一个LeNet网络进行手写数字识别。
1. LeNet网络结构
LeNet网络是由Yann LeCun等人在1998年提出的,是一个经典的卷积神经网络。它主要用于手写数字识别,包含两个卷积层和三个全连接层。
LeNet网络的结构如下所示:
输入层 -> 卷积层1 -> 池化层1 -> 卷积层2 -> 池化层2 -> 全连接层1 -> 全连接层2 -> 输出层
其中,卷积层1和卷积层2都使用5x5的卷积核,池化层1和池化层2都使用2x2的最大池化。全连接层1包含120个神经元,全连接层2包含84个神经元。输出层包含10个神经元,每个神经元对应一个数字。
2. LeNet网络实现
我们将使用PyTorch实现一个LeNet网络进行手写数字识别。我们将使用MNIST数据集进行训练和测试。
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)
# 定义LeNet网络
class LeNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
self.pool2 = nn.MaxPool2d(kernel_size=2)
self.fc1 = nn.Linear(16*4*4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool1(torch.relu(self.conv1(x)))
x = self.pool2(torch.relu(self.conv2(x)))
x = x.view(-1, 16*4*4)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
# 定义模型、损失函数和优化器
model = LeNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 训练模型
for epoch in range(10):
train_loss = 0.0
train_correct = 0
train_total = 0
for inputs, labels in train_dataset:
optimizer.zero_grad()
outputs = model(inputs.unsqueeze(0))
loss = criterion(outputs, labels.unsqueeze(0))
loss.backward()
optimizer.step()
train_loss += loss.item()
predicted = torch.argmax(outputs)
train_total += 1
train_correct += (predicted == labels).sum().item()
print(f'Epoch {epoch+1}, Train Loss: {train_loss/train_total}, Train Accuracy: {train_correct/train_total}')
# 测试模型
test_correct = 0
test_total = 0
with torch.no_grad():
for inputs, labels in test_dataset:
outputs = model(inputs.unsqueeze(0))
predicted = torch.argmax(outputs)
test_total += 1
test_correct += (predicted == labels).sum().item()
print(f'Test Accuracy: {test_correct/test_total}')
示例2:使用GPU加速的LeNet网络
如果你的机器上有GPU,你可以使用PyTorch的GPU加速功能来加速模型训练和预测。以下是使用GPU加速的LeNet网络的示例代码。
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)
# 定义LeNet网络
class LeNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
self.pool2 = nn.MaxPool2d(kernel_size=2)
self.fc1 = nn.Linear(16*4*4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool1(torch.relu(self.conv1(x)))
x = self.pool2(torch.relu(self.conv2(x)))
x = x.view(-1, 16*4*4)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
# 定义模型、损失函数和优化器
model = LeNet().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 训练模型
for epoch in range(10):
train_loss = 0.0
train_correct = 0
train_total = 0
for inputs, labels in train_dataset:
optimizer.zero_grad()
inputs = inputs.cuda()
labels = labels.cuda()
outputs = model(inputs.unsqueeze(0))
loss = criterion(outputs, labels.unsqueeze(0))
loss.backward()
optimizer.step()
train_loss += loss.item()
predicted = torch.argmax(outputs)
train_total += 1
train_correct += (predicted == labels).sum().item()
print(f'Epoch {epoch+1}, Train Loss: {train_loss/train_total}, Train Accuracy: {train_correct/train_total}')
# 测试模型
test_correct = 0
test_total = 0
with torch.no_grad():
for inputs, labels in test_dataset:
inputs = inputs.cuda()
labels = labels.cuda()
outputs = model(inputs.unsqueeze(0))
predicted = torch.argmax(outputs)
test_total += 1
test_correct += (predicted == labels).sum().item()
print(f'Test Accuracy: {test_correct/test_total}')
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Python LeNet网络详解及pytorch实现 - Python技术站