• 我们知道RNN在处理序列问题上十分有效,那么在图像处理上能奏效吗?

  • 我们使用MNSIT手写数字数据集尝试一下

  • 数据是batch_size12828的,将每张2828的图像按行展开成28个28的序列,就可以使用循环神经网络处理了,这里用的是RNN的改进版本LSTM

  • 实践后我们发现rnn的准确率达到了98,和之前用cnn的相近

  • 但MNSIT过于简单,我们又换成了稍微复杂的Fasion——MNIST数据集,达到了接近90的准确率,而简单的cnn就能轻松达到90以上了

  • 从两种网络的原理上也可以窥见到这种结果,rnn将图像转换为序列处理,相比cnn,无法很好捕捉相邻像素块表达的信息,并没有很好的针对图像的性质

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import numpy as np
BATCH_SIZE = 512
INPUT_SIZE = 28
EPOCHS = 10
DEVICE = torch.device("cuda")
LR = 0.01
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        root='data',
        train=True,
        download=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
             transforms.Normalize((0.1307,),(0.3801,))
        ])),
    batch_size = BATCH_SIZE,
    shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        root='data',
        train=False,
        download=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
             transforms.Normalize((0.1307,),(0.3801,))
        ])),
    batch_size = BATCH_SIZE,
    shuffle=True)
class RNN(nn.Module):
    def __init__(self):
        super(RNN,self).__init__()
        self.rnn = nn.LSTM(
            input_size=INPUT_SIZE,
            hidden_size=64,
            num_layers=1,
            batch_first=True,
        )
        self.out = nn.Linear(64,10)
        
    def forward(self,x):
        r_out,(h_n, h_c) = self.rnn(x,None)
        out = self.out(r_out[:,-1,:])
        return out
    
rnn =RNN().to(DEVICE)
optimizer = torch.optim.Adam(rnn.parameters(),lr=LR)
loss_func = nn.CrossEntropyLoss()
def train(model,device,train_loader,optimizer,epoch):
    model.train()
    for batch_idx, (data,target) in enumerate(train_loader):
        data,target = data.to(device),target.to(device)
        data = data.view(-1,28,28)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_func(output,target)
        loss.backward()
        optimizer.step()
        if(batch_idx+1)%30 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def test(model,device,test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data,target in test_loader:
            data, target = data.to(device), target.to(device)
            data = data.view(-1,28,28)
            output = model(data)
            test_loss  += loss_func(output,target).item()
            pred = output.max(1,keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
for epoch in range(1,EPOCHS+1):
    train(rnn, DEVICE, train_loader, optimizer, epoch)
    test(rnn,DEVICE,test_loader)
Train Epoch: 1 [14848/60000 (25%)]	Loss: 0.695383
Train Epoch: 1 [30208/60000 (50%)]	Loss: 0.288098
Train Epoch: 1 [45568/60000 (75%)]	Loss: 0.255972

Test set: Average loss: 0.0003, Accuracy: 9468/10000 (95%)

Train Epoch: 2 [14848/60000 (25%)]	Loss: 0.155514
Train Epoch: 2 [30208/60000 (50%)]	Loss: 0.118997
Train Epoch: 2 [45568/60000 (75%)]	Loss: 0.116895

Test set: Average loss: 0.0002, Accuracy: 9667/10000 (97%)

Train Epoch: 3 [14848/60000 (25%)]	Loss: 0.103602
Train Epoch: 3 [30208/60000 (50%)]	Loss: 0.128102
Train Epoch: 3 [45568/60000 (75%)]	Loss: 0.109712

Test set: Average loss: 0.0002, Accuracy: 9765/10000 (98%)

Train Epoch: 4 [14848/60000 (25%)]	Loss: 0.060537
Train Epoch: 4 [30208/60000 (50%)]	Loss: 0.064028
Train Epoch: 4 [45568/60000 (75%)]	Loss: 0.079152

Test set: Average loss: 0.0002, Accuracy: 9765/10000 (98%)

Train Epoch: 5 [14848/60000 (25%)]	Loss: 0.065045
Train Epoch: 5 [30208/60000 (50%)]	Loss: 0.050917
Train Epoch: 5 [45568/60000 (75%)]	Loss: 0.071716

Test set: Average loss: 0.0001, Accuracy: 9806/10000 (98%)

Train Epoch: 6 [14848/60000 (25%)]	Loss: 0.027889
Train Epoch: 6 [30208/60000 (50%)]	Loss: 0.057783
Train Epoch: 6 [45568/60000 (75%)]	Loss: 0.057316

Test set: Average loss: 0.0002, Accuracy: 9788/10000 (98%)

Train Epoch: 7 [14848/60000 (25%)]	Loss: 0.026589
Train Epoch: 7 [30208/60000 (50%)]	Loss: 0.046758
Train Epoch: 7 [45568/60000 (75%)]	Loss: 0.048475

Test set: Average loss: 0.0001, Accuracy: 9806/10000 (98%)

Train Epoch: 8 [14848/60000 (25%)]	Loss: 0.051083
Train Epoch: 8 [30208/60000 (50%)]	Loss: 0.047659
Train Epoch: 8 [45568/60000 (75%)]	Loss: 0.044323

Test set: Average loss: 0.0001, Accuracy: 9846/10000 (98%)

Train Epoch: 9 [14848/60000 (25%)]	Loss: 0.041711
Train Epoch: 9 [30208/60000 (50%)]	Loss: 0.071301
Train Epoch: 9 [45568/60000 (75%)]	Loss: 0.059197

Test set: Average loss: 0.0001, Accuracy: 9818/10000 (98%)

Train Epoch: 10 [14848/60000 (25%)]	Loss: 0.029158
Train Epoch: 10 [30208/60000 (50%)]	Loss: 0.022632
Train Epoch: 10 [45568/60000 (75%)]	Loss: 0.049115

Test set: Average loss: 0.0001, Accuracy: 9839/10000 (98%)
train_loader = torch.utils.data.DataLoader(
    datasets.FashionMNIST(
        root='data',
        train=True,
        download=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
             transforms.Normalize((0.1307,),(0.3801,))
        ])),
    batch_size = BATCH_SIZE,
    shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.FashionMNIST(
        root='data',
        train=False,
        download=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
             transforms.Normalize((0.1307,),(0.3801,))
        ])),
    batch_size = BATCH_SIZE,
    shuffle=True)
for epoch in range(1,EPOCHS+1):
    train(rnn, DEVICE, train_loader, optimizer, epoch)
    test(rnn,DEVICE,test_loader)
Train Epoch: 1 [14848/60000 (25%)]	Loss: 0.780703
Train Epoch: 1 [30208/60000 (50%)]	Loss: 0.591628
Train Epoch: 1 [45568/60000 (75%)]	Loss: 0.437340

Test set: Average loss: 0.0009, Accuracy: 8311/10000 (83%)

Train Epoch: 2 [14848/60000 (25%)]	Loss: 0.432366
Train Epoch: 2 [30208/60000 (50%)]	Loss: 0.460083
Train Epoch: 2 [45568/60000 (75%)]	Loss: 0.401738

Test set: Average loss: 0.0009, Accuracy: 8441/10000 (84%)

Train Epoch: 3 [14848/60000 (25%)]	Loss: 0.339167
Train Epoch: 3 [30208/60000 (50%)]	Loss: 0.364077
Train Epoch: 3 [45568/60000 (75%)]	Loss: 0.393577

Test set: Average loss: 0.0008, Accuracy: 8573/10000 (86%)

Train Epoch: 4 [14848/60000 (25%)]	Loss: 0.377391
Train Epoch: 4 [30208/60000 (50%)]	Loss: 0.332370
Train Epoch: 4 [45568/60000 (75%)]	Loss: 0.396648

Test set: Average loss: 0.0008, Accuracy: 8648/10000 (86%)

Train Epoch: 5 [14848/60000 (25%)]	Loss: 0.295036
Train Epoch: 5 [30208/60000 (50%)]	Loss: 0.307383
Train Epoch: 5 [45568/60000 (75%)]	Loss: 0.370547

Test set: Average loss: 0.0007, Accuracy: 8693/10000 (87%)

Train Epoch: 6 [14848/60000 (25%)]	Loss: 0.279029
Train Epoch: 6 [30208/60000 (50%)]	Loss: 0.322744
Train Epoch: 6 [45568/60000 (75%)]	Loss: 0.324377

Test set: Average loss: 0.0007, Accuracy: 8635/10000 (86%)

Train Epoch: 7 [14848/60000 (25%)]	Loss: 0.249461
Train Epoch: 7 [30208/60000 (50%)]	Loss: 0.297717
Train Epoch: 7 [45568/60000 (75%)]	Loss: 0.339295

Test set: Average loss: 0.0007, Accuracy: 8735/10000 (87%)

Train Epoch: 8 [14848/60000 (25%)]	Loss: 0.306152
Train Epoch: 8 [30208/60000 (50%)]	Loss: 0.317238
Train Epoch: 8 [45568/60000 (75%)]	Loss: 0.282634

Test set: Average loss: 0.0007, Accuracy: 8751/10000 (88%)

Train Epoch: 9 [14848/60000 (25%)]	Loss: 0.260427
Train Epoch: 9 [30208/60000 (50%)]	Loss: 0.238925
Train Epoch: 9 [45568/60000 (75%)]	Loss: 0.262525

Test set: Average loss: 0.0007, Accuracy: 8801/10000 (88%)

Train Epoch: 10 [14848/60000 (25%)]	Loss: 0.269448
Train Epoch: 10 [30208/60000 (50%)]	Loss: 0.225632
Train Epoch: 10 [45568/60000 (75%)]	Loss: 0.249671

Test set: Average loss: 0.0007, Accuracy: 8721/10000 (87%)