PyTorch模型保存与加载实例详解
在PyTorch中,模型的保存和加载是深度学习开发中的重要任务之一。本文将介绍如何使用PyTorch保存和加载模型,并演示两个示例。
保存模型
在PyTorch中,可以使用torch.save()函数将模型保存到磁盘上。torch.save()函数接受两个参数:要保存的对象和文件路径。下面是一个示例代码:
import torch
import torch.nn as nn
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(3, 2)
self.fc2 = nn.Linear(2, 1)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
net = Net()
# 保存模型
torch.save(net.state_dict(), 'model.pth')
在上述代码中,我们首先定义了一个模型net,然后使用torch.save()函数将其模型参数保存到文件'model.pth'中。
加载模型
在PyTorch中,可以使用torch.load()函数加载保存的模型。torch.load()函数接受一个参数:文件路径。下面是一个示例代码:
import torch
import torch.nn as nn
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(3, 2)
self.fc2 = nn.Linear(2, 1)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
net = Net()
# 加载模型
net.load_state_dict(torch.load('model.pth'))
在上述代码中,我们首先定义了一个模型net,然后使用torch.load()函数加载保存的模型参数,并使用net.load_state_dict()函数将其加载到模型中。
示例
示例一:保存和加载模型
import torch
import torch.nn as nn
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(3, 2)
self.fc2 = nn.Linear(2, 1)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
net = Net()
# 保存模型
torch.save(net.state_dict(), 'model.pth')
# 加载模型
net.load_state_dict(torch.load('model.pth'))
在上述代码中,我们首先定义了一个模型net,然后使用torch.save()函数将其模型参数保存到文件'model.pth'中。接着,我们使用torch.load()函数加载保存的模型参数,并使用net.load_state_dict()函数将其加载到模型中。
示例二:保存和加载整个模型
import torch
import torch.nn as nn
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(3, 2)
self.fc2 = nn.Linear(2, 1)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
net = Net()
# 保存整个模型
torch.save(net, 'model.pth')
# 加载整个模型
net = torch.load('model.pth')
在上述代码中,我们首先定义了一个模型net,然后使用torch.save()函数将整个模型保存到文件'model.pth'中。接着,我们使用torch.load()函数加载整个模型,并将其赋值给net变量。
总之,使用PyTorch保存和加载模型是深度学习开发中的重要任务之一。开发者可以根据自己的需求选择合适的方法来保存和加载模型。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch模型保存与加载实例详解 - Python技术站