1 import torch
  2 from torch import optim,nn
  3 import visdom
  4 import torchvision
  5 from torch.utils.data import DataLoader
  6 
  7 from pokemon import Pokemon
  8 
  9 # from resnet import ResNet18
 10 # 可以加载直接加载好的状态
 11 from torchvision.models import resnet18
 12 
 13 from utils import Flatten
 14 
 15 batchsz = 32
 16 lr = 1e-3
 17 epochs = 10
 18 
 19 device = torch.device('cuda')
 20 # 设置随机种子保证能够复现出来
 21 torch.manual_seed(1234)
 22 
 23 train_db = Pokemon('pokemon',224,mode = 'train')
 24 val_db = Pokemon('pokemon',224,mode = 'val')
 25 test_db = Pokemon('pokemon',224,mode = 'test')
 26 
 27 train_loader = DataLoader(train_db,batch_size = batchsz,shuffle = True,num_workers = 4)
 28 val_loader = DataLoader(val_db,batch_size = batchsz,num_workers = 2)
 29 test_loader = DataLoader(test_db,batch_size = batchsz,num_workers = 2)
 30 
 31 # visdom
 32 viz = visdom.Visdom()
 33 
 34 def evalute(model,loader):
 35     
 36     correct = 0
 37     total = len(loader.dataset)
 38     
 39     for x,y in loader:
 40         x,y = x.to(device),y.to(device)
 41         with torch.no_grad():
 42             logits = model(x)
 43             pred = logits.argmax(dim = 1)
 44             correct += torch.eq(pred,y).sum().float().item()
 45         
 46     return correct / total
 47 
 48 def main():
 49     
 50     # model = ResNet18(5).to(device)
 51     trained_model = resnet18(pretrained = True)
 52     # 取出前17层,加*打散数据
 53     model = nn.Sequential(*list(train_model.children())[:-1], # [b,512,1,1]
 54                           Flatten(), # [b,512,1,1] --> [b,512]
 55                           nn.Linear(512,5)
 56                           ).to(device)
 57     
 58     optimizer = optim.Adam(model.parameters().lr = lr)
 59     criteon = nn.CrossEntropyLoss
 60     
 61     best_acc,best_epoch = 0,0
 62     global_step = 0
 63     # visdom
 64     viz.line([0],[-1],win = 'loss',opts = dict(title = 'loss'))
 65     viz.line([0],[-1],win = 'val_acc',opts = dict(title = 'val_acc'))
 66     
 67     for epoch in range(epochs):
 68         
 69         for step,(x,y) in enumerate(train_loader):
 70             
 71             # x: [b,3,224,224] ,y : [b]
 72             x,y = x.to(device),y.to(device)
 73             
 74             # logits是没经过loss的
 75             logits = model(x)
 76             # CrossEntropyLoss会在内部进行onehot,所以不需要自己写
 77             loss = criteon(logits,y).item()
 78             
 79             optimizer.zero_grad()
 80             loss.backward()
 81             optimizer.step()
 82             
 83             # visdom
 84             viz.line([loss.item()],[global_step],win = 'loss',update = 'append')
 85             global_step += 1
 86             
 87         if epoch % 2 == 0:
 88         
 89             val_acc = evalute(model,val_loader)
 90             
 91             if val_acc > best_acc:
 92                 best_epoch = epoch
 93                 best_acc = val_acc
 94                 
 95                 torch.save(model.state_dict(),'best.mdl')
 96                 # visdom
 97                 viz.line([val_acc],[global_step],win = 'val_acc',update = 'append')
 98                 
 99     print('best_acc:',best_acc,'best_epoch',best_epoch)
100     
101     model.load_state_dict(torch.load('best.mdl'))
102     print('loaded from skpt!')
103     
104     test_acc = evalute(model,test_loader)
105     print('test_acc',test_acc)
106 
107 
108 if __name__ == '__main__'
109     main()