用Pytorch1.0进行半精度浮点型网络训练需要注意下问题:
1、网络要在GPU上跑,模型和输入样本数据都要cuda().half()
2、模型参数转换为half型,不必索引到每层,直接model.cuda().half()即可
3、对于半精度模型,优化算法,Adam我在使用过程中,在某些参数的梯度为0的时候,更新权重后,梯度为零的权重变成了NAN,这非常奇怪,但是Adam算法对于全精度数据类型却没有这个问题。
另外,SGD算法对于半精度和全精度计算均没有问题。
还有一个问题是不知道是不是网络结构比较小的原因,使用半精度的训练速度还没有全精度快。这个值得后续进一步探索。
对于上面的这个问题,的确是网络很小的情况下,在1080Ti上半精度浮点型没有很明显的优势,但是当网络变大之后,半精度浮点型要比全精度浮点型要快。但具体快多少和模型的大小以及输入样本大小有关系,我测试的是要快1/6,同时,半精度浮点型在占用内存上比较有优势,对于精度的影响尚未探究。
将网络再变大些,epoch的次数也增大,半精度和全精度的时间差就表现出来了,在训练的时候。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch半精度浮点型网络训练问题 - Python技术站