最近,在玩语义分割的模型。利用GPU训练的时候,每次跑几个epochs之后,程序崩溃,输出我说我的generator不是线程安全的。查看 trace back发现model.fit_generator在调用自己写的generator出现问题,需要将自己的generator写成线程安全的。
参考keras的#1638 issue找到解决方案。如下图,需要在代码中添加以下内容,然后在自己写的生成器函数上面加上修饰符 @threadsafe_generator:
添加以上内容后,程序报错 threadsafe_iter不是一个iterator。这是因为,我是用的是python 3,在python3中迭代器类self.next()的定义需要改为self.__next__(),将这个函数改为如下,问题解决(python 2 不需要修改):
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:[Keras 模型训练] Thread Safe Generator - Python技术站