问题描述
在使用PyTorch进行神经网络训练时,经常会遇到如下的报错信息:
TypeError: forward() takes 1 positional argument but 2 were given
这个报错信息的意思是,我们在调用神经网络的forward()函数时,给了多余的参数,而forward()函数只接收一个参数。
这个问题可能出现在多种情况下,下面我们分别来看一下各种情况下的原因和解决办法。
没有正确传递输入数据
第一种常见情况是,我们在给神经网络输入数据时,没有按正确的方式传递数据。例如,下面的代码就会报出上述错误:
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 5)
self.fc1 = nn.Linear(32*5*5, 10)
def forward(self, x):
x = self.conv1(x)
x = self.fc1(x)
return x
net = SimpleNet()
x = torch.randn(1, 3, 32, 32)
output = net(x, 1)
解决方法是,将x作为参数传递给forward()函数即可:
output = net.forward(x)
或者直接省略掉forward()函数:
output = net(x)
在调用神经网络时,传递了多余的参数
第二种情况是,我们在调用神经网络的时候,传递了多余的参数。例如下面的代码:
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 5)
self.fc1 = nn.Linear(32*5*5, 10)
def forward(self, x):
x = self.conv1(x)
x = self.fc1(x)
return x
net = SimpleNet()
x = torch.randn(1, 3, 32, 32)
output = net(x, 1)
解决方法是,正确调用神经网络:
output = net(x)
在前向传播过程中,调用了其他的函数
第三种情况是,在前向传播的过程中,调用了其他的函数,而这些函数又需要多个参数。例如:
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 5)
self.fc1 = nn.Linear(32*5*5, 10)
def forward(self, x):
x = self.conv1(x)
x = self.other_function(x, 1)
x = self.fc1(x)
return x
def other_function(self, x, y):
return x + y
net = SimpleNet()
x = torch.randn(1, 3, 32, 32)
output = net(x)
这段代码会报错,因为在调用other_function()函数时,我们需要传递两个参数,但是forward()函数只接受一个参数。
解决方法是,在调用other_function()函数时,将x和y合成一个参数,然后再传递即可:
def forward(self, x):
x = self.conv1(x)
x = self.other_function(x, 1, 2)
x = self.fc1(x)
return x
def other_function(self, x, y, z):
return x + y + z
net = SimpleNet()
x = torch.randn(1, 3, 32, 32)
output = net(x)
总结
在使用PyTorch进行神经网络训练时,出现“TypeError: forward() takes 1 positional argument but 2 were given”的问题,常常是因为传参不正确所导致的。
要解决这个问题,通常需要检查代码中对神经网络的调用情况,确保传递的参数正确无误。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch报”TypeError: forward() takes 1 positional argument but 2 were given “的原因以及解决办法 - Python技术站