〇、基本流程

加载数据->搭建模型->训练->测试

 

一、加载数据

通过使用torch.utils.data.DataLoader和torchvision.datasets两个模块可以很方便地去获取常用数据集(手写数字MNIST、分类CIFAR),以及将其加载进来。

1.加载内置数据集


 import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms

 1 train_loader = torch.utils.data.DataLoader(
 2     torchvision.datasets.MNIST('mnist_data', train=True, download=True,
 3                                transform=torchvision.transforms.Compose([
 4                                    torchvision.transforms.ToTensor(),
 5                                    torchvision.transforms.Normalize(
 6                                        (0.1307,), (0.3081,))
 7                                ])),
 8     batch_size=batch_size, shuffle=True)
 9 # train 是否为训练集,         download 数据集不存在时是否下载数据集
10 # ToTensor() 转换成tensor格式,Normalize() 归一化,将数据作(data-mean)/std
11 # batch_size 加载一批数量,    shuffle 是否打散数据

2.加载自定义数据集

用torchvision.datasets.ImageFolder加载图片数据集

 

二、搭建模型

一个模型可以表示为python的一个类,这个类要继承torch.nn.modules.Module,并且实现forward( )方法

 1 class Lenet5(nn.Module):
 2     """
 3     for CIFAR10
 4     """
 5     def __init__(self):
 6         super(Lenet5, self).__init__()
 7 
 8         # 两层卷积
 9         self.conv_unit = nn.Sequential(
10             
11             # 3表示input,可以理解为图片的通道数量,即我的卷积核一次要到几个tensor去作卷积
12             # 6表示有多少个卷积核
13             # stride表示卷积核移动步长,padding表示边缘扩充
14             nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),    # 卷积
15             nn.AvgPool2d(kernel_size=2, stride=2, padding=0),      # 池化
16 
17             nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
18             nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
19         )
20 
21         # 3层全连接层
22         self.fc_unit = nn.Sequential(
23             nn.Linear(16*5*5, 120),
24             nn.ReLU(),
25             nn.Linear(120, 84),
26             nn.ReLU(),
27             nn.Linear(84, 10)
28         )
31 
32 
33     def forward(self, x):                # 数据从此进来,经过定义好的各层网络,最终输出
34         batchsz = x.size(0)
35         x = self.conv_unit(x)
36         x = x.view(batchsz, 16*5*5)          # 经过卷积层后,对数据维度作处理,以适应全连接层
37         logits = self.fc_unit(x)
38         return logits

 

三、训练

训练过程可以认为是对参数优化的过程,通过输入数据,得到输出,计算损失(误差),再经过误差反向传播得到梯度信息,以更新参数。

 1     # 实例模型、配置损失函数、优化器
 2     device = torch.device('cuda')                       # 转为GPU上执行
 3     model = Lenet5().to(device)                         # 实例化模型
 4     criteon = nn.CrossEntropyLoss().to(device)          # 损失函数
 5     optimizer = optim.Adam(model.parameters(), lr=1e-3) # 优化器
 6     print(model)
 7 
 8     # 训练
 9     for epoch in range(1000):                   # 迭代1000次
10         model.train()                           # 模型切换为训练模式
11         for batchidx, (x, label) in enumerate(cifar_train):
12             x, label = x.to(device), label.to(device)
13             logist = model(x)                   # 得到模型的输出
14             loss = criteon(logist, label)       # 计算损失
15             optimizer.zero_grad()               # 旧梯度清零
16             loss.backward()                     # 误差反向传播
17             optimizer.step()                    # 梯度更新
18 
19         print(epoch, loss.item())

 

四、测试

当模型训练完毕后,进行数据测试。

 1      model.eval()                 # 切换为验证模式
 2         with torch.no_grad():            # 不进行梯度更新
 3             total_correct = 0                   # 记录正确的数据量
 4             total_num = 0                       # 记录总数据量
 5             for x, label in cifar_test:
 6                 x, label = x.to(device), label.to(device)
 7                 logist = model(x)               # 获得模型输出
 8                 pred = logist.argmax(dim=1)     # 取值最大的下标,在这里恰好对应图片标签
 9                 
10                 # eq(pred, label)表示比较预测值和实际标签
11                 total_correct += torch.eq(pred, label).float().sum().item()
12                 total_num += x.size(0)
13             acc = total_correct / total_num     # 计算正确率

 

五、其他

1、保存与加载模型

即当模型训练好之后,将模型保存,下一次可以直接使用。

1 torch.save(model.state_dict(), 'best.mdl')      # 保存模型
2 
3 model.load_state_dict(torch.load('best.mdl'))   # 加载模型