一、K折交叉验证

将训练集分成K份,一份做验证集,其他做测试集。这K份都有机会做验证集

pytorch(二十一):交叉验证

 

 pytorch(二十一):交叉验证

 

 pytorch(二十一):交叉验证

 

 

二、代码

  1 import torch
  2 import torch.nn as nn
  3 import torchvision 
  4 from torchvision import datasets,transforms
  5 from torch.nn import functional as F
  6 import torch.optim as optim
  7 
  8 
  9 batch_size = 200
 10 learning_rate  = 1e-2
 11 epochs = 10
 12 train_db =  datasets.MNIST('datasets/mnist_data',
 13                 train=True,
 14                 download=True,
 15                 transform=torchvision.transforms.Compose([
 16                 torchvision.transforms.ToTensor(),                       # 数据类型转化
 17                 torchvision.transforms.Normalize((0.1307, ), (0.3081, )) # 数据归一化处理
 18     ]))
 19 
 20 train_loader = torch.utils.data.DataLoader(
 21         train_db,
 22         batch_size = batch_size,
 23         shuffle = True)
 24 
 25 test_db = datasets.MNIST('datasets/mnist_data/',
 26                 train=False,
 27                 download=True,
 28                 transform=torchvision.transforms.Compose([
 29                 torchvision.transforms.ToTensor(),
 30                 torchvision.transforms.Normalize((0.1307, ), (0.3081, ))
 31     ]))
 32 
 33 test_loader = torch.utils.data.DataLoader(
 34         test_db,
 35         batch_size = batch_size,
 36         shuffle = True
 37 )
 38 
 39 print('train:', len(train_db), 'test:', len(test_db))
 40 train_db, val_db = torch.utils.data.random_split(train_db, [50000, 10000])
 41 print('db1:', len(train_db), 'db2:', len(val_db))
 42 train_loader = torch.utils.data.DataLoader(
 43     train_db,
 44     batch_size=batch_size, shuffle=True)
 45 val_loader = torch.utils.data.DataLoader(
 46     val_db,
 47     batch_size=batch_size, shuffle=True)
 48 
 49 class MLP(nn.Module):
 50 
 51     def __init__(self):
 52         super(MLP, self).__init__()
 53 
 54         self.model = nn.Sequential(
 55             nn.Linear(784, 200),
 56             nn.LeakyReLU(inplace=True),
 57             nn.Linear(200, 200),
 58             nn.LeakyReLU(inplace=True),
 59             nn.Linear(200, 10),
 60             nn.LeakyReLU(inplace=True),
 61         )
 62 
 63     def forward(self, x):
 64         x = self.model(x)
 65 
 66         return x
 67 
 68 device = torch.device('cuda:0')
 69 net = MLP().to(device)
 70 optimizer = optim.SGD(net.parameters(), lr=learning_rate)
 71 criteon = nn.CrossEntropyLoss().to(device)
 72 
 73 for epoch in range(epochs):
 74 
 75     for batch_idx, (data, target) in enumerate(train_loader):
 76         data = data.view(-1, 28*28)
 77         data, target = data.to(device), target.cuda()
 78 
 79         logits = net(data)
 80         loss = criteon(logits, target)
 81 
 82         optimizer.zero_grad()
 83         loss.backward()
 84         # print(w1.grad.norm(), w2.grad.norm())
 85         optimizer.step()
 86 
 87         if batch_idx % 100 == 0:
 88             print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
 89                 epoch, batch_idx * len(data), len(train_loader.dataset),
 90                        100. * batch_idx / len(train_loader), loss.item()))
 91 
 92 
 93     test_loss = 0
 94     correct = 0
 95     for data, target in val_loader:
 96         data = data.view(-1, 28 * 28)
 97         data, target = data.to(device), target.cuda()
 98         logits = net(data)
 99         test_loss += criteon(logits, target).item()
100 
101         pred = logits.data.max(1)[1]
102         correct += pred.eq(target.data).sum()
103 
104     test_loss /= len(val_loader.dataset)
105     print('\nVAL set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
106         test_loss, correct, len(val_loader.dataset),
107         100. * correct / len(val_loader.dataset)))
108 
109 
110 
111 test_loss = 0
112 correct = 0
113 for data, target in test_loader:
114     data = data.view(-1, 28 * 28)
115     data, target = data.to(device), target.cuda()
116     logits = net(data)
117     test_loss += criteon(logits, target).item()
118 
119     pred = logits.data.max(1)[1]
120     correct += pred.eq(target.data).sum()
121 
122 test_loss /= len(test_loader.dataset)
123 print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
124     test_loss, correct, len(test_loader.dataset),
125     100. * correct / len(test_loader.dataset)))