1 import torch 2 from torch import optim,nn 3 import visdom 4 import torchvision 5 from torch.utils.data import DataLoader 6 7 from pokemon import Pokemon 8 9 # from resnet import ResNet18 10 # 可以加载直接加载好的状态 11 from torchvision.models import resnet18 12 13 from utils import Flatten 14 15 batchsz = 32 16 lr = 1e-3 17 epochs = 10 18 19 device = torch.device('cuda') 20 # 设置随机种子保证能够复现出来 21 torch.manual_seed(1234) 22 23 train_db = Pokemon('pokemon',224,mode = 'train') 24 val_db = Pokemon('pokemon',224,mode = 'val') 25 test_db = Pokemon('pokemon',224,mode = 'test') 26 27 train_loader = DataLoader(train_db,batch_size = batchsz,shuffle = True,num_workers = 4) 28 val_loader = DataLoader(val_db,batch_size = batchsz,num_workers = 2) 29 test_loader = DataLoader(test_db,batch_size = batchsz,num_workers = 2) 30 31 # visdom 32 viz = visdom.Visdom() 33 34 def evalute(model,loader): 35 36 correct = 0 37 total = len(loader.dataset) 38 39 for x,y in loader: 40 x,y = x.to(device),y.to(device) 41 with torch.no_grad(): 42 logits = model(x) 43 pred = logits.argmax(dim = 1) 44 correct += torch.eq(pred,y).sum().float().item() 45 46 return correct / total 47 48 def main(): 49 50 # model = ResNet18(5).to(device) 51 trained_model = resnet18(pretrained = True) 52 # 取出前17层,加*打散数据 53 model = nn.Sequential(*list(train_model.children())[:-1], # [b,512,1,1] 54 Flatten(), # [b,512,1,1] --> [b,512] 55 nn.Linear(512,5) 56 ).to(device) 57 58 optimizer = optim.Adam(model.parameters().lr = lr) 59 criteon = nn.CrossEntropyLoss 60 61 best_acc,best_epoch = 0,0 62 global_step = 0 63 # visdom 64 viz.line([0],[-1],win = 'loss',opts = dict(title = 'loss')) 65 viz.line([0],[-1],win = 'val_acc',opts = dict(title = 'val_acc')) 66 67 for epoch in range(epochs): 68 69 for step,(x,y) in enumerate(train_loader): 70 71 # x: [b,3,224,224] ,y : [b] 72 x,y = x.to(device),y.to(device) 73 74 # logits是没经过loss的 75 logits = model(x) 76 # CrossEntropyLoss会在内部进行onehot,所以不需要自己写 77 loss = criteon(logits,y).item() 78 79 optimizer.zero_grad() 80 loss.backward() 81 optimizer.step() 82 83 # visdom 84 viz.line([loss.item()],[global_step],win = 'loss',update = 'append') 85 global_step += 1 86 87 if epoch % 2 == 0: 88 89 val_acc = evalute(model,val_loader) 90 91 if val_acc > best_acc: 92 best_epoch = epoch 93 best_acc = val_acc 94 95 torch.save(model.state_dict(),'best.mdl') 96 # visdom 97 viz.line([val_acc],[global_step],win = 'val_acc',update = 'append') 98 99 print('best_acc:',best_acc,'best_epoch',best_epoch) 100 101 model.load_state_dict(torch.load('best.mdl')) 102 print('loaded from skpt!') 103 104 test_acc = evalute(model,test_loader) 105 print('test_acc',test_acc) 106 107 108 if __name__ == '__main__' 109 main()
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:龙良曲pytorch学习笔记_迁移学习 - Python技术站