当进行PyTorch实验时,我们经常需要使用一些常用的代码段来完成模型训练、数据处理、可视化等任务。本文将详细讲解PyTorch实验常用代码段汇总,并提供两个示例说明。
1. 模型训练
在PyTorch中,我们可以使用torch.optim模块中的优化器和nn模块中的损失函数来训练模型。以下是模型训练的示例代码:
import torch
import torch.nn as nn
import torch.optim as optim
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# 实例化模型、损失函数和优化器
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 训练模型
for epoch in range(10):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print('Epoch %d loss: %.3f' % (epoch + 1, running_loss / len(trainloader)))
在上面的代码中,我们首先定义了一个包含两个全连接层的模型Net。然后,我们实例化了该模型、损失函数和优化器。接下来,我们使用for循环训练模型,其中每个epoch包含多个batch。在每个batch中,我们首先将优化器的梯度清零,然后计算模型的输出和损失,并使用反向传播更新模型参数。最后,我们输出每个epoch的平均损失。
2. 数据处理
在PyTorch中,我们可以使用torch.utils.data模块中的Dataset和DataLoader来处理数据。以下是数据处理的示例代码:
import torch
import torchvision
import torchvision.transforms as transforms
# 定义数据增强和标准化
transform = transforms.Compose(
[transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 加载CIFAR10数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
# 输出数据集大小
print('Trainset size:', len(trainset))
# 输出数据集类别
classes = trainset.classes
print('Classes:', classes)
在上面的代码中,我们首先定义了数据增强和标准化的方法,并使用transforms.Compose()方法将它们组合起来。然后,我们使用torchvision.datasets模块中的CIFAR10()方法加载CIFAR10数据集,并使用torch.utils.data模块中的DataLoader()方法将数据集转换为可迭代的数据加载器。接下来,我们输出了数据集的大小和类别。
3. 示例3:模型保存和加载
在PyTorch中,我们可以使用torch.save()方法将模型保存到文件中,并使用torch.load()方法从文件中加载模型。以下是模型保存和加载的示例代码:
import torch
import torch.nn as nn
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# 实例化模型
net = Net()
# 保存模型
torch.save(net.state_dict(), 'model.pth')
# 加载模型
net.load_state_dict(torch.load('model.pth'))
在上面的代码中,我们首先定义了一个包含两个全连接层的模型Net,并实例化了该模型。然后,我们使用torch.save()方法将模型的参数保存到文件model.pth中。接下来,我们使用torch.load()方法从文件中加载模型的参数,并使用net.load_state_dict()方法将参数加载到模型中。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch实验常用代码段汇总 - Python技术站