# ====================LeNet-5_main.py===============
# pytorch+torchvision+visdom
  1 # -*- coding: utf-8 -*-
  2 """
  3 Created on Sun May 26 22:53:52 2019
  4 
  5 @author: jiangshan
  6 """
  7 #A modified LeNet-5 [LeCun et al., 1998a] on the MNIST dataset.
  8 import torch
  9 import torch.nn as nn
 10 import torch.optim as optim
 11 from torchvision.datasets.mnist import MNIST
 12 import torchvision.transforms as transforms
 13 from torch.utils.data import DataLoader
 14 import visdom
 15 from collections import OrderedDict
 16 
 17 class LeNet5(nn.Module):
 18     """
 19     Input - 1x32x32
 20     C1 - 6@28x28 (5x5 kernel)
 21     relu
 22     S2 - 6@14x14 (2x2 kernel, stride 2) Subsampling
 23     C3 - 16@10x10 (5x5 kernel, complicated shit)
 24     relu
 25     S4 - 16@5x5 (2x2 kernel, stride 2) Subsampling
 26     C5 - 120@1x1 (5x5 kernel)
 27     F6 - 84
 28     relu
 29     F7 - 10 (Output)
 30     """
 31     def __init__(self):
 32         super(LeNet5, self).__init__()
 33 
 34         self.convnet = nn.Sequential(OrderedDict([
 35             ('c1', nn.Conv2d(1, 6, kernel_size=(5, 5))),
 36             ('relu1', nn.ReLU()),
 37             ('s2', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
 38             ('c3', nn.Conv2d(6, 16, kernel_size=(5, 5))),
 39             ('relu3', nn.ReLU()),
 40             ('s4', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
 41             ('c5', nn.Conv2d(16, 120, kernel_size=(5, 5))),
 42             ('relu5', nn.ReLU())
 43         ]))
 44 
 45         self.fc = nn.Sequential(OrderedDict([
 46             ('f6', nn.Linear(120, 84)),
 47             ('relu6', nn.ReLU()),
 48             ('f7', nn.Linear(84, 10)),
 49             ('sig7', nn.LogSoftmax(dim=-1))
 50         ]))
 51 
 52     def forward(self, img):
 53         output = self.convnet(img)
 54         output = output.view(img.size(0), -1)
 55         output = self.fc(output)
 56         return output
 57 
 58 
 59 viz = visdom.Visdom()
 60 data_train = MNIST('./data/mnist',
 61                    download=True,
 62                    transform=transforms.Compose([
 63                        transforms.Resize((32, 32)),
 64                        transforms.ToTensor()]))
 65 data_test = MNIST('./data/mnist',
 66                   train=False,
 67                   download=True,
 68                   transform=transforms.Compose([
 69                       transforms.Resize((32, 32)),
 70                       transforms.ToTensor()]))
 71 data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)
 72 data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8)
 73 
 74 net = LeNet5()
 75 criterion = nn.CrossEntropyLoss()
 76 optimizer = optim.Adam(net.parameters(), lr=2e-3)
 77 
 78 cur_batch_win = None
 79 cur_batch_win_opts = {
 80     'title': 'Epoch Loss Trace',
 81     'xlabel': 'Batch Number',
 82     'ylabel': 'Loss',
 83     'width': 1200,
 84     'height': 600,
 85 }
 86 
 87 
 88 def train(epoch):
 89     global cur_batch_win
 90     net.train()
 91     loss_list, batch_list = [], []
 92     for i, (images, labels) in enumerate(data_train_loader):
 93         optimizer.zero_grad()
 94 
 95         output = net(images)
 96 
 97         loss = criterion(output, labels)
 98 
 99         loss_list.append(loss.detach().cpu().item())
100         batch_list.append(i+1)
101 
102         if i % 10 == 0:
103             print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.detach().cpu().item()))
104 
105         # Update Visualization
106         if viz.check_connection():
107             cur_batch_win = viz.line(torch.Tensor(loss_list), torch.Tensor(batch_list),
108                                      win=cur_batch_win, name='current_batch_loss',
109                                      update=(None if cur_batch_win is None else 'replace'),
110                                      opts=cur_batch_win_opts)
111         loss.backward()
112         optimizer.step()
113 
114 
115 def test():
116     net.eval()
117     total_correct = 0
118     avg_loss = 0.0
119     for i, (images, labels) in enumerate(data_test_loader):
120         output = net(images)
121         avg_loss += criterion(output, labels).sum()
122         pred = output.detach().max(1)[1]
123         total_correct += pred.eq(labels.view_as(pred)).sum()
124 
125     avg_loss /= len(data_test)
126     print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.detach().cpu().item(), float(total_correct) / len(data_test)))
127 
128 
129 def train_and_test(epoch):
130     train(epoch)
131     test()
132 
133 
134 def main():
135     for e in range(1, 16):
136         train_and_test(e)
137 
138 
139 if __name__ == '__main__':
140     main()

先开启visdom 进行可视化

python -m visdom.server

运行程序

python LeNet-5_main.py

打开浏览器查看live graph

http://localhost:8097