PyTorch报”RuntimeError: Expected object of scalar type Int but got scalar type Double for argument #2 ‘other’ “的原因以及解决办法

问题描述

在使用PyTorch进行神经网络训练时,有时会遇到如下报错:

RuntimeError: Expected object of scalar type Int but got scalar type Double for argument #2 'other'

这个问题通常发生在计算loss或者metrics时。

原因分析

该报错由于tensor数据类型不匹配所致。一些loss函数或者metrics函数要求输入的参数类型必须是int类型,但是有时候我们会不小心传入了float类型的参数,导致出现这个报错。

解决办法

解决该问题的方法非常简单。只需要将参数强制转换为int类型即可,例如:

loss = criterion(output, target.float().long())

在上述代码中,先将target转换成float类型,再转换成long类型,就可以解决该问题。

另外,如果在计算metrics时遇到类似的问题,也可以通过类似的方式进行转换:

accuracy = accuracy_score(target.cpu().numpy(), output.cpu().argmax(axis=1).numpy())

在这个例子中,将target和output都转换成cpu上的numpy array,并对output的axis=1进行argmax操作,得到模型的预测结果。最终调用scikit-learn的accuracy_score函数计算准确率。

总结

在使用PyTorch进行神经网络训练时,有时会遇到报错"RuntimeError: Expected object of scalar type Int but got scalar type Double for argument #2 'other'",这个问题源于参数类型不匹配。为了解决这个问题,只需将参数强制转换为int类型。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch报”RuntimeError: Expected object of scalar type Int but got scalar type Double for argument #2 ‘other’ “的原因以及解决办法 - Python技术站

(0)
上一篇 2023年3月19日
下一篇 2023年3月19日

相关文章

合作推广
合作推广
分享本页
返回顶部