pytorch 网络参数 weight bias 初始化详解

以下是PyTorch网络参数weight bias初始化的详细攻略,包括两个示例说明。

1. 网络参数初始化

在PyTorch中,网络参数的初始化是非常重要的,因为它可以影响模型的收敛速度和最终的性能。PyTorch提供了多种初始化方法,包括常见的均匀分布、正态分布、Xavier初始化和Kaiming初始化等。

1.1 均匀分布初始化

均匀分布初始化是一种简单的初始化方法,它将权重初始化为在[-a, a]之间的均匀分布,其中a是一个常数。在PyTorch中,可以使用torch.nn.init.uniform_()函数来进行均匀分布初始化。

import torch.nn as nn
import torch.nn.init as init

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 10)

        # 均匀分布初始化
        init.uniform_(self.conv1.weight, a=-0.1, b=0.1)
        init.uniform_(self.conv2.weight, a=-0.1, b=0.1)
        init.uniform_(self.fc1.weight, a=-0.1, b=0.1)
        init.uniform_(self.fc2.weight, a=-0.1, b=0.1)

1.2 正态分布初始化

正态分布初始化是一种常用的初始化方法,它将权重初始化为均值为0、标准差为std的正态分布。在PyTorch中,可以使用torch.nn.init.normal_()函数来进行正态分布初始化。

import torch.nn as nn
import torch.nn.init as init

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 10)

        # 正态分布初始化
        init.normal_(self.conv1.weight, mean=0, std=0.01)
        init.normal_(self.conv2.weight, mean=0, std=0.01)
        init.normal_(self.fc1.weight, mean=0, std=0.01)
        init.normal_(self.fc2.weight, mean=0, std=0.01)

1.3 Xavier初始化

Xavier初始化是一种常用的初始化方法,它根据输入和输出的维度自适应地调整权重的初始化范围。在PyTorch中,可以使用torch.nn.init.xavier_uniform_()函数或torch.nn.init.xavier_normal_()函数来进行Xavier初始化。

import torch.nn as nn
import torch.nn.init as init

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 10)

        # Xavier初始化
        init.xavier_uniform_(self.conv1.weight)
        init.xavier_uniform_(self.conv2.weight)
        init.xavier_uniform_(self.fc1.weight)
        init.xavier_uniform_(self.fc2.weight)

1.4 Kaiming初始化

Kaiming初始化是一种针对ReLU激活函数的初始化方法,它根据输入和输出的维度自适应地调整权重的初始化范围。在PyTorch中,可以使用torch.nn.init.kaiming_uniform_()函数或torch.nn.init.kaiming_normal_()函数来进行Kaiming初始化。

import torch.nn as nn
import torch.nn.init as init

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 10)

        # Kaiming初始化
        init.kaiming_uniform_(self.conv1.weight, mode='fan_in', nonlinearity='relu')
        init.kaiming_uniform_(self.conv2.weight, mode='fan_in', nonlinearity='relu')
        init.kaiming_uniform_(self.fc1.weight, mode='fan_in', nonlinearity='relu')
        init.kaiming_uniform_(self.fc2.weight, mode='fan_in', nonlinearity='relu')

2. bias初始化

在PyTorch中,bias的初始化通常使用常数初始化,可以使用torch.nn.init.constant_()函数来进行常数初始化。

import torch.nn as nn
import torch.nn.init as init

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 10)

        # bias初始化
        init.constant_(self.conv1.bias, 0)
        init.constant_(self.conv2.bias, 0)
        init.constant_(self.fc1.bias, 0)
        init.constant_(self.fc2.bias, 0)

以上就是PyTorch网络参数weight bias初始化的详细攻略,包括两个示例说明。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch 网络参数 weight bias 初始化详解 - Python技术站

(0)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • win10/windows 安装Pytorch

    https://pytorch.org/get-started/locally/ 去官网,选择你需要的版本。   把 pip install torch==1.5.0+cu101 torchvision==0.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html 命令行执行。    C…

    2023年4月8日
    00
  • PyTorch中,关于model.eval()和torch.no_grad()

    一直对于model.eval()和torch.no_grad()有些疑惑 之前看博客说,只用torch.no_grad()即可 但是今天查资料,发现不是这样,而是两者都用,因为两者有着不同的作用 引用stackoverflow: Use both. They do different things, and have different scopes.wit…

    PyTorch 2023年4月8日
    00
  • windows10 安装 Anaconda 并配置 pytorch1.0

    官网下载Anaconda安装包,按步骤安装即可安装完后,打开DOS,或Anaconda自带的Anaconda Prompt终端查看Anaconda已安装的安装包C:\Users\jiangshan>conda list安装 matplotlibC:\Users\jiangshan>conda install matplotlib设置镜像# 添加A…

    PyTorch 2023年4月8日
    00
  • VScode中pytorch出现Module ‘torch’ has no ‘xx’ member错误

           因为代码变量太多,使用Sublime text并能很好地跳转,所以使用VsCode 神器。     导入Pytorch模块后出现了   Module ‘torch’ has no cat member,所以在网上找解决办法,这位博主的文章很好用,一路解决。        我的版本python3.7无Anacada,解决办法,打开设置,搜索pyt…

    2023年4月8日
    00
  • 关于PyTorch环境配置及安装教程(Windows10)

    关于 PyTorch 环境配置及安装教程(Windows10) PyTorch 是一个基于 Python 的科学计算库,它主要用于深度学习研究。在 Windows10 系统下,我们可以通过 Anaconda 或 pip 来安装 PyTorch 环境。本文将详细讲解 PyTorch 环境配置及安装教程,并提供两个示例说明。 1. 使用 Anaconda 安装 …

    PyTorch 2023年5月16日
    00
  • Jupyter Notebook远程登录及密码设置操作

    Jupyter Notebook远程登录及密码设置操作 Jupyter Notebook是一种非常流行的交互式计算环境,它可以让用户在浏览器中编写和运行代码。本文将介绍如何在远程服务器上设置Jupyter Notebook,并设置密码以保护您的笔记本。 远程登录Jupyter Notebook 要在远程服务器上登录Jupyter Notebook,您需要执行…

    PyTorch 2023年5月15日
    00
  • Pytorch:学习率调整

    PyTorch学习率调整策略通过torch.optim.lr_scheduler接口实现。PyTorch提供的学习率调整策略分为三大类,分别是: 有序调整:等间隔调整(Step),按需调整学习率(MultiStep),指数衰减调整(Exponential)和 余弦退火CosineAnnealing 自适应调整:自适应调整学习率 ReduceLROnPlate…

    2023年4月6日
    00
  • pytorch sampler对数据进行采样的实现

    PyTorch中的Sampler是一个用于对数据进行采样的工具,它可以用于实现数据集的随机化、平衡化等操作。本文将深入浅析PyTorch的Sampler的实现方法,并提供两个示例说明。 1. PyTorch的Sampler的实现方法 PyTorch的Sampler的实现方法如下: sampler = torch.utils.data.Sampler(data…

    PyTorch 2023年5月15日
    00
合作推广
合作推广
分享本页
返回顶部