state_dict()函数可以返回所有的状态数据。load_state_dict()函数可以加载这些状态数据。
推荐使用:
#保存 t.save(net.state_dict(),"net.pth") #加载 net2=Net() net2.load_state_dict(t.load("net.pth"))
不推荐直接save与load,因为这种方式严重依赖模型定义方法以及文件路径结构等,容易出问题。
t.save(net,"net.pth") net2=t.load("net.pth")
【PyTorch中已封装的网络模型】https://pytorch.org/docs/stable/torchvision/index.html
从上图看出,有针对分类问题、语义分割、目标识别、视频分类的模型。
以分类模型为例,PyTorch中已封装的模型如下:
使用方式,参考标黄部分
######################################## 1、使用torchvision加载并预处理数据集 #### datasets的ImageFolder读图 from torchvision.datasets import ImageFolder dataset=ImageFolder("E:/data/dogcat_2/train/") #获取路径,返回的是所有图的data、label from torchvision import transforms as T #设置格式化条件 transform=T.Compose([T.Resize((64,64)), T.ToTensor(), #PIL Image转Tensor,[0,255]自动归一化为[0,1] T.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5]) #标准化,减均值除标准差 ]) dataset=ImageFolder("E:/data/dogcat_2/train/",transform=transform) testset=ImageFolder("E:/data/dogcat_2/test/",transform=transform) #### DataLoader from torch.utils.data import DataLoader dataloader=DataLoader( dataset,batch_size=4,shuffle=True,num_workers=2 ) testloader=DataLoader(testset,batch_size=4,shuffle=True,num_workers=2) #### 显示第1个batch的4幅图(随机) from torchvision.transforms import ToPILImage from torchvision.utils import make_grid dataiter = iter(dataloader) (images, labels) = dataiter.next() print(labels) #打印标签 show=ToPILImage() show(make_grid(images*0.5+0.5)).resize((4*64,64)) ######################################## 2、定义网络 from torchvision import models net=models.alexnet() ######################################## 3、定义损失函数和优化器 import torch.nn as nn from torch import optim criterion=nn.CrossEntropyLoss() #交叉熵损失函数 optimizer=optim.SGD(net.parameters(),lr=0.001,momentum=0.9) #随机梯度下降法,指定要调整的参数和学习率,动量算法加速更新权重 ######################################## 4、训练网络并更新网络参数 for epoch in range(2): # 在整个数据集上轮番训练多次,轮训一次叫一个回合(epoch) running_loss = 0.0 for i, data in enumerate(dataloader, 0): # 输入数据 inputs, labels = data # 梯度清零 optimizer.zero_grad() # forward + backward outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() #更新参数 optimizer.step() # 打印一些关于训练的统计信息 running_loss += loss.item() if i % 200 == 199: # 每 200 个batch打印一次 print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 200)) running_loss = 0.0 print('Finished Training') ######################################## 5、测试网络 import torchvision as tv import torch as t #datasets测试集中前4幅图,并输出标签 dataiter = iter(testloader) (images, labels) = dataiter.next() #返回1个batch(4张图) # 输出图像和正确的类标签 #print('实际的label:', ' '.join('%5s' % classes[labels[j]] for j in range(4))) show(tv.utils.make_grid((images+1)/2)).resize((400,100)) #测试 outputs = net(images) #预测上边得到的batch(4张图),返回得分(每一类都打分) _, predicted = t.max(outputs, 1) #每1张图得分最高的那个类的下标 print(outputs) print(predicted) #print('预测结果:', ' '.join('%5s' % classes[predicted[j]] for j in range(4))) show(tv.utils.make_grid((images+1)/2)).resize((400,100)) #测试整个测试集 correct = 0 #预测正确的图片数 total = 0 #总共的图片数 with t.no_grad(): for data in testloader: images, labels = data outputs = net(images) _, predicted = t.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print('10000张测试集中的准确率: %d %%' % (100 * correct / total))
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch保存、加载模型,PyTorch中已封装的网络模型 - Python技术站