model.state_dict()中保存了{参数名:参数值}的字典
import torchvision.models as models
resnet34 = models.resnet34(pretrained=True)
resnet34.state_dict().keys()
for param in resnet34.parameters():
param.requires_grad = False
resnet.fc = nn.Linear(resnet.fc.in_features, 100)
# resnet.fc = nn.Sequential(nn.Linear(512, 100),
# nn.ReLU(),
# nn.Linear(100, 10))
保存模型torch.save(model.state_dict(), PATH) # 保存模型为pth
导入模型
model = ModelClass() # 需要先建立模型
model.load_state_dict(torch.load(PATH)) # 加载模型
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch保存模型和导入模型以及预训练模型 - Python技术站