今天使用hiddenlayer测试了下retinanet网络的可视化。
首先,安装hiddlayer,直接pip pip install git+https://github.com/waleedka/hiddenlayer.git
然后在终端加载模型并显示:
import model, torch
import hiddenlayer as hl
retinanet = model.resnet18(num_classes=100, pretrained=True).cuda()
x = torch.rand((1, 3, 224, 224)).cuda().float()
ann = torch.tensor([[[20.0, 30.0, 53.2, 33.3, 32.0]]]).cuda().float()
hl.build_graph(retinanet, [x, ann])
hl.save('/home/willer/model.pdf')
模型太复杂了,放在这里了。
昨天晚上对比着模型结构的pdf和代码又看了下,发现还是很有用的,起码对网络的数据流动的认识更加清晰了。
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch 网络可视化 - Python技术站