pytorch 网络参数 weight bias 初始化详解

yizhihongxing

以下是PyTorch网络参数weight bias初始化的详细攻略,包括两个示例说明。

1. 网络参数初始化

在PyTorch中,网络参数的初始化是非常重要的,因为它可以影响模型的收敛速度和最终的性能。PyTorch提供了多种初始化方法,包括常见的均匀分布、正态分布、Xavier初始化和Kaiming初始化等。

1.1 均匀分布初始化

均匀分布初始化是一种简单的初始化方法,它将权重初始化为在[-a, a]之间的均匀分布,其中a是一个常数。在PyTorch中,可以使用torch.nn.init.uniform_()函数来进行均匀分布初始化。

import torch.nn as nn
import torch.nn.init as init

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 10)

        # 均匀分布初始化
        init.uniform_(self.conv1.weight, a=-0.1, b=0.1)
        init.uniform_(self.conv2.weight, a=-0.1, b=0.1)
        init.uniform_(self.fc1.weight, a=-0.1, b=0.1)
        init.uniform_(self.fc2.weight, a=-0.1, b=0.1)

1.2 正态分布初始化

正态分布初始化是一种常用的初始化方法,它将权重初始化为均值为0、标准差为std的正态分布。在PyTorch中,可以使用torch.nn.init.normal_()函数来进行正态分布初始化。

import torch.nn as nn
import torch.nn.init as init

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 10)

        # 正态分布初始化
        init.normal_(self.conv1.weight, mean=0, std=0.01)
        init.normal_(self.conv2.weight, mean=0, std=0.01)
        init.normal_(self.fc1.weight, mean=0, std=0.01)
        init.normal_(self.fc2.weight, mean=0, std=0.01)

1.3 Xavier初始化

Xavier初始化是一种常用的初始化方法,它根据输入和输出的维度自适应地调整权重的初始化范围。在PyTorch中,可以使用torch.nn.init.xavier_uniform_()函数或torch.nn.init.xavier_normal_()函数来进行Xavier初始化。

import torch.nn as nn
import torch.nn.init as init

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 10)

        # Xavier初始化
        init.xavier_uniform_(self.conv1.weight)
        init.xavier_uniform_(self.conv2.weight)
        init.xavier_uniform_(self.fc1.weight)
        init.xavier_uniform_(self.fc2.weight)

1.4 Kaiming初始化

Kaiming初始化是一种针对ReLU激活函数的初始化方法,它根据输入和输出的维度自适应地调整权重的初始化范围。在PyTorch中,可以使用torch.nn.init.kaiming_uniform_()函数或torch.nn.init.kaiming_normal_()函数来进行Kaiming初始化。

import torch.nn as nn
import torch.nn.init as init

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 10)

        # Kaiming初始化
        init.kaiming_uniform_(self.conv1.weight, mode='fan_in', nonlinearity='relu')
        init.kaiming_uniform_(self.conv2.weight, mode='fan_in', nonlinearity='relu')
        init.kaiming_uniform_(self.fc1.weight, mode='fan_in', nonlinearity='relu')
        init.kaiming_uniform_(self.fc2.weight, mode='fan_in', nonlinearity='relu')

2. bias初始化

在PyTorch中,bias的初始化通常使用常数初始化,可以使用torch.nn.init.constant_()函数来进行常数初始化。

import torch.nn as nn
import torch.nn.init as init

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 10)

        # bias初始化
        init.constant_(self.conv1.bias, 0)
        init.constant_(self.conv2.bias, 0)
        init.constant_(self.fc1.bias, 0)
        init.constant_(self.fc2.bias, 0)

以上就是PyTorch网络参数weight bias初始化的详细攻略,包括两个示例说明。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch 网络参数 weight bias 初始化详解 - Python技术站

(0)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • Pytorch:数据增强与标准化

    本文对transforms.py中的各个预处理方法进行介绍和总结。主要从官方文档中总结而来,官方文档只是将方法陈列,没有归纳总结,顺序很乱,这里总结一共有四大类,方便大家索引: 裁剪——Crop 中心裁剪:transforms.CenterCrop 随机裁剪:transforms.RandomCrop 随机长宽比裁剪:transforms.RandomRes…

    PyTorch 2023年4月6日
    00
  • M1 mac安装PyTorch的实现步骤

    M1 Mac是苹果公司推出的基于ARM架构的芯片,与传统的x86架构有所不同。因此,在M1 Mac上安装PyTorch需要一些特殊的步骤。本文将介绍M1 Mac上安装PyTorch的实现步骤,并提供两个示例说明。 步骤一:安装Miniforge Miniforge是一个轻量级的Anaconda发行版,专门为ARM架构的Mac电脑设计。我们可以使用Minifo…

    PyTorch 2023年5月15日
    00
  • PyTorch自定义数据集

    数据传递机制 我们首先回顾识别手写数字的程序: … Dataset = torchvision.datasets.MNIST(root=’./mnist/’, train=True, transform=transform, download=True,) dataloader = torch.utils.data.DataLoader(dataset=…

    2023年4月7日
    00
  • pytorch中tensor的属性 类型转换 形状变换 转置 最大值

    import torch import numpy as np a = torch.tensor([[[1]]]) #只有一个数据的时候,获取其数值 print(a.item()) #tensor转化为nparray b = a.numpy() print(b,type(b),type(a)) #获取张量的形状 a = torch.tensor(np.ara…

    PyTorch 2023年4月8日
    00
  • OpenCV加载Pytorch模型出现Unsupported Lua type 解决方法

    原因 Torch有两个版本,一个就叫Torch一个专门给Python用的Pytorch,它们训练完之后保存下来的模型是不一样的.说到这问题就很清楚了.OpenCV的ReadNetFromTorch支持的是前者… 解决方法 那么有没有解决办法呢,答案是有的.PyTorch支持把模型保存为ONNX格式.而这个格式在opencv是支持的.操作如下: impor…

    PyTorch 2023年4月8日
    00
  • PyTorch——(4)where条件判断、gather查表

    where() 条件判断 gather()查表 input :待查的表dim : 在input的哪个维度上查表index: 待查表的索引值

    2023年4月8日
    00
  • pytorch的.item()方法

    python的.item()用于将字典中每对key和value组成一个元组,并把这些元组放在列表中返回例如person={‘name’:‘lizhong’,‘age’:‘26’,‘city’:‘BeiJing’,‘blog’:‘www.jb51.net’} for key,value in person.items():print ‘key=’,key,’,…

    PyTorch 2023年4月8日
    00
  • 浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式

    在PyTorch中,我们可以使用不同的文件格式来保存模型,包括.pt、.pth和.pkl。这些文件格式之间有一些区别,本文将对它们进行详细讲解,并提供两个示例说明。 .pt和.pth文件 .pt和.pth文件是PyTorch中最常用的模型保存格式。它们都是二进制文件,可以保存模型的参数、状态和结构。.pt文件通常用于保存单个模型,而.pth文件通常用于保存多…

    PyTorch 2023年5月16日
    00
合作推广
合作推广
分享本页
返回顶部