PyTorch报”TypeError: forward() takes 1 positional argument but 2 were given “的原因以及解决办法

yizhihongxing

问题描述

在使用PyTorch进行神经网络训练时,经常会遇到如下的报错信息:

TypeError: forward() takes 1 positional argument but 2 were given

这个报错信息的意思是,我们在调用神经网络的forward()函数时,给了多余的参数,而forward()函数只接收一个参数。

这个问题可能出现在多种情况下,下面我们分别来看一下各种情况下的原因和解决办法。

没有正确传递输入数据

第一种常见情况是,我们在给神经网络输入数据时,没有按正确的方式传递数据。例如,下面的代码就会报出上述错误:

import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 5)
        self.fc1 = nn.Linear(32*5*5, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.fc1(x)
        return x

net = SimpleNet()
x = torch.randn(1, 3, 32, 32)
output = net(x, 1)

解决方法是,将x作为参数传递给forward()函数即可:

output = net.forward(x)

或者直接省略掉forward()函数:

output = net(x)

在调用神经网络时,传递了多余的参数

第二种情况是,我们在调用神经网络的时候,传递了多余的参数。例如下面的代码:

import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 5)
        self.fc1 = nn.Linear(32*5*5, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.fc1(x)
        return x

net = SimpleNet()
x = torch.randn(1, 3, 32, 32)
output = net(x, 1)

解决方法是,正确调用神经网络:

output = net(x)

在前向传播过程中,调用了其他的函数

第三种情况是,在前向传播的过程中,调用了其他的函数,而这些函数又需要多个参数。例如:

import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 5)
        self.fc1 = nn.Linear(32*5*5, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.other_function(x, 1)
        x = self.fc1(x)
        return x

    def other_function(self, x, y):
        return x + y

net = SimpleNet()
x = torch.randn(1, 3, 32, 32)
output = net(x)

这段代码会报错,因为在调用other_function()函数时,我们需要传递两个参数,但是forward()函数只接受一个参数。

解决方法是,在调用other_function()函数时,将x和y合成一个参数,然后再传递即可:

def forward(self, x):
    x = self.conv1(x)
    x = self.other_function(x, 1, 2)
    x = self.fc1(x)
    return x

def other_function(self, x, y, z):
    return x + y + z

net = SimpleNet()
x = torch.randn(1, 3, 32, 32)
output = net(x)

总结

在使用PyTorch进行神经网络训练时,出现“TypeError: forward() takes 1 positional argument but 2 were given”的问题,常常是因为传参不正确所导致的。

要解决这个问题,通常需要检查代码中对神经网络的调用情况,确保传递的参数正确无误。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch报”TypeError: forward() takes 1 positional argument but 2 were given “的原因以及解决办法 - Python技术站

(0)
上一篇 2023年3月19日
下一篇 2023年3月19日

相关文章

合作推广
合作推广
分享本页
返回顶部