PyTorch实现图像识别(实战)攻略
前言
图像识别是计算机视觉领域的一个重要应用,而深度学习技术在图像识别中发挥了重要作用。PyTorch是深度学习领域的一个强大工具,本文将介绍如何使用PyTorch实现图像识别。
环境
在实现图像识别之前,需要确保安装了正确的开发环境,包括:
- Python 3.x版本
- PyTorch 1.x版本
- Torchvision、Numpy等必要的库
数据集
在进行图像识别之前,需要先准备好数据集。数据集是指一组图像和对应的标签,用于训练和测试模型。在本文中,我们将使用MNIST数据集,它包含了60,000张训练图和10,000张测试图,每张图都是28x28像素的灰度图,标签分别是0到9的数字。
构建模型
在PyTorch中,可以使用nn模块构建深度学习模型。在本文中,我们将构建一个简单的卷积神经网络来识别MNIST数据集中的数字图像。
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.conv2(x)
x = nn.functional.relu(x)
x = nn.functional.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = nn.functional.softmax(x, dim=1)
return output
以上是一个简单的卷积神经网络,包括2个卷积层、2个池化层、2个丢弃层和2个全连接层,最后采用softmax激活函数输出每个数字的概率。
训练模型
下面是如何使用PyTorch训练我们构建的模型:
import torch.optim as optim
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(10):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 100 == 99:
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 100))
running_loss = 0.0
print('Finished Training')
在训练之前,需要设置损失函数、优化器和学习率,然后利用训练集进行训练。我们将训练集分batch进行训练,每一次迭代都会更新模型权重。
模型评估
训练完成后,需要使用测试集对模型进行评估。
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))
以上代码遍历测试集,输出模型在测试集上的正确率。
示例1:手写数字识别
下面是如何使用我们构建的模型识别手写数字:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# Load the image from file
img = Image.open('test.png')
# Convert the image to grayscale
img = img.convert('L')
# Resize the image to 28x28 pixels
img = img.resize((28, 28))
# Convert the image to a numpy array
img = np.asarray(img)
# Invert the pixels (0->255, 255->0)
img = 255 - img
# Normalize the pixel values to be between 0 and 1
img = img / 255
# Convert the numpy array to a PyTorch tensor
img = torch.FloatTensor(img)
# Add a batch dimension to the tensor
img = img.unsqueeze(0).unsqueeze(0)
# Run the model on the input image
output = model(img)
# Get the predicted digit
_, predicted = torch.max(output.data, 1)
# Convert the PyTorch tensor to a numpy array
img = img.squeeze().numpy()
# Display the input image and the predicted digit
plt.imshow(img, cmap='gray')
plt.title("Predicted digit: {}".format(predicted.item()))
plt.show()
该代码加载图片、转换为灰度图并缩放为28x28像素。然后,通过一系列数值计算将图像转换为张量,并添加一个批处理维度。最后,运行模型进行预测并将结果呈现为数字。
示例2:猫狗图像分类
下面是如何使用我们构建的模型识别猫狗图像:
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
# Load the image from file
img = Image.open('cat.jpg')
# Define the image transformations
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.Grayscale(),
transforms.ToTensor()
])
# Apply the transformations to the image
img = transform(img)
# Add a batch dimension to the tensor
img = img.unsqueeze(0)
# Run the model on the input image
output = model(img)
# Get the predicted class probabilities
probs = F.softmax(output, dim=1)
# Get the predicted class label
pred_class = torch.argmax(probs)
# Display the input image and the predicted class
plt.imshow(img.squeeze().numpy().transpose(1, 2, 0), cmap='gray')
if pred_class == 0:
plt.title("Predicted class: cat")
else:
plt.title("Predicted class: dog")
plt.show()
该代码加载图片并进行预处理,然后将其转换为张量并执行预测。最后,将预测结果呈现为类别。注意:在训练阶段,应该使用猫狗图像数据集去训练模型。
结语
本文介绍了如何使用PyTorch实现图像识别,包括构建模型、训练模型和模型评估等步骤。实际应用中,我们可以根据数据集和任务需求,选择不同的深度学习模型进行构建和训练。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch实现图像识别(实战) - Python技术站