PyTorch零基础入门之逻辑斯蒂回归
本文将介绍如何使用PyTorch实现逻辑斯蒂回归模型。逻辑斯蒂回归是一种二元分类模型,它可以用于预测一个样本属于两个类别中的哪一个。
1. 数据集
我们将使用Iris数据集进行逻辑斯蒂回归模型的训练和测试。该数据集包含150个样本,每个样本包含4个特征和1个标签。我们将使用前100个样本作为训练集,后50个样本作为测试集。
import pandas as pd
import numpy as np
# 加载数据集
url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data'
names = ['sepal-length', 'sepal-width', 'petal-length', 'petal-width', 'class']
dataset = pd.read_csv(url, names=names)
# 将标签转换为0和1
dataset['class'] = np.where(dataset['class'] == 'Iris-setosa', 0, 1)
# 将数据集分为训练集和测试集
train_dataset = dataset[:100]
test_dataset = dataset[100:]
2. 模型搭建
我们将使用PyTorch搭建一个简单的逻辑斯蒂回归模型。该模型包含一个线性层和一个sigmoid激活函数。
import torch.nn as nn
class LogisticRegression(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(4, 1)
def forward(self, x):
x = self.linear(x)
x = nn.functional.sigmoid(x)
return x
model = LogisticRegression()
3. 模型训练
我们将使用二元交叉熵损失函数和随机梯度下降(SGD)优化器进行模型训练。我们将模型训练100个epoch,并在每个epoch结束时计算训练集和测试集的准确率。
import torch.optim as optim
# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 训练模型
for epoch in range(100):
# 在训练集上训练模型
train_correct = 0
train_total = 0
for index, row in train_dataset.iterrows():
optimizer.zero_grad()
inputs = torch.tensor(row[:4].values, dtype=torch.float32)
label = torch.tensor(row[4], dtype=torch.float32)
outputs = model(inputs)
loss = criterion(outputs, label.unsqueeze(0))
loss.backward()
optimizer.step()
predicted = torch.round(outputs)
train_total += 1
train_correct += (predicted == label).sum().item()
# 在测试集上测试模型
test_correct = 0
test_total = 0
with torch.no_grad():
for index, row in test_dataset.iterrows():
inputs = torch.tensor(row[:4].values, dtype=torch.float32)
label = torch.tensor(row[4], dtype=torch.float32)
outputs = model(inputs)
predicted = torch.round(outputs)
test_total += 1
test_correct += (predicted == label).sum().item()
# 打印准确率
print(f'Epoch {epoch+1}, Train Accuracy: {train_correct/train_total}, Test Accuracy: {test_correct/test_total}')
4. 模型预测
我们使用训练好的模型来预测测试集中的样本,并将预测结果与实际结果进行比较。
# 预测测试集中的样本
with torch.no_grad():
inputs = torch.tensor(test_dataset.iloc[0,:4].values, dtype=torch.float32)
label = torch.tensor(test_dataset.iloc[0,4], dtype=torch.float32)
outputs = model(inputs)
predicted = torch.round(outputs).item()
# 打印预测结果和实际结果
print(f'Actual: {label}, Predicted: {predicted}')
示例2:使用GPU加速的逻辑斯蒂回归模型
如果你的机器上有GPU,你可以使用PyTorch的GPU加速功能来加速模型训练和预测。以下是使用GPU加速的逻辑斯蒂回归模型的示例代码。
import torch.nn as nn
class LogisticRegression(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(4, 1)
def forward(self, x):
x = self.linear(x)
x = nn.functional.sigmoid(x)
return x
model = LogisticRegression().cuda()
# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 训练模型
for epoch in range(100):
# 在训练集上训练模型
train_correct = 0
train_total = 0
for index, row in train_dataset.iterrows():
optimizer.zero_grad()
inputs = torch.tensor(row[:4].values, dtype=torch.float32).cuda()
label = torch.tensor(row[4], dtype=torch.float32).cuda()
outputs = model(inputs)
loss = criterion(outputs, label.unsqueeze(0))
loss.backward()
optimizer.step()
predicted = torch.round(outputs)
train_total += 1
train_correct += (predicted == label).sum().item()
# 在测试集上测试模型
test_correct = 0
test_total = 0
with torch.no_grad():
for index, row in test_dataset.iterrows():
inputs = torch.tensor(row[:4].values, dtype=torch.float32).cuda()
label = torch.tensor(row[4], dtype=torch.float32).cuda()
outputs = model(inputs)
predicted = torch.round(outputs)
test_total += 1
test_correct += (predicted == label).sum().item()
# 打印准确率
print(f'Epoch {epoch+1}, Train Accuracy: {train_correct/train_total}, Test Accuracy: {test_correct/test_total}')
# 预测测试集中的样本
with torch.no_grad():
inputs = torch.tensor(test_dataset.iloc[0,:4].values, dtype=torch.float32).cuda()
label = torch.tensor(test_dataset.iloc[0,4], dtype=torch.float32).cuda()
outputs = model(inputs)
predicted = torch.round(outputs).item()
# 打印预测结果和实际结果
print(f'Actual: {label}, Predicted: {predicted}')
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch零基础入门之逻辑斯蒂回归 - Python技术站