知识引
这次,我主要给大家分享Caffe中如何添加新的网络层。
我们的任务是一个图像分割任务,在Caffe官方的框架之中,并不包含图像分割的任务,所以我们需要添加本任务相关方面的一些代码,具体来说将包含三个方面的内容:
第一、添加一个新的图像分割数据层
添加这个新的数据层之后,我们才能按照图像分割这样一个任务,读取我们需要训练的文件,以及将我们需要训练的文件载入内存中进行训练。
第二、修改序列化文件
为什么要修改这样一个序列化文件呢?是因为图像分割这种任务,他本身比图像分类任务来说要更难一些,而且为了使我们训练出来的模型非常鲁棒,我们一般需要做比较多的数据增强的操作。而在Caffe框架之中,它原生只支持镜像操作跟图像crop这样的两个基本操作。所以我们需要修改这个序列化文件,增加一些新的数据增强操作的接口,也就是增加新的参数。
第三,添加新的datatranformer方法
与上面是类似的。图像分割任务本身,它数据读取的方法以及数据筛选的方法肯定与图像分类任务不同。 所以我们需要在data_tranformer.cpp这个CPP之中添加对应的函数,从而增加我们相应的数据操作。
下面我们开始分别分享。首先第一个,我们添加新的数据层头文件,我们在include/caffe/layers这个目录下添加了image_seg_data_layer.hpp,我们定义了一个新的数据层,也就是我们的图像分割任务数据层,他继承了Base Prefet chingDataLayer,而BasePrefetchingDataLayer本身又继承了BaseDataLayer和InternalThread这两个基本类。关键代码如下:
caffe添加新网络层:一站式caffe工程实践连载(4)
继承了这两个基本类之后,就可以实现数据的读取。我们的图像分割层继承了BasePrefetchingDataLayer,也就继承了它的数据接口,从而我们在CPP文件之中就可以重载相应的函数,实现图像分割这个任务本身的数据读取。
另外在.hpp文件之中还定义了一个lines这样的一个变量,如下所示:
caffe添加新网络层:一站式caffe工程实践连载(4)
变量详解
在图像分类任务中也有类似这个变量,但是与我们的分割任务中的lines变量有所不同。图像分类任务,它载入的是图片和一个标签。图片是一个string类型的变量,而标签是一个0123这样的一个int型的变量,所以它的基本数据格式是(string,int),我们分割任务有所不同,分割任务,它的标签也是一张跟原图大小相同的图像。所以我们载入的图像分割的数据文件格式,它(imagepath,imagelabelpath),也就是说都是string类型。 所以我们要定义一个(string,string)这样一个基本数据格式的一个变量。有的HPP文件之后,我们在CPP文件之中对他进行实现,主要包含两个函数,第一个是DataLayerSetup函数。这个函数它本身实现了对我们的训练文件的读取,以及做一些基本的初始化操作。这样一个函数在其中最重要的一部分代码就是如图所示,可以看到我们通过getline这样的一个函数,循环地读取我们输入的文件,也就是source,source就是我们的训练文件,每一行读取的文件就是两个string变量,分别是图像以及图像标签的变量,然后将其投入lines这个变量之中,随后我们就在load_batcj这个函数对lines变量之中的图片进行读取。load_batch这个函数,真正实现了数据的读取,以及将我们读取的图片数据塞入到我们的显存之中进行训练。所以它包含两个方面的操作。首先我们通过ReadImagetoCVMat这样一个函数,在这个函数的内部是调用了Opencv的接口,它实现了图像数据的读取。读取完数据之后,我们存到了cv mat这样的一个格式中,我们得到了cv_image以及cv_label这样的两个文件。随后我们通过data_tranformer中的TransformImageSeg这样的一个函数来实现,将我们刚刚读取的两张图片塞入到真正的内存之中,也就是prefetch_data和prefetch_label这两个变量。
工厂模式
如果大家真正的读过数据层的CPP就会看到在CBB的最后包含了以下两行代码,也就是INSTANTIATE_CLASS和REGISTER_LAYER_CLASS这样的两行代码。这两行代码是什么意思呢?这就涉及到了caffe里面的一个重要的设计模式,也就是工厂设计模式。通过这样的两行代码,我们真正实现了一个新的数据层的注册。那什么是工厂设计模式呢?大家知道Caffe是使用C++代码进行编写。C++标准中最重要的一个设计模式就是面向对象的设计模式。面向对象设计模式之中,工厂设计模式就是最常用的实例化对象的模式,也就是定义了一个用于创建对象的接口,从而让我们实现一个新的子类的时候,它本身可以决定实例化哪一个类。工厂设计模式可以大大提升代码的效率。工程设计模式是caffe整个框架中非常核心的设计模式和思想。大家可以在课下去细细地阅读。
caffe.proto序列文件
下面我们开始第二部分,caffe.proto序列文件。前面我说过了,我们这个新的任务可能需要做一些新的图像增强相关方面的操作。所以我们看TransformerParam这样的一个message。在这个message之中,我们定义了一些新的变量,包括constrast_brightness_adjustment,smooth_filtering等等。这其中包含了一些颜色扰动、对比度变化、旋转变化相关方面的一些变量。随后我们会实现这样的一些图像增强操作。在Caffe.proto中定义完这些变量之后,我们需要在data_tranformer.cpp之中使用它。
TransformImageSeg函数
具体的使用是在data_tranformer的TransformImageSeg这样的一个函数里面,在这个函数里首先我们可以通过param这样的一个变量获取它的成员变量,获取到它的成员变量之后,我们就可以进行相关的数据操作的控制。所有这些数据增强操作的参数,参数变量的控制,都是事先在train.prototxt中进行配置的。总结一句就是说我们首先在train.prototxt配置了我们需要使用的数据增强的操作,然后根据我们定义的参数去寻找到相关的数据操作。随后在image_data_layer.cpp,再去TransformImageSeg这个函数中进行最终的调用。
do_rotation变量
下面我们说说do_rotation这个变量,就是我们刚刚获取的一个是否进行图像旋转这样操作的一个控制变量,这个变量,它最终的控制是在train.prototxt中,也就是我们的网络配置文件中进行控制。如果为true,我们就要调用旋转操作,旋转操作这个函数,我们可以直接定义到data_tranformer.cpp文件之中,它是data_tranformer.cpp的局部函数,对其他的CPP来说它是不可见的。
caffe添加新网络层:一站式caffe工程实践连载(4)
rotate函数
在rotate的函数之中,我们做一些图像旋转的操作,更多的数据增强相关方面的操作,大家可以去我的github项目中进行仔细的阅读,此处不再一一赘述。
caffe添加新网络层:一站式caffe工程实践连载(4)
在经过数据增强操作之后,我们就需要将我们的数据真正地塞入到显存之中,这样才能够进行训练。主要包含两个方面。第一个是赋值,我们需要获取到mutable_cpu_data这样的一个数据指针,它是一个可擦写的数据指针。通过这个数据指针,我们就可以获tranformed_image_blob和tranformed_label_blob相应的数据指针。随后我们通过一个for循环来进行赋值操作。在这个赋值操作之中,就做一些常见的图像变化,包括是否做scale,是否减去均值,是否做一些图像尺度相关方面的操作,关键代码如下所示:
caffe添加新网络层:一站式caffe工程实践连载(4)
添加一个新的图像分割任务的网络相关操作,我们本次就说到这里。
完整内容及视频解读,请扫关注蜂口小程序~
参与内测,免费获取蜂口所有内容,请申请内测(1-8-8-1-1-2-1-7-5-9-5),更有其他优惠福利多多,欢迎大家多多参与,尽情挑刺,凡是好的建议,我们都会虚心采纳哒~
蜂口小程序将持续为你带来最新技术的落地方法,欢迎随时关注了解~