RefineDet

一、相关背景

        中科院自动化所最新成果,CVPR 2018 《Single-Shot Refinement Neural Network for Object Detection》

         在VOC2007测试集上,图像输入512*512时,map为81.8%,速度为24fps。

        论文链接:https://arxiv.org/abs/1711.06897

二、主要思想

       1、单阶段框架用于目标检测,由两个相互连接模块组成:ARM和ODM;

       2、设计了TCB来传输ARM特征,来处理更具挑战任务,预测ODM中准确的对象位置、大小和类别标签;

       3、引入类似FPN网络的特征融合操作用于检测网络,可以有效提高对小目标的检测效果

       4、实现了目前最好的检测效果。

三、网络结构

        基于SSD算法改进,改进了单一阶段设计的方法,使用两个相连模块,即anchor细化模块ARM(anchor refinement module )和目标检测模块ODM(the object detection module),如下图,其中,灰绿色代表和不同的特征层相联系的经过细化的anchors。星形代表已定义的anchors的中心,这些anchors在图像上并不按照一定规则铺设。

目标检测之RefineDet

       RefineDet 基于前馈卷积网络,该网络生成固定数目的边界框和表示在这些框中存在的不同类别目标的分数,通过非最大抑制(NMS)来产生最后的结果。RefineDet 由两个相连的模块组成,即 ARM 和 ODM。

       ARM旨在移除负样本anchors以便为分类器减少搜索空间,并粗略调整 anchors 的位置和大小,以便为随后的回归器提供更好的初始化结果。 

       具体点来说:类似Faster R-CNN算法中的RPN网络,主要用来得到bbox(类似Faster R-CNN中的ROI或proposal)和去除一些负样本(这是因为负样本数量远大于正样本)。因此基于4层特征最后得到的还是两条支路,一个bbox的坐标回归支路,另一个是bbox的二分类支路。我们知道在Faster R-CNN算法中RPN网络存在的意义就是生成proposal(或者叫ROI),这些proposal会给后续检测网络提供较好的初始信息,这也是one stage的object detection算法和two stage的object detection算法的重要区别,这里的ARM基本上扮演了RPN网络的角色,如果一定要说不同点的话,那应该就是这里的输入利用了多层特征,而RPN网络的输入是单层特征。

       ODM 旨在根据细化后的 anchors 将结果回归到准确的目标位置并预测多类别标签。

       具体来说:采用ARM产生的refined anchors作为输入,进一步改善回归和预测多类标签。这两个互连模块模仿两阶段结构,因此继承了其三个优点来产生准确的检测结果,效率也高。这部分就基本上是SSD了,也是融合不同层的特征,然后做multi class classification和regression。主要的不同点一方面在于这部分的输入anchors是ARM部分得到的refined anchors,类似RPN网络输出的proposal。另一方面和FPN算法类似,这里的浅层feature map(size较大的蓝色矩形块)融合了高层feature map的信息,然后预测bbox是基于每层feature map(每个蓝色矩形块)进行,最后将各层结果再整合到一起。而在SSD中浅层的feature map是直接拿来用的(并没有和高层的feature map融合),也就是对bbox的预测是在每一层上进行的,预测得到结果后再将各层结果整合在一起,这是非常重要的区别。这样做的好处就是对小目标物体的检测效果更好,这在FPN和RON等算法中已经证明过了。

       为了在 ARM 和 ODM 之间建立链接,我们引入了 TCB(传输连接块)将来自 ARM 的不同层的功能转换为 ODM 所需的形式,以便 ODM 可以共享来自 ARM 的特征。ODM 由TCB 的输出组成,TCB后面连接着预测层,其生成目标类别的分数和相对于细化后的 anchors 的坐标的形状偏移量。TCB传输anchor细化模块中的特征,一边可在ODM中预测目标的位置、尺寸和类别标签。TCB的另一个功能是通过将高级特征添加到传输的特征来继承大规模的上下文,以提高检测的准确性。

与SSD类似,RefineDet基于前馈卷积网络生成bounding boxes和目标的不同类的预测分数,通过非最大值抑制来产生最终结果。RefineDet由两个相互连接的模块组成,即ARM和ODM。删除VGG-16和ResNet-101的分类层并添加辅助结构来构建ARM,它们在ImageNet上进行了预训练来满足我们的需求。

       网络结构构建,以ResNet101,输入图像大小为320为例,在Anchor Refinement Module部分的4个灰色矩形块(feature map)的size分别是40*40,20*20,10*10,5*5,其中前三个是ResNet101网络本身的输出层,最后5*5输出是另外添加的一个residual block。有了特征提取的主网络后,就要开始做融合层操作了,首先是5*5的feature map经过一个transfer connection block得到对应大小的蓝色矩形块(P6),transfer connection block后面会介绍 ,对于生成P6的这条支路而言只是3个卷积层而已。接着基于10*10的灰色矩形块(feature map)经过transfer connection block得到对应大小的蓝色矩形块(P5),此处的transfer connection block相比P6增加了反卷积支路,反卷积支路的输入来自于生成P6的中间层输出。P4和P3的生成与P5同理。

       因此,整体来看该网络和two stage的结构很像(都可以概括为two-step cascaded regression),一个子模块做RPN的事,另一个子模块做SSD的事。因为SSD是直接在default box的基础上进行回归的,而在RefineDet中是先通过ARM部分生成refined anchor boxes(类似RPN网络输出的propsoal),然后在refined anchor boxes基础上进行回归,所以能有更高的准确率,而且得益于特征融合,该算法对于小目标物体的检测更有效。 

       以下解释RefineDet的三个核心组件:(1)传输连接块(TCB),传输ARM的特征到ODM进行检测; (2)两步级联回归,准确地回归物体的位置和大小; (3)负锚过滤,在早期拒绝分类良好的负锚,缓解类不平衡问题。

为了让他们之间的维度相匹配,我们使用逆卷积操作来增大高级特征图,并把它们的对应元素进行求和。然后,我们在求和之后添加卷积层以确保检测的特征的可辨性。

如下图:

目标检测之RefineDet

目标检测之RefineDet

 

 

两步级联回归策略来回归目标的位置和大小。也就是说,我们使用ARM 来首次调整 anchors 的位置和大小,以便为ODM 中的回归操作提供更好的初始化结果。具体而言,我们将 n 个 anchor boxes 与特定特征图上的每个规则划分的单元相互关联。每个 anchor box 相对于其对应单元的初始位置是固定的。对于每个特征图单元,我们预测经过细化的 anchor boxes相对于原始平铺 anchors 的四个偏移量以及便是这些框中存在前景对象的两个置信度分数。因此,我们可以在每个特征图单元中生成 n 个细化后anchor boxes。

 

四、训练与推理

1、数据增强

使用了几种数据扩充方法生成训练样本,来构建一个强大的模型来适应对象的变化,包括随机扩展,随机剪裁,随机光度失真和翻转。

2、主干网络

使用在ILSVRC CLS-LOC数据集上预训练的VGG-16和ResNet-101作为RefineDet中的骨干网络。RefineDet也可以在其他预训练网络上工作,如Inception v2 ,Inception ResNet和ResNeXt101。 与DeepLab-LargeFOV类似,通过子采样参数,将VGG-16的fc6和fc7转换成卷积层conv_fc6和conv_fc7。与其他层相比,conv4_3和conv5_3具有不同的特征尺度,所以使用L2正则化来扩展特征到10和8中,然后在反向传播中学习尺度。 同时,为了捕捉高层次多种尺度的信息和引导对象检测,还分别在剪裁的VGG-16和ResNet101的末尾添加了额外的卷积层(即conv6_1和conv6_2)和额外的剩余块(即res6)

检测层参数表:

RefineDet选择conv4_3为初始检测层,步长为8,在特征图上移动一点相当于在原始图像上移动8个像素,这种设置不适合检测更小尺寸目标。本文将conv3_3作为初始检测层,特征图步长设为4,更利于检测小尺寸人脸。从conv3_3到conv7_2,宽高比为1:1,检测层参数设置如表1所示。通过在6层卷积特征图上设置不同大小的人脸检测框,能有效提高多尺度人脸的检测精度。

目标检测之RefineDet

3、Anchors设计与匹配

处理不同的规模对象,在VGG-16和ResNet101上选择尺寸分别为8,16,32和64像素步幅大小的特征层,与几种不同尺度的anchor相关联进行预测。 每个特征图层都与一个特定特征anchor的尺度(尺度是相应层步幅的4倍)和三个比率(0.5,1.0和2.0)相关联。 我们遵循不同层上的anchor尺度设计,确保了不同尺度的anchor在图像上具有相同的平铺密度。 同时,在训练期间阶段,我们确定之间的对应关系基于anchors和ground truth boxes的jaccard重叠率(IoU),并端到端地训练整个网络。具体来说,我们首先将每个ground truth boxes与具有最佳重叠分数的anchor boxes相匹配,然后匹配anchor重叠高于0.5的任何ground truth boxes。

4、负样本挖掘

关于正负样本界定的标准基本上和其他object detection类似,比如和ground truth的IoU超过阈值0.5的box为正样本,也就是label是1。显然这样做后很多box的标签都是背景标签,也就是所谓的负样本,通过前面说的ARM部分可以过滤掉一些负样本,但接下来还是要采用类似SSD算法中的hard negative mining来设定正负样本的比例(一般设定为1:3),当然负样本不是随机选的,而是根据box的分类loss排序来选的,按照指定比例选择loss最高的那些负样本即可。

5、损失函数

 损失函数方面主要包含ARM和ODM两方面。在ARM部分包含binary classification损失Lb和回归损失Lr;同理在ODM部分包含multi-class classification损失Lm和回归损失Lr。需要注意的是虽然本文大致上是RPN网络和SSD的结合,但是在Faster R-CNN算法中RPN网络和检测网络的训练可以分开也可以end to end,而这里的训练方式就纯粹是end to end了,ARM和ODM两个部分的损失函数都是一起向前传递的。 

损失函数如下:

 目标检测之RefineDet

Narm和Nodm分别指的是ARM和ODM中正样本anchors的数目,pi指的是预测的anchor i是一个目标的置信度,xi指的是ARM细化后预测的anchor i的坐标,ci是ODM中预测的bbox的物体类别,ti是ODM中预测的bbox坐标,li*是anchor i真实的类别标签,gi*是anchor i真实的位置和大小。 

6、优化器

用“xavier”方法随机初始化基于VGG-16的RefineDet的两个添加的卷积层中(conv6_1和conv6_2)的参数。对于基于ResNet-101的RefineDet,绘制参数来自具有标准的零均值高斯分布,额外残余块(res6)的偏差为0.01。 

default batch size:32
momentum:0.9(加速收敛)
weight decay:0.0005(防止过拟合)
initial learing rate:0.001
different learning rate decay

7、推理

在预测阶段,首先,ARM过滤掉负置信度分数大于阈值θ的anchors,refine剩余anchors的位置和大小。然后, ODM输出每个检测图像前400名高置信度的anchors。 最后,应用NMS,jaccard重叠率限定为0.45 ,并保留前200名高置信度anchors,产生最终的检测结果。

8、实验结果:

目标检测之RefineDet

速度和SSD相近,精度明显更高,精度更高没什么好说的,速度在多了一部分卷积层和反卷积层的情况下没有明显下降,作者分析有两点原因,anchors较少以及基础网络后的附加网路层数少、特征选取层更少(4个,SSD有5个)

1) 我们使用了较少的anchor,如512尺度下,我们总共有1.6W个框,而SSD有2.5W个框。我们使用较少anchor也能达到高精度的原因是二阶段回归。虽然我们总共预设了4个尺度(32,,64,128,256,)和3个比例(0.5,1,2),但是经过第一阶段的回归后,预设的anchor被极大的丰富了,因此用于第二阶段回归的anchor,具备着丰富的尺度和比例。    

2) 第2个原因是,由于显存限制,我们只在基础网络的基础上,新加了很少的卷积层,并只选了4个卷积层作为检测层。如果增加更多卷积层,并选择更多检测层,效果应该还能得到进一步提升。

目标检测之RefineDet

五、改进方案:

5.1 SRN

Shifeng Zhang, Xiangyu Zhu, Zhen Lei, Hailin Shi, Xiaobo Wang, Stan Z. Li, S3FD: Single Shot Scale-invariant Face Detector, ICCV, 2017
Shifeng Zhang, Longyin Wen, Hailin Shi, Zhen Lei, Siwei Lyu, Stan Z. Li, Single-Shot Scale-Aware Network for Real-Time Face Detection, IJCV

网络介绍

网络设置如下,注意P5、P6和P7之间的关系:C2->C5是backbone,P5->P2是反向backbone,而C6、C7、P6、P7都是在backbone后面额外添加的3*3卷积层。
目标检测之RefineDet
按照作者的说法,他将第二阶段的分类、回归操作进行了解耦:
    a. Conduct the two-step classification only on the lower pyramid levels (P2, P3, P4)
    b. Perform the two-step regression only on the higher pyramid levels (P5, P6, P7)
原因如下:可以发现anchors选取的过程中,浅层的占比要远大于深层的占比(空间分辨率大),这导致大量的负样本集中在浅层,所以对其进行预分类是必要的;而深层感受野本身很大,分类相比之下很容易,没必要进行两次分类。

 5.2 AlignDet

之前提到了one stage相较于two stage的四个劣势,refine解决了前三个,最后的特征校准遗留了下来,后续作者把它补上了,如下图:
目标检测之RefineDet

六、下一步展望

更快的速度
更高的准确率
    a. 小物体检测:人脸检测的主要难题就是小物体检测
    b. 遮挡问题:   行人检测的主要问题就是遮挡去除
多任务
    例如检测+分割(最终目标:实例分割、全景分割)
视频目标检测
    利用视频的连续性:精度提升
 利用视频的冗余性:速度提升

参考代码:

   
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow as wrap
import sys
import os
import numpy as np


class RefineDet320:
    def __init__(self, config, data_provider):
        assert config['mode'] in ['train', 'test']
        assert config['data_format'] in ['channels_first', 'channels_last']
        self.config = config
        self.data_provider = data_provider
        self.input_size = config['input_size']
        if config['data_format'] == 'channels_last':
            self.data_shape = [self.input_size, self.input_size, 3]
        else:
            self.data_shape = [3, self.input_size, self.input_size]
        self.num_classes = config['num_classes'] + 1
        self.weight_decay = config['weight_decay']
        self.prob = 1. - config['keep_prob']
        self.data_format = config['data_format']
        self.mode = config['mode']
        self.batch_size = config['batch_size'] if config['mode'] == 'train' else 1
        self.anchor_ratios = [0.5, 1.0, 2.0]
        self.num_anchors = len(self.anchor_ratios)
        self.nms_score_threshold = config['nms_score_threshold']
        self.nms_max_boxes = config['nms_max_boxes']
        self.nms_iou_threshold = config['nms_iou_threshold']
        self.reader = wrap.NewCheckpointReader(config['pretraining_weight'])

        if self.mode == 'train':
            self.num_train = data_provider['num_train']
            self.num_val = data_provider['num_val']
            self.train_generator = data_provider['train_generator']
            self.train_initializer, self.train_iterator = self.train_generator
            if data_provider['val_generator'] is not None:
                self.val_generator = data_provider['val_generator']
                self.val_initializer, self.val_iterator = self.val_generator

        self.global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False)
        self.is_training = True

        self._define_inputs()
        self._build_graph()
        self._create_saver()
        if self.mode == 'train':
            self._create_summary()
        self._init_session()

    def _define_inputs(self):
        shape = [self.batch_size]
        shape.extend(self.data_shape)
        mean = tf.convert_to_tensor([123.68, 116.779, 103.979], dtype=tf.float32)
        if self.data_format == 'channels_last':
            mean = tf.reshape(mean, [1, 1, 1, 3])
        else:
            mean = tf.reshape(mean, [1, 3, 1, 1])
        if self.mode == 'train':
            self.images, self.ground_truth = self.train_iterator.get_next()
            self.images.set_shape(shape)
            self.images = self.images - mean
        else:
            self.images = tf.placeholder(tf.float32, shape, name='images')
            self.images = self.images - mean
            self.ground_truth = tf.placeholder(tf.float32, [self.batch_size, None, 5], name='labels')
        self.lr = tf.placeholder(dtype=tf.float32, shape=[], name='lr')

    def _build_graph(self):
        with tf.variable_scope('feature_extractor'):
            feat1, feat2, feat3, feat4, stride1, stride2, stride3, stride4 = self._feature_extractor(self.images)
            feat1 = tf.nn.l2_normalize(feat1, axis=3 if self.data_format == 'channels_last' else 1)
            feat1_norm_factor = tf.get_variable('feat1_l2_norm', initializer=tf.constant(10.))
            feat1 = feat1_norm_factor * feat1
            feat2 = tf.nn.l2_normalize(feat2, axis=3 if self.data_format == 'channels_last' else 1)
            feat2_norm_factor = tf.get_variable('feat2_l2_norm', initializer=tf.constant(8.))
            feat2 = feat2_norm_factor * feat2
        with tf.variable_scope('ARM'):
            arm1loc, arm1conf = self._arm(feat1, 'arm1')
            arm2loc, arm2conf = self._arm(feat2, 'arm2')
            arm3loc, arm3conf = self._arm(feat3, 'arm3')
            arm4loc, arm4conf = self._arm(feat4, 'arm4')
        with tf.variable_scope('TCB'):
            tcb4 = self._tcb(feat4, 'tcb4')
            tcb3 = self._tcb(feat3, 'tcb3', tcb4)
            tcb2 = self._tcb(feat2, 'tcb2', tcb3)
            tcb1 = self._tcb(feat1, 'tcb1', tcb2)
        with tf.variable_scope('ODM'):
            odm1loc, odm1conf = self._odm(tcb1, 'odm1')
            odm2loc, odm2conf = self._odm(tcb2, 'odm2')
            odm3loc, odm3conf = self._odm(tcb3, 'odm3')
            odm4loc, odm4conf = self._odm(tcb4, 'odm4')
        with tf.variable_scope('inference'):
            if self.data_format == 'channels_first':
                arm1loc = tf.transpose(arm1loc, [0, 2, 3, 1])
                arm1conf = tf.transpose(arm1conf, [0, 2, 3, 1])
                arm2loc = tf.transpose(arm2loc, [0, 2, 3, 1])
                arm2conf = tf.transpose(arm2conf, [0, 2, 3, 1])
                arm3loc = tf.transpose(arm3loc, [0, 2, 3, 1])
                arm3conf = tf.transpose(arm3conf, [0, 2, 3, 1])
                arm4loc = tf.transpose(arm4loc, [0, 2, 3, 1])
                arm4conf = tf.transpose(arm4conf, [0, 2, 3, 1])
                odm1loc = tf.transpose(odm1loc, [0, 2, 3, 1])
                odm1conf = tf.transpose(odm1conf, [0, 2, 3, 1])
                odm2loc = tf.transpose(odm2loc, [0, 2, 3, 1])
                odm2conf = tf.transpose(odm2conf, [0, 2, 3, 1])
                odm3loc = tf.transpose(odm3loc, [0, 2, 3, 1])
                odm3conf = tf.transpose(odm3conf, [0, 2, 3, 1])
                odm4loc = tf.transpose(odm4loc, [0, 2, 3, 1])
                odm4conf = tf.transpose(odm4conf, [0, 2, 3, 1])
            p1shape = tf.shape(arm1loc)
            p2shape = tf.shape(arm2loc)
            p3shape = tf.shape(arm3loc)
            p4shape = tf.shape(arm4loc)
            arm1pbbox_yx, arm1pbbox_hw, arm1pconf = self._get_armpbbox(arm1loc, arm1conf)
            arm2pbbox_yx, arm2pbbox_hw, arm2pconf = self._get_armpbbox(arm2loc, arm2conf)
            arm3pbbox_yx, arm3pbbox_hw, arm3pconf = self._get_armpbbox(arm3loc, arm3conf)
            arm4pbbox_yx, arm4pbbox_hw, arm4pconf = self._get_armpbbox(arm4loc, arm4conf)

            odm1pbbox_yx, odm1pbbox_hw, odm1pconf = self._get_odmpbbox(odm1loc, odm1conf)
            odm2pbbox_yx, odm2pbbox_hw, odm2pconf = self._get_odmpbbox(odm2loc, odm2conf)
            odm3pbbox_yx, odm3pbbox_hw, odm3pconf = self._get_odmpbbox(odm3loc, odm3conf)
            odm4pbbox_yx, odm4pbbox_hw, odm4pconf = self._get_odmpbbox(odm4loc, odm4conf)

            a1bbox_y1x1, a1bbox_y2x2, a1bbox_yx, a1bbox_hw = self._get_abbox(stride1*4, stride1, p1shape)
            a2bbox_y1x1, a2bbox_y2x2, a2bbox_yx, a2bbox_hw = self._get_abbox(stride2*4, stride2, p2shape)
            a3bbox_y1x1, a3bbox_y2x2, a3bbox_yx, a3bbox_hw = self._get_abbox(stride3*4, stride3, p3shape)
            a4bbox_y1x1, a4bbox_y2x2, a4bbox_yx, a4bbox_hw = self._get_abbox(stride4*4, stride4, p4shape)

            armpbbox_yx = tf.concat([arm1pbbox_yx, arm2pbbox_yx, arm3pbbox_yx, arm4pbbox_yx], axis=1)
            armpbbox_hw = tf.concat([arm1pbbox_hw, arm2pbbox_hw, arm3pbbox_hw, arm4pbbox_hw], axis=1)
            armpconf = tf.concat([arm1pconf, arm2pconf, arm3pconf, arm4pconf], axis=1)
            odmpbbox_yx = tf.concat([odm1pbbox_yx, odm2pbbox_yx, odm3pbbox_yx, odm4pbbox_yx], axis=1)
            odmpbbox_hw = tf.concat([odm1pbbox_hw, odm2pbbox_hw, odm3pbbox_hw, odm4pbbox_hw], axis=1)
            odmpconf = tf.concat([odm1pconf, odm2pconf, odm3pconf, odm4pconf], axis=1)
            abbox_y1x1 = tf.concat([a1bbox_y1x1, a2bbox_y1x1, a3bbox_y1x1, a4bbox_y1x1], axis=0)
            abbox_y2x2 = tf.concat([a1bbox_y2x2, a2bbox_y2x2, a3bbox_y2x2, a4bbox_y2x2], axis=0)
            abbox_yx = tf.concat([a1bbox_yx, a2bbox_yx, a3bbox_yx, a4bbox_yx], axis=0)
            abbox_hw = tf.concat([a1bbox_hw, a2bbox_hw, a3bbox_hw, a4bbox_hw], axis=0)
            if self.mode == 'train':
                i = 0.
                loss = 0.
                cond = lambda loss, i: tf.less(i, tf.cast(self.batch_size, tf.float32))
                body = lambda loss, i: (
                    tf.add(loss, self._compute_one_image_loss(
                        tf.squeeze(tf.gather(armpbbox_yx, tf.cast(i, tf.int32))),
                        tf.squeeze(tf.gather(armpbbox_hw, tf.cast(i, tf.int32))),
                        tf.squeeze(tf.gather(armpconf, tf.cast(i, tf.int32))),
                        tf.squeeze(tf.gather(odmpbbox_yx, tf.cast(i, tf.int32))),
                        tf.squeeze(tf.gather(odmpbbox_hw, tf.cast(i, tf.int32))),
                        tf.squeeze(tf.gather(odmpconf, tf.cast(i, tf.int32))),
                        abbox_y1x1,
                        abbox_y2x2,
                        abbox_yx,
                        abbox_hw,
                        tf.squeeze(tf.gather(self.ground_truth, tf.cast(i, tf.int32))),
                    )),
                    tf.add(i, 1.)
                )
                init_state = (loss, i)
                state = tf.while_loop(cond, body, init_state)
                total_loss, _ = state
                total_loss = total_loss / self.batch_size
                optimizer = tf.train.MomentumOptimizer(learning_rate=self.lr, momentum=.9)
                self.loss = total_loss + self.weight_decay * tf.add_n(
                    [tf.nn.l2_loss(var) for var in tf.trainable_variables()]
                )
                train_op = optimizer.minimize(self.loss, global_step=self.global_step)
                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                self.train_op = tf.group([update_ops, train_op])

            else:
                armconft = tf.nn.softmax(armpconf[0, ...])
                odmconft = tf.nn.softmax(odmpconf[0, ...])
                armmask = armconft[:, 1] < 0.99
                odmmask = tf.argmax(odmconft, axis=-1) < self.num_classes - 1
                mask = (tf.cast(armmask, tf.float32) * tf.cast(odmmask, tf.float32)) > 0.
                armpbbox_yxt = tf.boolean_mask(armpbbox_yx[0, ...], mask)
                armpbbox_hwt = tf.boolean_mask(armpbbox_hw[0, ...], mask)
                odmpbbox_yxt = tf.boolean_mask(odmpbbox_yx[0, ...], mask)
                odmpbbox_hwt = tf.boolean_mask(odmpbbox_hw[0, ...], mask)
                abbox_yxt = tf.boolean_mask(abbox_yx, mask)
                abbox_hwt = tf.boolean_mask(abbox_hw, mask)
                odmconft = tf.boolean_mask(odmconft, mask)
                confidence = odmconft[..., :self.num_classes-1]

                arm_yx = armpbbox_yxt * abbox_hwt + abbox_yxt
                arm_hw = tf.exp(armpbbox_hwt) * abbox_hwt
                odm_yx = odmpbbox_yxt * arm_hw + arm_yx
                odm_hw = tf.exp(odmpbbox_hwt) * arm_hw

                odm_y1x1 = odm_yx - odm_hw / 2.
                odm_y2x2 = odm_yx + odm_hw / 2.
                odm_y1x1y2x2 = tf.concat([odm_y1x1, odm_y2x2], axis=-1)
                filter_mask = tf.greater_equal(confidence, self.nms_score_threshold)
                scores = []
                class_id = []
                bbox = []
                for i in range(self.num_classes-1):
                    scoresi = tf.boolean_mask(confidence[:, i], filter_mask[:, i])
                    bboxi = tf.boolean_mask(odm_y1x1y2x2, filter_mask[:, i])
                    selected_indices = tf.image.non_max_suppression(

                        bboxi, scoresi, self.nms_max_boxes, self.nms_iou_threshold,
                    )
                    scores.append(tf.gather(scoresi, selected_indices))
                    bbox.append(tf.gather(bboxi, selected_indices))
                    class_id.append(tf.ones_like(tf.gather(scoresi, selected_indices), tf.int32) * i)
                bbox = tf.concat(bbox, axis=0)
                scores = tf.concat(scores, axis=0)
                class_id = tf.concat(class_id, axis=0)

                self.detection_pred = [scores, bbox, class_id]

    def _feature_extractor(self, images):
        conv1_1 = self._load_conv_layer(images,
                                        tf.get_variable(name='kernel_conv1_1',
                                                        initializer=self.reader.get_tensor("vgg_16/conv1/conv1_1/weights"),
                                                        trainable=True),
                                        tf.get_variable(name='bias_conv1_1',
                                                        initializer=self.reader.get_tensor("vgg_16/conv1/conv1_1/biases"),
                                                        trainable=True),
                                        name="conv1_1")
        conv1_2 = self._load_conv_layer(conv1_1,
                                        tf.get_variable(name='kernel_conv1_2',
                                                        initializer=self.reader.get_tensor("vgg_16/conv1/conv1_2/weights"),
                                                        trainable=True),
                                        tf.get_variable(name='bias_conv1_2',
                                                        initializer=self.reader.get_tensor("vgg_16/conv1/conv1_2/biases"),
                                                        trainable=True),
                                        name="conv1_2")
        pool1 = self._max_pooling(conv1_2, 2, 2, name="pool1")

        conv2_1 = self._load_conv_layer(pool1,
                                        tf.get_variable(name='kenrel_conv2_1',
                                                        initializer=self.reader.get_tensor("vgg_16/conv2/conv2_1/weights"),
                                                        trainable=True),
                                        tf.get_variable(name='bias_conv2_1',
                                                        initializer=self.reader.get_tensor("vgg_16/conv2/conv2_1/biases"),
                                                        trainable=True),
                                        name="conv2_1")
        conv2_2 = self._load_conv_layer(conv2_1,
                                        tf.get_variable(name='kernel_conv2_2',
                                                        initializer=self.reader.get_tensor("vgg_16/conv2/conv2_2/weights"),
                                                        trainable=True),
                                        tf.get_variable(name='bias_conv2_2',
                                                        initializer=self.reader.get_tensor("vgg_16/conv2/conv2_2/biases"),
                                                        trainable=True),
                                        name="conv2_2")
        pool2 = self._max_pooling(conv2_2, 2, 2, name="pool2")
        conv3_1 = self._load_conv_layer(pool2,
                                        tf.get_variable(name='kernel_conv3_1',
                                                        initializer=self.reader.get_tensor("vgg_16/conv3/conv3_1/weights"),
                                                        trainable=True),
                                        tf.get_variable(name='bias_conv_3_1',
                                                        initializer=self.reader.get_tensor("vgg_16/conv3/conv3_1/biases"),
                                                        trainable=True),
                                        name="conv3_1")
        conv3_2 = self._load_conv_layer(conv3_1,
                                        tf.get_variable(name='kernel_conv3_2',
                                                        initializer=self.reader.get_tensor("vgg_16/conv3/conv3_2/weights"),
                                                        trainable=True),
                                        tf.get_variable(name='bias_conv3_2',
                                                        initializer=self.reader.get_tensor("vgg_16/conv3/conv3_2/biases"),
                                                        trainable=True),
                                        name="conv3_2")
        conv3_3 = self._load_conv_layer(conv3_2,
                                        tf.get_variable(name='kernel_conv3_3',
                                                        initializer=self.reader.get_tensor("vgg_16/conv3/conv3_3/weights"),
                                                        trainable=True),
                                        tf.get_variable(name='bias_conv3_3',
                                                        initializer=self.reader.get_tensor("vgg_16/conv3/conv3_3/biases"),
                                                        trainable=True),
                                        name="conv3_3")
        pool3 = self._max_pooling(conv3_3, 2, 2, name="pool3")

        conv4_1 = self._load_conv_layer(pool3,
                                        tf.get_variable(name='kernel_conv4_1',
                                                        initializer=self.reader.get_tensor("vgg_16/conv4/conv4_1/weights"),
                                                        trainable=True),
                                        tf.get_variable(name='bias_conv4_1',
                                                        initializer=self.reader.get_tensor("vgg_16/conv4/conv4_1/biases"),
                                                        trainable=True),
                                        name="conv4_1")
        conv4_2 = self._load_conv_layer(conv4_1,
                                        tf.get_variable(name='kernel_conv4_2',
                                                        initializer=self.reader.get_tensor("vgg_16/conv4/conv4_2/weights"),
                                                        trainable=True),
                                        tf.get_variable(name='bias_conv4_2',
                                                        initializer=self.reader.get_tensor("vgg_16/conv4/conv4_2/biases"),
                                                        trainable=True),
                                        name="conv4_2")
        conv4_3 = self._load_conv_layer(conv4_2,
                                        tf.get_variable(name='kernel_conv4_3',
                                                        initializer=self.reader.get_tensor("vgg_16/conv4/conv4_3/weights"),
                                                        trainable=True),
                                        tf.get_variable(name='bias_conv4_3',
                                                        initializer=self.reader.get_tensor("vgg_16/conv4/conv4_3/biases"),
                                                        trainable=True),
                                        name="conv4_3")
        pool4 = self._max_pooling(conv4_3, 2, 2, name="pool4")
        conv5_1 = self._load_conv_layer(pool4,
                                        tf.get_variable(name='kernel_conv5_1',
                                                        initializer=self.reader.get_tensor("vgg_16/conv5/conv5_1/weights"),
                                                        trainable=True),
                                        tf.get_variable(name='bias_conv5_1',
                                                        initializer=self.reader.get_tensor("vgg_16/conv5/conv5_1/biases"),
                                                        trainable=True),
                                        name="conv5_1")
        conv5_2 = self._load_conv_layer(conv5_1,
                                        tf.get_variable(name='kernel_conv5_2',
                                                        initializer=self.reader.get_tensor("vgg_16/conv5/conv5_2/weights"),
                                                        trainable=True),
                                        tf.get_variable(name='bias_conv5_2',
                                                        initializer=self.reader.get_tensor("vgg_16/conv5/conv5_2/biases"),
                                                        trainable=True),
                                        name="conv5_2")
        conv5_3 = self._load_conv_layer(conv5_2,
                                        tf.get_variable(name='kernel_conv5_3',
                                                        initializer=self.reader.get_tensor("vgg_16/conv5/conv5_3/weights"),
                                                        trainable=True),
                                        tf.get_variable(name='bias_conv5_3',
                                                        initializer=self.reader.get_tensor("vgg_16/conv5/conv5_3/biases"),
                                                        trainable=True),
                                        name="conv5_3")
        pool5 = self._max_pooling(conv5_3, 3, 1, 'pool5')
        conv6 = self._conv_layer(pool5, 1024, 3, 1, 'conv6', dilation_rate=2, activation=tf.nn.relu)
        conv7 = self._conv_layer(conv6, 1024, 1, 1, 'conv7', activation=tf.nn.relu)
        conv8_1 = self._conv_layer(conv7, 256, 1, 1, 'conv8_1', activation=tf.nn.relu)
        conv8_2 = self._conv_layer(conv8_1, 512, 3, 2, 'conv8_2', activation=tf.nn.relu)
        conv9_1 = self._conv_layer(conv8_2, 256, 1, 1, 'conv9_1', activation=tf.nn.relu)
        conv9_2 = self._conv_layer(conv9_1, 512, 3, 2, 'conv9_2', activation=tf.nn.relu)
        conv10_1 = self._conv_layer(conv9_2, 256, 1, 1, 'conv10_1', activation=tf.nn.relu)
        conv10_2 = self._conv_layer(conv10_1, 256, 3, 1, 'conv10_2', activation=tf.nn.relu)
        stride1 = 8
        stride2 = 16
        stride3 = 32
        stride4 = 64
        return conv4_3, conv5_3, conv8_2, conv10_2, stride1, stride2, stride3, stride4

    def _arm(self, bottom, scope):
        with tf.variable_scope(scope):
            conv1 = self._conv_layer(bottom, 256, 3, 1, activation=tf.nn.relu)
            conv2 = self._conv_layer(conv1, 256, 3, 1, activation=tf.nn.relu)
            conv3 = self._conv_layer(conv2, 256, 3, 1, activation=tf.nn.relu)
            conv4 = self._conv_layer(conv3, 256, 3, 1, activation=tf.nn.relu)
            ploc = self._conv_layer(conv4, 4*self.num_anchors, 3, 1)
            pconf = self._conv_layer(conv4, 2*self.num_anchors, 3, 1)
            return ploc, pconf

    def _tcb(self, bottom, scope, high_level_feat=None):
        with tf.variable_scope(scope):
            conv1 = self._conv_layer(bottom, 256, 3, 1, activation=tf.nn.relu)
            conv2 = self._conv_layer(conv1, 256, 3, 1)
            if high_level_feat is not None:
                dconv = self._dconv_layer(high_level_feat, 256, 4, 2)
                conv2 = tf.nn.relu(conv2 + dconv)
            conv3 = tf.nn.relu(conv2)
            return conv3

    def _odm(self, bottom, scope):
        with tf.variable_scope(scope):
            conv1 = self._conv_layer(bottom, 256, 3, 1, activation=tf.nn.relu)
            conv2 = self._conv_layer(conv1, 256, 3, 1, activation=tf.nn.relu)
            conv3 = self._conv_layer(conv2, 256, 3, 1, activation=tf.nn.relu)
            conv4 = self._conv_layer(conv3, 256, 3, 1, activation=tf.nn.relu)
            ploc = self._conv_layer(conv4, 4*self.num_anchors, 3, 1)
            pconf = self._conv_layer(conv4, self.num_classes*self.num_anchors, 3, 1)
            return ploc, pconf

    def _get_armpbbox(self, ploc, pconf):
        pconf = tf.reshape(pconf, [self.batch_size, -1, 2])
        ploc = tf.reshape(ploc, [self.batch_size, -1, 4])
        pbbox_yx = ploc[..., :2]
        pbbox_hw = ploc[..., 2:]
        return pbbox_yx, pbbox_hw, pconf

    def _get_odmpbbox(self, ploc, pconf):
        pconf = tf.reshape(pconf, [self.batch_size, -1, self.num_classes])
        ploc = tf.reshape(ploc, [self.batch_size, -1, 4])
        pbbox_yx = ploc[..., :2]
        pbbox_hw = ploc[..., 2:]
        return pbbox_yx, pbbox_hw, pconf

    def _get_abbox(self, size, stride, pshape):
        topleft_y = tf.range(0., tf.cast(pshape[1], tf.float32), dtype=tf.float32)
        topleft_x = tf.range(0., tf.cast(pshape[2], tf.float32), dtype=tf.float32)
        topleft_y = tf.reshape(topleft_y, [-1, 1, 1, 1]) + 0.5
        topleft_x = tf.reshape(topleft_x, [1, -1, 1, 1]) + 0.5
        topleft_y = tf.tile(topleft_y, [1, pshape[2], 1, 1]) * stride
        topleft_x = tf.tile(topleft_x, [pshape[1], 1, 1, 1]) * stride
        topleft_yx = tf.concat([topleft_y, topleft_x], -1)
        topleft_yx = tf.tile(topleft_yx, [1, 1, self.num_anchors, 1])

        priors = []
        for ratio in self.anchor_ratios:
            priors.append([size*(ratio**0.5), size/(ratio**0.5)])
        priors = tf.convert_to_tensor(priors, tf.float32)
        priors = tf.reshape(priors, [1, 1, -1, 2])

        abbox_y1x1 = tf.reshape(topleft_yx - priors / 2., [-1, 2])
        abbox_y2x2 = tf.reshape(topleft_yx + priors / 2., [-1, 2])
        abbox_yx = abbox_y1x1 / 2. + abbox_y2x2 / 2.
        abbox_hw = abbox_y2x2 - abbox_y1x1
        return abbox_y1x1, abbox_y2x2, abbox_yx, abbox_hw

    def _compute_one_image_loss(self, armpbbox_yx, armpbbox_hw, armpconf,
                                odmpbbox_yx, odmpbbox_hw, odmpconf,
                                abbox_y1x1, abbox_y2x2,
                                abbox_yx, abbox_hw,  ground_truth):
        slice_index = tf.argmin(ground_truth, axis=0)[0]
        ground_truth = tf.gather(ground_truth, tf.range(0, slice_index, dtype=tf.int64))
        gbbox_yx = ground_truth[..., 0:2]
        gbbox_hw = ground_truth[..., 2:4]
        gbbox_y1x1 = gbbox_yx - gbbox_hw / 2.
        gbbox_y2x2 = gbbox_yx + gbbox_hw / 2.
        class_id = tf.cast(ground_truth[..., 4:5], dtype=tf.int32)
        label = class_id

        abbox_hwti = tf.reshape(abbox_hw, [1, -1, 2])
        abbox_y1x1ti = tf.reshape(abbox_y1x1, [1, -1, 2])
        abbox_y2x2ti = tf.reshape(abbox_y2x2, [1, -1, 2])
        gbbox_hwti = tf.reshape(gbbox_hw, [-1, 1, 2])
        gbbox_y1x1ti = tf.reshape(gbbox_y1x1, [-1, 1, 2])
        gbbox_y2x2ti = tf.reshape(gbbox_y2x2, [-1, 1, 2])
        ashape = tf.shape(abbox_hwti)
        gshape = tf.shape(gbbox_hwti)
        abbox_hwti = tf.tile(abbox_hwti, [gshape[0], 1, 1])
        abbox_y1x1ti = tf.tile(abbox_y1x1ti, [gshape[0], 1, 1])
        abbox_y2x2ti = tf.tile(abbox_y2x2ti, [gshape[0], 1, 1])
        gbbox_hwti = tf.tile(gbbox_hwti, [1, ashape[1], 1])
        gbbox_y1x1ti = tf.tile(gbbox_y1x1ti, [1, ashape[1], 1])
        gbbox_y2x2ti = tf.tile(gbbox_y2x2ti, [1, ashape[1], 1])

        gaiou_y1x1ti = tf.maximum(abbox_y1x1ti, gbbox_y1x1ti)
        gaiou_y2x2ti = tf.minimum(abbox_y2x2ti, gbbox_y2x2ti)
        gaiou_area = tf.reduce_prod(tf.maximum(gaiou_y2x2ti - gaiou_y1x1ti, 0), axis=-1)
        aarea = tf.reduce_prod(abbox_hwti, axis=-1)
        garea = tf.reduce_prod(gbbox_hwti, axis=-1)
        gaiou_rate = gaiou_area / (aarea + garea - gaiou_area)

        best_raindex = tf.argmax(gaiou_rate, axis=1)
        best_armpbbox_yx = tf.gather(armpbbox_yx, best_raindex)
        best_armpbbox_hw = tf.gather(armpbbox_hw, best_raindex)
        best_armpconf = tf.gather(armpconf, best_raindex)
        best_odmpbbox_yx = tf.gather(odmpbbox_yx, best_raindex)
        best_odmpbbox_hw = tf.gather(odmpbbox_hw, best_raindex)
        best_odmpconf = tf.gather(odmpconf, best_raindex)
        best_abbox_yx = tf.gather(abbox_yx, best_raindex)
        best_abbox_hw = tf.gather(abbox_hw, best_raindex)

        bestmask, _ = tf.unique(best_raindex)
        bestmask = tf.contrib.framework.sort(bestmask)
        bestmask = tf.reshape(bestmask, [-1, 1])
        bestmask = tf.sparse.SparseTensor(tf.concat([bestmask, tf.zeros_like(bestmask)], axis=-1),
                                          tf.squeeze(tf.ones_like(bestmask)), dense_shape=[ashape[1], 1])
        bestmask = tf.reshape(tf.cast(tf.sparse.to_dense(bestmask), tf.float32), [-1])

        othermask = 1. - bestmask
        othermask = othermask > 0.
        other_armpbbox_yx = tf.boolean_mask(armpbbox_yx, othermask)
        other_armpbbox_hw = tf.boolean_mask(armpbbox_hw, othermask)
        other_armpconf = tf.boolean_mask(armpconf, othermask)
        other_odmpbbox_yx = tf.boolean_mask(odmpbbox_yx, othermask)
        other_odmpbbox_hw = tf.boolean_mask(odmpbbox_hw, othermask)
        other_odmpconf = tf.boolean_mask(odmpconf, othermask)
        other_abbox_yx = tf.boolean_mask(abbox_yx, othermask)
        other_abbox_hw = tf.boolean_mask(abbox_hw, othermask)

        agiou_rate = tf.transpose(gaiou_rate)
        other_agiou_rate = tf.boolean_mask(agiou_rate, othermask)
        max_agiou_rate = tf.reduce_max(other_agiou_rate, axis=1)
        pos_agiou_mask = max_agiou_rate > 0.5
        neg_agiou_mask = max_agiou_rate < 0.4
        rgindex = tf.argmax(other_agiou_rate, axis=1)
        pos_rgindex = tf.boolean_mask(rgindex, pos_agiou_mask)
        pos_armppox_yx = tf.boolean_mask(other_armpbbox_yx, pos_agiou_mask)
        pos_armppox_hw = tf.boolean_mask(other_armpbbox_hw, pos_agiou_mask)
        pos_armpconf = tf.boolean_mask(other_armpconf, pos_agiou_mask)
        pos_odmppox_yx = tf.boolean_mask(other_odmpbbox_yx, pos_agiou_mask)
        pos_odmppox_hw = tf.boolean_mask(other_odmpbbox_hw, pos_agiou_mask)
        pos_odmpconf = tf.boolean_mask(other_odmpconf, pos_agiou_mask)
        pos_abbox_yx = tf.boolean_mask(other_abbox_yx, pos_agiou_mask)
        pos_abbox_hw = tf.boolean_mask(other_abbox_hw, pos_agiou_mask)
        pos_odmlabel = tf.gather(label, pos_rgindex)
        pos_gbbox_yx = tf.gather(gbbox_yx, pos_rgindex)
        pos_gbbox_hw = tf.gather(gbbox_hw, pos_rgindex)
        neg_armpconf = tf.boolean_mask(other_armpconf, neg_agiou_mask)
        neg_armabbox_yx = tf.boolean_mask(other_abbox_yx, neg_agiou_mask)
        neg_armabbox_hw = tf.boolean_mask(other_abbox_hw, neg_agiou_mask)
        neg_armabbox_y1x1y2x2 = tf.concat([neg_armabbox_yx - neg_armabbox_hw/2., neg_armabbox_yx + neg_armabbox_hw/2.], axis=-1)
        neg_odmpconf = tf.boolean_mask(other_odmpconf, neg_agiou_mask)

        total_pos_armpbbox_yx = tf.concat([best_armpbbox_yx, pos_armppox_yx], axis=0)
        total_pos_armpbbox_hw = tf.concat([best_armpbbox_hw, pos_armppox_hw], axis=0)
        total_pos_armpconf = tf.concat([best_armpconf, pos_armpconf], axis=0)
        total_pos_odmpbbox_yx = tf.concat([best_odmpbbox_yx, pos_odmppox_yx], axis=0)
        total_pos_odmpbbox_hw = tf.concat([best_odmpbbox_hw, pos_odmppox_hw], axis=0)
        total_pos_odmpconf = tf.concat([best_odmpconf, pos_odmpconf], axis=0)
        total_pos_odmlabel = tf.concat([label, pos_odmlabel], axis=0)
        total_pos_gbbox_yx = tf.concat([gbbox_yx, pos_gbbox_yx], axis=0)
        total_pos_gbbox_hw = tf.concat([gbbox_hw, pos_gbbox_hw], axis=0)
        total_pos_abbox_yx = tf.concat([best_abbox_yx, pos_abbox_yx], axis=0)
        total_pos_abbox_hw = tf.concat([best_abbox_hw, pos_abbox_hw], axis=0)

        num_pos = tf.shape(total_pos_odmlabel)[0]
        num_armneg = tf.shape(neg_armpconf)[0]
        chosen_num_armneg = tf.cond(num_armneg > 3*num_pos, lambda: 3*num_pos, lambda: num_armneg)
        neg_armclass_id = tf.constant([1])
        pos_armclass_id = tf.constant([0])
        neg_armlabel = tf.tile(neg_armclass_id, [num_armneg])
        pos_armlabel = tf.tile(pos_armclass_id, [num_pos])
        total_neg_armloss = tf.losses.sparse_softmax_cross_entropy(neg_armlabel, neg_armpconf, reduction=tf.losses.Reduction.NONE)
        selected_armindices = tf.image.non_max_suppression(
            neg_armabbox_y1x1y2x2, total_neg_armloss, chosen_num_armneg, iou_threshold=0.7
        )
        neg_armloss = tf.reduce_mean(tf.gather(total_neg_armloss, selected_armindices))

        chosen_neg_armpconf = tf.gather(neg_armpconf, selected_armindices)
        chosen_neg_odmpconf = tf.gather(neg_odmpconf, selected_armindices)

        neg_odm_mask = chosen_neg_armpconf[:, 1] < 0.99
        chosen_neg_odmpconf = tf.boolean_mask(chosen_neg_odmpconf, neg_odm_mask)
        chosen_num_odmneg = tf.shape(chosen_neg_odmpconf)[0]
        neg_odmclass_id = tf.constant([self.num_classes-1])
        neg_odmlabel = tf.tile(neg_odmclass_id, [chosen_num_odmneg])
        neg_odmloss = tf.losses.sparse_softmax_cross_entropy(neg_odmlabel, chosen_neg_odmpconf, reduction=tf.losses.Reduction.MEAN)

        pos_armconf_loss = tf.losses.sparse_softmax_cross_entropy(pos_armlabel, total_pos_armpconf, reduction=tf.losses.Reduction.MEAN)
        pos_truth_armpbbox_yx = (total_pos_gbbox_yx - total_pos_abbox_yx) / total_pos_abbox_hw
        pos_truth_armpbbox_hw = tf.log(total_pos_gbbox_hw / total_pos_abbox_hw)
        pos_yx_armloss = tf.reduce_sum(self._smooth_l1_loss(total_pos_armpbbox_yx - pos_truth_armpbbox_yx), axis=-1)
        pos_hw_armloss = tf.reduce_sum(self._smooth_l1_loss(total_pos_armpbbox_hw - pos_truth_armpbbox_hw), axis=-1)
        pos_coord_armloss = tf.reduce_mean(pos_yx_armloss + pos_hw_armloss)

        arm_yx = total_pos_armpbbox_yx * total_pos_abbox_hw + total_pos_abbox_yx
        arm_hw = tf.exp(total_pos_armpbbox_hw) * total_pos_abbox_hw

        pos_odmconf_loss = tf.losses.sparse_softmax_cross_entropy(total_pos_odmlabel, total_pos_odmpconf, reduction=tf.losses.Reduction.MEAN)
        pos_truth_odmpbbox_yx = (total_pos_gbbox_yx - arm_yx) / arm_hw
        pos_truth_odmpbbox_hw = tf.log(total_pos_gbbox_hw / arm_hw)
        pos_yx_odmloss = tf.reduce_sum(self._smooth_l1_loss(total_pos_odmpbbox_yx - pos_truth_odmpbbox_yx), axis=-1)
        pos_hw_odmloss = tf.reduce_sum(self._smooth_l1_loss(total_pos_odmpbbox_hw - pos_truth_odmpbbox_hw), axis=-1)
        pos_coord_odmloss = tf.reduce_mean(pos_yx_odmloss + pos_hw_odmloss)

        armloss = neg_armloss + pos_armconf_loss + pos_coord_armloss
        odmloss = neg_odmloss + pos_odmconf_loss + pos_coord_odmloss
        return armloss + odmloss

    def _smooth_l1_loss(self, x):
        return tf.where(tf.abs(x) < 1., 0.5*x*x, tf.abs(x)-0.5)

    def _init_session(self):
        self.sess = tf.InteractiveSession()
        self.sess.run(tf.global_variables_initializer())
        if self.mode == 'train':
            self.sess.run(self.train_initializer)

    def _create_saver(self):
        self.saver = tf.train.Saver()
        self.best_saver = tf.train.Saver()

    def _create_summary(self):
        with tf.variable_scope('summaries'):
            tf.summary.scalar('loss', self.loss)
            self.summary_op = tf.summary.merge_all()

    def train_one_epoch(self, lr):
        self.is_training = True
        self.sess.run(self.train_initializer)
        mean_loss = []
        num_iters = self.num_train // self.batch_size
        for i in range(num_iters):
            _, loss = self.sess.run([self.train_op, self.loss], feed_dict={self.lr: lr})
            sys.stdout.write('r>> ' + 'iters '+str(i+1)+str('/')+str(num_iters)+' loss '+str(loss))
            sys.stdout.flush()
            mean_loss.append(loss)
        sys.stdout.write('n')
        mean_loss = np.mean(mean_loss)
        return mean_loss

    def test_one_image(self, images):
        self.is_training = False
        pred = self.sess.run(self.detection_pred, feed_dict={self.images: images})
        return pred

    def save_weight(self, mode, path):
        assert(mode in ['latest', 'best'])
        if mode == 'latest':
            saver = self.saver
        else:
            saver = self.best_saver
        if not tf.gfile.Exists(os.path.dirname(path)):
            tf.gfile.MakeDirs(os.path.dirname(path))
            print(os.path.dirname(path), 'does not exist, create it done')
        saver.save(self.sess, path, global_step=self.global_step)
        print('save', mode, 'model in', path, 'successfully')

    def load_weight(self, path):
        self.saver.restore(self.sess, path)
        print('load weight', path, 'successfully')

    def _bn(self, bottom):
        bn = tf.layers.batch_normalization(
            inputs=bottom,
            axis=3 if self.data_format == 'channels_last' else 1,
            training=self.is_training
        )
        return bn

    def _load_conv_layer(self, bottom, filters, bias, name):
        if self.data_format == 'channels_last':
            data_format = 'NHWC'
        else:
            data_format = 'NCHW'
        conv = tf.nn.conv2d(bottom, filter=filters, strides=[1, 1, 1, 1], name="kernel"+name, padding="SAME", data_format=data_format)
        conv_bias = tf.nn.bias_add(conv, bias=bias, name="bias"+name)
        return tf.nn.relu(conv_bias)

    def _conv_layer(self, bottom, filters, kernel_size, strides, name=None, dilation_rate=1, activation=None):
        conv = tf.layers.conv2d(
            inputs=bottom,
            filters=filters,
            kernel_size=kernel_size,
            strides=strides,
            padding='same',
            name=name,
            data_format=self.data_format,
            dilation_rate=dilation_rate,
        )
        bn = self._bn(conv)
        if activation is not None:
            bn = activation(bn)
        return bn

    def _dconv_layer(self, bottom, filters, kernel_size, strides, name=None, activation=None):
        conv = tf.layers.conv2d_transpose(
            inputs=bottom,
            filters=filters,
            kernel_size=kernel_size,
            strides=strides,
            padding='same',
            name=name,
            data_format=self.data_format,
        )
        bn = self._bn(conv)
        if activation is not None:
            bn = activation(bn)
        return bn

    def _max_pooling(self, bottom, pool_size, strides, name):
        return tf.layers.max_pooling2d(
            inputs=bottom,
            pool_size=pool_size,
            strides=strides,
            padding='same',
            data_format=self.data_format,
            name=name
        )

    def _avg_pooling(self, bottom, pool_size, strides, name):
        return tf.layers.average_pooling2d(
            inputs=bottom,
            pool_size=pool_size,
            strides=strides,
            padding='same',
            data_format=self.data_format,
            name=name
        )

    def _dropout(self, bottom, name):
        return tf.layers.dropout(
            inputs=bottom,
            rate=self.prob,
            training=self.is_training,
            name=name
        )

View Code