对Pytorch神经网络初始化kaiming分布详解

对PyTorch神经网络初始化Kaiming分布详解

在深度学习中,神经网络的初始化非常重要,它可以影响模型的收敛速度和性能。Kaiming初始化是一种常用的初始化方法,它可以有效地解决梯度消失和梯度爆炸的问题。本文将详细介绍PyTorch中Kaiming初始化的实现方法,并提供两个示例说明。

1. Kaiming初始化方法

Kaiming初始化方法是一种针对ReLU激活函数的初始化方法,它可以使得神经网络的输出具有相同的方差。具体来说,Kaiming初始化方法根据输入和输出的维度计算标准差,并将权重初始化为从均值为0、标准差为$\sqrt{\frac{2}{n_{in}}}$的正态分布中采样得到的值,其中$n_{in}$是输入的维度。

在PyTorch中,可以使用torch.nn.init模块中的kaiming_normal_函数或kaiming_uniform_函数实现Kaiming初始化。以下是一个示例代码:

import torch.nn as nn
import torch.nn.init as init

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 10)

        # 使用kaiming_normal_函数初始化卷积层和全连接层的权重
        init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
        init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu')
        init.kaiming_normal_(self.fc1.weight, mode='fan_out', nonlinearity='relu')
        init.kaiming_normal_(self.fc2.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.max_pool2d(nn.functional.relu(self.conv2(x)), 2)
        x = x.view(-1, 128 * 8 * 8)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

在上面的示例代码中,我们定义了一个名为Net的类,它继承自nn.Module。在__init__函数中,我们定义了神经网络的各个层,并使用kaiming_normal_函数初始化卷积层和全连接层的权重。在forward函数中,我们定义了模型的前向传播过程。

2. 示例1:使用Kaiming初始化训练神经网络

以下是一个使用Kaiming初始化训练神经网络的示例代码:

import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms

# 加载数据集
transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

# 定义模型、损失函数和优化器
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

# 训练模型
for epoch in range(100):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 100 == 99:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

在上面的示例代码中,我们首先加载CIFAR-10数据集,并定义模型、损失函数和优化器。然后,我们训练模型并输出每个epoch的平均损失。

3. 示例2:使用Kaiming初始化进行迁移学习

以下是一个使用Kaiming初始化进行迁移学习的示例代码:

import torch.hub
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms

# 加载预训练模型
pretrained_net = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=True)
pretrained_net.fc = nn.Linear(512, 10)

# 使用kaiming_normal_函数初始化全连接层的权重
init.kaiming_normal_(pretrained_net.fc.weight, mode='fan_out', nonlinearity='relu')

# 将模型移动到GPU上
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pretrained_net = pretrained_net.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(pretrained_net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

# 加载数据集
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

# 训练模型
for epoch in range(100):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = pretrained_net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 100 == 99:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

在上面的示例代码中,我们首先使用torch.hub.load函数加载预训练的ResNet-18模型,并将其输出层替换为一个新的全连接层。然后,我们使用kaiming_normal_函数初始化全连接层的权重,并将模型移动到GPU上。接着,我们加载CIFAR-10数据集,并定义损失函数和优化器。最后,我们训练模型并输出每个epoch的平均损失。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:对Pytorch神经网络初始化kaiming分布详解 - Python技术站

(0)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • PyTorch全连接ReLU网络

    PyTorch全连接ReLU网络 1.PyTorch的核心是两个主要特征: 一个n维张量,类似于numpy,但可以在GPU上运行 搭建和训练神经网络时的自动微分/求导机制 本文将使用全连接的ReLU网络作为运行示例。该网络将有一个单一的隐藏层,并将使用梯度下降训练,通过最小化网络输出和真正结果的欧几里得距离,来拟合随机生成的数据。 2.张量 2.1 热身: …

    PyTorch 2023年4月8日
    00
  • Pytorch Visdom

    fb官方的一些demo 一.  show something 1.  vis.image:显示一张图片 viz.image( np.random.rand(3, 512, 256), opts=dict(title=’Random!’, caption=’How random.’), ) opts.jpgquality:JPG质量(number0-100;默…

    2023年4月8日
    00
  • pytorch中的math operation: torch.bmm()

    torch.bmm(batch1, batch2, out=None) → Tensor Performs a batch matrix-matrix product of matrices stored in batch1 and batch2. batch1 and batch2 must be 3-D tensors each containing t…

    PyTorch 2023年4月8日
    00
  • M1 mac安装PyTorch的实现步骤

    M1 Mac是苹果公司推出的基于ARM架构的芯片,与传统的x86架构有所不同。因此,在M1 Mac上安装PyTorch需要一些特殊的步骤。本文将介绍M1 Mac上安装PyTorch的实现步骤,并提供两个示例说明。 步骤一:安装Miniforge Miniforge是一个轻量级的Anaconda发行版,专门为ARM架构的Mac电脑设计。我们可以使用Minifo…

    PyTorch 2023年5月15日
    00
  • pytorch 中tensor的加减和mul、matmul、bmm

    如下是tensor乘法与加减法,对应位相乘或相加减,可以一对多 import torch def add_and_mul(): x = torch.Tensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]) y = torch.Tensor([1, 2, 3]) y = y – x print(y)…

    PyTorch 2023年4月7日
    00
  • PyTorch中Tensor和tensor的区别是什么

    这篇文章主要介绍“PyTorch中Tensor和tensor的区别是什么”的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇“PyTorch中Tensor和tensor的区别是什么”文章能帮助大家解决问题。 Tensor和tensor的区别 本文列举的框架源码基于PyTorch2.0,交互语句在0.4.1上测试通过 impo…

    2023年4月8日
    00
  • pytorch tensor计算三通道均值方式

    以下是PyTorch计算三通道均值的两个示例说明。 示例1:计算图像三通道均值 在这个示例中,我们将使用PyTorch计算图像三通道均值。 首先,我们需要准备数据。我们将使用torchvision库来加载图像数据集。您可以使用以下代码来加载数据集: import torchvision.datasets as datasets import torchvis…

    PyTorch 2023年5月15日
    00
  • pytorch动态网络以及权重共享实例

    以下是关于“PyTorch 动态网络以及权重共享实例”的完整攻略,其中包含两个示例说明。 示例1:动态网络 步骤1:导入必要库 在定义动态网络之前,我们需要导入一些必要的库,包括torch。 import torch 步骤2:定义动态网络 在这个示例中,我们使用动态网络来演示如何定义动态网络。 # 定义动态网络 class DynamicNet(torch.…

    PyTorch 2023年5月16日
    00
合作推广
合作推广
分享本页
返回顶部