Python机器学习pytorch交叉熵损失函数的深刻理解

Python机器学习pytorch交叉熵损失函数的深刻理解

交叉熵损失函数是一种常用的损失函数,它在分类问题中非常有效。在PyTorch中,我们可以使用nn.CrossEntropyLoss()函数来计算交叉熵损失。本文将提供一个完整的攻略,介绍如何使用Python和PyTorch实现交叉熵损失函数,并提供两个示例,分别是使用交叉熵损失函数进行多分类和使用交叉熵损失函数进行图像分类。

交叉熵损失函数的现

交叉熵损失函数是一种常用的损失函数,它在分类问题中非常有效。在PyTorch中,我们可以使用nn.CrossEntropyLoss()函数来计算交叉熵损失。交叉熵损失函数的公式如下:

$$
\text{loss} = -\frac{1}{N}\sum_{i=1}^{N}\sum_{j=1}^{C}y_{ij}\log(p_{ij})
$$

其中,$N$是样本数量,$C$是类别数量,$y_{ij}$是第$i$个样本的第$j$个类别的真实标签,$p_{ij}$是第$i$个样本的第$j$个类别的预测概率。

示例1:使用交叉熵损失函数进行多分类

以下是一个示例,展示如何使用交叉熵损失函数进行多分类。

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

X_train = torch.from_numpy(X_train).float()
y_train = torch.from_numpy(y_train).long()

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(4, 10)
        self.fc2 = nn.Linear(10, 3)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1)

for epoch in range(1000):
    optimizer.zero_grad()
    outputs = net(X_train)
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()

X_test = torch.from_numpy(X_test).float()
y_test = torch.from_numpy(y_test).long()

with torch.no_grad():
    outputs = net(X_test)
    _, predicted = torch.max(outputs, 1)
    accuracy = (predicted == y_test).sum().item() / len(y_test)

print(f'Accuracy: {accuracy}')

在这个示例中,我们使用鸢尾花数据集进行多分类。我们首先加载数据集,然后将数据集分为训练集和测试集。接下来,我们定义一个神经网络模型,并定义交叉熵损失函数和优化器。在训练过程中,我们首先将梯度清零,然后计算输出和损失函数的值。接下来,我们计算梯度并更新权重。最后,我们使用测试集评估模型的准确性。

示例2:使用交叉熵损失函数进行图像分类

以下是一个示例,展示如何使用交叉熵损失函数进行图像分类。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

train_dataset = MNIST(root='data', train=True, transform=ToTensor(), download=True)
test_dataset = MNIST(root='data', train=False, transform=ToTensor(), download=True)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.pool(x)
        x = torch.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 64 * 7 * 7)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1)

for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}')

correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy: {correct / total}')

在这个示例中,我们使用MNIST数据集进行图像分类。我们首先加载数据集,并将数据集分为训练集和测试集。接下来,我们定义一个卷积神经网络模型,并定义交叉熵损失函数和优化器。在训练过程中,我们使用数据加载器来加载数据,并在每个epoch中计算损失函数的值。最后,我们使用测试集评估模型的准确性。

总结

本文提供了一个完整的攻略,介绍了如何使用Python和PyTorch实现交叉熵损失函数,并提供了两个示例,分别是使用交叉熵损失函数进行多分类和使用交叉熵损失函数进行图像分类。在实现过程中,我们使用了PyTorch和scikit-learn等库,并介绍了一些常用的函数和技术。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Python机器学习pytorch交叉熵损失函数的深刻理解 - Python技术站

(0)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • PyTorch数据处理,datasets、DataLoader及其工具的使用

    torchvision是PyTorch的一个视觉工具包,提供了很多图像处理的工具。 datasets使用ImageFolder工具(默认PIL Image图像),获取定制化的图片并自动生成类别标签。如裁剪、旋转、标准化、归一化等(使用transforms工具)。 DataLoader可以把datasets数据集打乱,分成batch,并行加速等。 一、data…

    2023年4月8日
    00
  • Pytorch 随机数种子设置

    一般而言,可以按照如下方式固定随机数种子,以便复现实验: # 来自相关于 GCN 代码: 例如 grand.py 等的代码 parser.add_argument(‘–seed’, type=int, default=42, help=’Random seed.’) np.random.seed(args.seed) torch.manual_seed(a…

    PyTorch 2023年4月6日
    00
  • pytorch 7 optimizer 优化器 加速训练

    import torch import torch.utils.data as Data import torch.nn.functional as F import matplotlib.pyplot as plt # torch.manual_seed(1) # reproducible 超参数设置 LR = 0.01 BATCH_SIZE = 32 E…

    2023年4月8日
    00
  • Python笔记之a = [0]*x格式的含义及说明

    在Python中,a = [0]*x是一种常见的列表初始化方式,其中x是一个整数。这种方式会创建一个长度为x的列表,其中每个元素都是0。下面是一个示例: a = [0]*5 print(a) # 输出 [0, 0, 0, 0, 0] 在这个示例中,我们创建了一个长度为5的列表a,其中每个元素都是0。 这种方式的好处是可以快速创建一个指定长度的列表,并且所有元…

    PyTorch 2023年5月15日
    00
  • pytorch中permute()函数用法补充说明(矩阵维度变化过程)

    PyTorch中permute()函数用法补充说明 在PyTorch中,permute()函数用于对张量的维度进行重新排列。本文将详细介绍permute()函数的用法,并提供两个示例说明。 permute()函数的用法 permute()函数的语法如下: torch.Tensor.permute(*dims) 其中,*dims表示一个可变参数,用于指定新的维…

    PyTorch 2023年5月15日
    00
  • Pytorch训练模型得到输出后计算F1-Score 和AUC的操作

    以下是“PyTorch训练模型得到输出后计算F1-Score和AUC的操作”的完整攻略,包含两个示例说明。 示例1:计算F1-Score 步骤1:准备数据 首先,我们需要准备一些数据来计算F1-Score。假设我们有一个二分类问题,我们有一些真实标签和一些预测标签。我们可以使用sklearn库中的precision_recall_fscore_support…

    PyTorch 2023年5月15日
    00
  • PyTorch在Windows环境搭建的方法步骤

    PyTorch在Windows环境搭建的方法步骤 在本文中,我们将介绍如何在Windows环境下搭建PyTorch。我们将提供两个示例,一个是使用Anaconda安装PyTorch,另一个是使用pip安装PyTorch。 示例1:使用Anaconda安装PyTorch 以下是使用Anaconda安装PyTorch的步骤: 下载并安装Anaconda。可以从A…

    PyTorch 2023年5月16日
    00
  • pytorch中tensor张量的创建

    import torch import numpy as np print(torch.tensor([1,2,3])) print(torch.tensor(np.arange(15).reshape(3,5))) print(torch.empty([3,4])) print(torch.ones([3,4])) print(torch.zeros([3…

    PyTorch 2023年4月8日
    00
合作推广
合作推广
分享本页
返回顶部