state_dict()函数可以返回所有的状态数据。load_state_dict()函数可以加载这些状态数据。

推荐使用:

#保存
t.save(net.state_dict(),"net.pth")
#加载
net2=Net()
net2.load_state_dict(t.load("net.pth"))

不推荐直接save与load,因为这种方式严重依赖模型定义方法以及文件路径结构等,容易出问题。

t.save(net,"net.pth")
net2=t.load("net.pth")

 【PyTorch中已封装的网络模型】https://pytorch.org/docs/stable/torchvision/index.html

PyTorch保存、加载模型,PyTorch中已封装的网络模型

 从上图看出,有针对分类问题、语义分割、目标识别、视频分类的模型。

以分类模型为例,PyTorch中已封装的模型如下:

PyTorch保存、加载模型,PyTorch中已封装的网络模型

 使用方式,参考标黄部分

######################################## 1、使用torchvision加载并预处理数据集

#### datasets的ImageFolder读图
from torchvision.datasets import ImageFolder
dataset=ImageFolder("E:/data/dogcat_2/train/") #获取路径,返回的是所有图的data、label
from torchvision import transforms as T #设置格式化条件
transform=T.Compose([T.Resize((64,64)), 
                     T.ToTensor(), #PIL Image转Tensor,[0,255]自动归一化为[0,1]
                     T.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5]) #标准化,减均值除标准差
                    ])
dataset=ImageFolder("E:/data/dogcat_2/train/",transform=transform)
testset=ImageFolder("E:/data/dogcat_2/test/",transform=transform)

#### DataLoader
from torch.utils.data import DataLoader
dataloader=DataLoader( dataset,batch_size=4,shuffle=True,num_workers=2 )
testloader=DataLoader(testset,batch_size=4,shuffle=True,num_workers=2)

#### 显示第1个batch的4幅图(随机)
from torchvision.transforms import ToPILImage
from torchvision.utils import make_grid
dataiter = iter(dataloader)
(images, labels) = dataiter.next()
print(labels) #打印标签
show=ToPILImage() 
show(make_grid(images*0.5+0.5)).resize((4*64,64)) 

######################################## 2、定义网络
from torchvision import models
net=models.alexnet()

######################################## 3、定义损失函数和优化器
import torch.nn as nn
from torch import optim
criterion=nn.CrossEntropyLoss() #交叉熵损失函数
optimizer=optim.SGD(net.parameters(),lr=0.001,momentum=0.9) #随机梯度下降法,指定要调整的参数和学习率,动量算法加速更新权重

######################################## 4、训练网络并更新网络参数
for epoch in range(2):  # 在整个数据集上轮番训练多次,轮训一次叫一个回合(epoch)

    running_loss = 0.0
    for i, data in enumerate(dataloader, 0):
        
        # 输入数据
        inputs, labels = data
        
        # 梯度清零
        optimizer.zero_grad()

        # forward + backward
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        #更新参数
        optimizer.step()

        # 打印一些关于训练的统计信息
        running_loss += loss.item()
        if i % 200 == 199:    # 每 200 个batch打印一次
            print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 200))
            running_loss = 0.0

print('Finished Training')

######################################## 5、测试网络
import torchvision as tv
import torch as t
#datasets测试集中前4幅图,并输出标签
dataiter = iter(testloader)
(images, labels) = dataiter.next() #返回1个batch(4张图)

# 输出图像和正确的类标签
#print('实际的label:', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
show(tv.utils.make_grid((images+1)/2)).resize((400,100))

#测试
outputs = net(images) #预测上边得到的batch(4张图),返回得分(每一类都打分)
_, predicted = t.max(outputs, 1) #每1张图得分最高的那个类的下标

print(outputs)
print(predicted)
#print('预测结果:', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
show(tv.utils.make_grid((images+1)/2)).resize((400,100))

#测试整个测试集
correct = 0 #预测正确的图片数
total = 0 #总共的图片数
with t.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = t.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('10000张测试集中的准确率: %d %%' % (100 * correct / total))