在PyTorch中,我们通常使用nn.Module类来定义神经网络模型。在定义模型时,我们需要实现__init__()、forward()和__call__()方法。这些方法分别用于初始化模型参数、定义前向传播过程和调用模型。
init()方法
init()方法用于初始化模型参数。在该方法中,我们通常定义模型的各个层,并初始化它们的参数。以下是一个示例代码,演示了如何在__init__()方法中定义模型的各个层:
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
在上面的代码中,我们定义了一个Net类,该类继承自nn.Module类。在Net类的构造函数中,我们定义了模型的各个层,包括两个卷积层、两个池化层和三个全连接层。我们使用nn.Conv2d()函数定义卷积层,使用nn.MaxPool2d()函数定义池化层,使用nn.Linear()函数定义全连接层。
forward()方法
forward()方法用于定义模型的前向传播过程。在该方法中,我们通常将输入传递给模型的各个层,并计算输出。以下是一个示例代码,演示了如何在forward()方法中定义模型的前向传播过程:
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
在上面的代码中,我们在Net类中定义了forward()方法。在该方法中,我们首先将输入x传递给第一个卷积层,并使用ReLU激活函数和池化层。接下来,我们将输出传递给第二个卷积层,并再次使用ReLU激活函数和池化层。然后,我们将输出展平,并传递三个全连接层。最后,我们返回输出。
call()方法
call()方法用于调用模型。在该方法中,我们通常将输入传递给forward()方法,并计算输出。以下是一个示例代码,演示了如何在__call__()方法中调用模型:
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def __call__(self, x):
return self.forward(x)
在上面的代码中,我们在Net类中定义了__call__()方法。在该方法中,我们将输入x传递给forward()方法,并返回输出。这样,我们就可以使用Net类的实例来调用模型。例如,我们可以使用以下代码来调用模型:
net = Net()
output = net(input)
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch __init__、forward与__call__的用法小结 - Python技术站