model_dict = torch.load(save_path) fp = open('model_parameter.bin', 'wb') weight_count = 0 num=1 for k, v in model_dict.items(): print(k,num) num=num+1 if 'num_batches_tracked' in k: continue v = v.cpu().numpy().flatten() for d in v: fp.write(d) weight_count+=1 print('model_weight has Convert Completely!',weight_count)
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch中.pth文件转成.bin的二进制文件 - Python技术站