官方文档虽然有多输入多输出的例子[英文] [译文],但是作为使用者,对于keras多输入多输出存在一定疑惑
1 网络层能不能间隔使用,也就是生成Deep Residual Learning。
2 网络连接的时候,merge层链接,能不能自定义merge网络?
merge子类网络层有:add、Subtract、Multiply、Average、Maximum、Minimum、Concatenate、Dot这九个网络层
merge源代码在github可查看
先分析merge父类代码:
1 class _Merge(Layer): 2 """Generic merge layer for elementwise merge functions. 3 Used to implement `Sum`, `Average`, etc. 4 # Arguments 5 **kwargs: standard layer keyword arguments. 6 """ 7 8 def __init__(self, **kwargs): 9 super(_Merge, self).__init__(**kwargs) 10 self.supports_masking = True 11 12 def _merge_function(self, inputs): 13 raise NotImplementedError 14 15 def _compute_elemwise_op_output_shape(self, shape1, shape2): 16 """Computes the shape of the resultant of an elementwise operation. 17 # Arguments 18 shape1: tuple or None. Shape of the first tensor 19 shape2: tuple or None. Shape of the second tensor 20 # Returns 21 expected output shape when an element-wise operation is 22 carried out on 2 tensors with shapes shape1 and shape2. 23 tuple or None. 24 # Raises 25 ValueError: if shape1 and shape2 are not compatible for 26 element-wise operations. 27 """ 28 if None in [shape1, shape2]: 29 return None 30 elif len(shape1) < len(shape2): 31 return self._compute_elemwise_op_output_shape(shape2, shape1) 32 elif len(shape2) == 0: 33 return shape1 34 output_shape = list(shape1[:-len(shape2)]) 35 for i, j in zip(shape1[-len(shape2):], shape2): 36 if i is None or j is None: 37 output_shape.append(None) 38 elif i == 1: 39 output_shape.append(j) 40 elif j == 1: 41 output_shape.append(i) 42 else: 43 if i != j: 44 raise ValueError('Operands could not be broadcast ' 45 'together with shapes ' + 46 str(shape1) + ' ' + str(shape2)) 47 output_shape.append(i) 48 return tuple(output_shape) 49 50 def build(self, input_shape): 51 # Used purely for shape validation. 52 if not isinstance(input_shape, list): 53 raise ValueError('A merge layer should be called ' 54 'on a list of inputs.') 55 if len(input_shape) < 2: 56 raise ValueError('A merge layer should be called ' 57 'on a list of at least 2 inputs. ' 58 'Got ' + str(len(input_shape)) + ' inputs.') 59 batch_sizes = [s[0] for s in input_shape if s is not None] 60 batch_sizes = set(batch_sizes) 61 batch_sizes -= set([None]) 62 if len(batch_sizes) > 1: 63 raise ValueError('Can not merge tensors with different ' 64 'batch sizes. Got tensors with shapes : ' + 65 str(input_shape)) 66 if input_shape[0] is None: 67 output_shape = None 68 else: 69 output_shape = input_shape[0][1:] 70 for i in range(1, len(input_shape)): 71 if input_shape[i] is None: 72 shape = None 73 else: 74 shape = input_shape[i][1:] 75 output_shape = self._compute_elemwise_op_output_shape(output_shape, shape) 76 # If the inputs have different ranks, we have to reshape them 77 # to make them broadcastable. 78 if None not in input_shape and len(set(map(len, input_shape))) == 1: 79 self._reshape_required = False 80 else: 81 self._reshape_required = True 82 83 def call(self, inputs): 84 #返回函数 85 if self._reshape_required: 86 reshaped_inputs = [] 87 input_ndims = list(map(K.ndim, inputs)) 88 if None not in input_ndims: 89 # If ranks of all inputs are available, 90 # we simply expand each of them at axis=1 91 # until all of them have the same rank. 92 max_ndim = max(input_ndims) 93 for x in inputs: 94 x_ndim = K.ndim(x) 95 for _ in range(max_ndim - x_ndim): 96 x = K.expand_dims(x, 1) 97 reshaped_inputs.append(x) 98 return self._merge_function(reshaped_inputs) 99 else: 100 # Transpose all inputs so that batch size is the last dimension. 101 # (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... , batch_size) 102 transposed = False 103 for x in inputs: 104 x_ndim = K.ndim(x) 105 if x_ndim is None: 106 x_shape = K.shape(x) 107 batch_size = x_shape[0] 108 new_shape = K.concatenate([x_shape[1:], K.expand_dims(batch_size)]) 109 x_transposed = K.reshape(x, K.stack([batch_size, K.prod(x_shape[1:])])) 110 x_transposed = K.permute_dimensions(x_transposed, (1, 0)) 111 x_transposed = K.reshape(x_transposed, new_shape) 112 reshaped_inputs.append(x_transposed) 113 transposed = True 114 elif x_ndim > 1: 115 dims = list(range(1, x_ndim)) + [0] 116 reshaped_inputs.append(K.permute_dimensions(x, dims)) 117 transposed = True 118 else: 119 # We don't transpose inputs if they are 1D vectors or scalars. 120 reshaped_inputs.append(x) 121 y = self._merge_function(reshaped_inputs) 122 y_ndim = K.ndim(y) 123 if transposed: 124 # If inputs have been transposed, we have to transpose the output too. 125 if y_ndim is None: 126 y_shape = K.shape(y) 127 y_ndim = K.shape(y_shape)[0] 128 batch_size = y_shape[y_ndim - 1] 129 new_shape = K.concatenate([K.expand_dims(batch_size), y_shape[:y_ndim - 1]]) 130 y = K.reshape(y, (-1, batch_size)) 131 y = K.permute_dimensions(y, (1, 0)) 132 y = K.reshape(y, new_shape) 133 elif y_ndim > 1: 134 dims = [y_ndim - 1] + list(range(y_ndim - 1)) 135 y = K.permute_dimensions(y, dims) 136 return y 137 else: 138 return self._merge_function(inputs) 139 140 def compute_output_shape(self, input_shape): 141 #返回值的shape设置 142 if input_shape[0] is None: 143 output_shape = None 144 else: 145 output_shape = input_shape[0][1:] 146 for i in range(1, len(input_shape)): 147 if input_shape[i] is None: 148 shape = None 149 else: 150 shape = input_shape[i][1:] 151 output_shape = self._compute_elemwise_op_output_shape(output_shape, shape) 152 batch_sizes = [s[0] for s in input_shape if s is not None] 153 batch_sizes = set(batch_sizes) 154 batch_sizes -= set([None]) 155 if len(batch_sizes) == 1: 156 output_shape = (list(batch_sizes)[0],) + output_shape 157 else: 158 output_shape = (None,) + output_shape 159 return output_shape 160 161 def compute_mask(self, inputs, mask=None): 162 if mask is None: 163 return None 164 if not isinstance(mask, list): 165 raise ValueError('`mask` should be a list.') 166 if not isinstance(inputs, list): 167 raise ValueError('`inputs` should be a list.') 168 if len(mask) != len(inputs): 169 raise ValueError('The lists `inputs` and `mask` ' 170 'should have the same length.') 171 if all([m is None for m in mask]): 172 return None 173 masks = [K.expand_dims(m, 0) for m in mask if m is not None] 174 return K.all(K.concatenate(masks, axis=0), axis=0, keepdims=False)
merge父类中调用各类子类层的函数,其实就是直接实例化子类:
def add(inputs, **kwargs): """Functional interface to the `Add` layer. # Arguments inputs: A list of input tensors (at least 2). **kwargs: Standard layer keyword arguments. # Returns A tensor, the sum of the inputs. # Examples ```python import keras input1 = keras.layers.Input(shape=(16,)) x1 = keras.layers.Dense(8, activation='relu')(input1) input2 = keras.layers.Input(shape=(32,)) x2 = keras.layers.Dense(8, activation='relu')(input2) added = keras.layers.add([x1, x2]) out = keras.layers.Dense(4)(added) model = keras.models.Model(inputs=[input1, input2], outputs=out) ``` """ return Add(**kwargs)(inputs) def subtract(inputs, **kwargs): """Functional interface to the `Subtract` layer. # Arguments inputs: A list of input tensors (exactly 2). **kwargs: Standard layer keyword arguments. # Returns A tensor, the difference of the inputs. # Examples ```python import keras input1 = keras.layers.Input(shape=(16,)) x1 = keras.layers.Dense(8, activation='relu')(input1) input2 = keras.layers.Input(shape=(32,)) x2 = keras.layers.Dense(8, activation='relu')(input2) subtracted = keras.layers.subtract([x1, x2]) out = keras.layers.Dense(4)(subtracted) model = keras.models.Model(inputs=[input1, input2], outputs=out) ``` """ return Subtract(**kwargs)(inputs) def multiply(inputs, **kwargs): """Functional interface to the `Multiply` layer. # Arguments inputs: A list of input tensors (at least 2). **kwargs: Standard layer keyword arguments. # Returns A tensor, the element-wise product of the inputs. """ return Multiply(**kwargs)(inputs) def average(inputs, **kwargs): """Functional interface to the `Average` layer. # Arguments inputs: A list of input tensors (at least 2). **kwargs: Standard layer keyword arguments. # Returns A tensor, the average of the inputs. """ return Average(**kwargs)(inputs) def maximum(inputs, **kwargs): """Functional interface to the `Maximum` layer. # Arguments inputs: A list of input tensors (at least 2). **kwargs: Standard layer keyword arguments. # Returns A tensor, the element-wise maximum of the inputs. """ return Maximum(**kwargs)(inputs) def minimum(inputs, **kwargs): """Functional interface to the `Minimum` layer. # Arguments inputs: A list of input tensors (at least 2). **kwargs: Standard layer keyword arguments. # Returns A tensor, the element-wise minimum of the inputs. """ return Minimum(**kwargs)(inputs) def concatenate(inputs, axis=-1, **kwargs): """Functional interface to the `Concatenate` layer. # Arguments inputs: A list of input tensors (at least 2). axis: Concatenation axis. **kwargs: Standard layer keyword arguments. # Returns A tensor, the concatenation of the inputs alongside axis `axis`. """ return Concatenate(axis=axis, **kwargs)(inputs) def dot(inputs, axes, normalize=False, **kwargs): """Functional interface to the `Dot` layer. # Arguments inputs: A list of input tensors (at least 2). axes: Integer or tuple of integers, axis or axes along which to take the dot product. normalize: Whether to L2-normalize samples along the dot product axis before taking the dot product. If set to True, then the output of the dot product is the cosine proximity between the two samples. **kwargs: Standard layer keyword arguments. # Returns A tensor, the dot product of the samples from the inputs. """ return Dot(axes=axes, normalize=normalize, **kwargs)(inputs)
简单的子层,只需要重载_merge_function,其它函数继承父类
Add层:
class Add(_Merge): """Layer that adds a list of inputs. It takes as input a list of tensors, all of the same shape, and returns a single tensor (also of the same shape). # Examples ```python import keras input1 = keras.layers.Input(shape=(16,)) x1 = keras.layers.Dense(8, activation='relu')(input1) input2 = keras.layers.Input(shape=(32,)) x2 = keras.layers.Dense(8, activation='relu')(input2) added = keras.layers.Add()([x1, x2]) # equivalent to added = keras.layers.add([x1, x2]) out = keras.layers.Dense(4)(added) model = keras.models.Model(inputs=[input1, input2], outputs=out) ``` """ #把所有输入都与第一个输入相加,意味着你可以使用两个以上的网络层输入…… def _merge_function(self, inputs): output = inputs[0] for i in range(1, len(inputs)): output += inputs[i] return output
Subtract层:
class Subtract(_Merge): """Layer that subtracts two inputs. It takes as input a list of tensors of size 2, both of the same shape, and returns a single tensor, (inputs[0] - inputs[1]), also of the same shape. # Examples ```python import keras input1 = keras.layers.Input(shape=(16,)) x1 = keras.layers.Dense(8, activation='relu')(input1) input2 = keras.layers.Input(shape=(32,)) x2 = keras.layers.Dense(8, activation='relu')(input2) # Equivalent to subtracted = keras.layers.subtract([x1, x2]) subtracted = keras.layers.Subtract()([x1, x2]) out = keras.layers.Dense(4)(subtracted) model = keras.models.Model(inputs=[input1, input2], outputs=out) ``` """ #输入的层数只能为两个,第一个层减去第二个层 def _merge_function(self, inputs): if len(inputs) != 2: raise ValueError('`Subtract` layer should be called ' 'on exactly 2 inputs') if inputs[0]._keras_shape != inputs[1]._keras_shape: raise ValueError('`Subtract` layer should be called ' 'on inputs of the same shape') return inputs[0] - inputs[1]
Multiply层:
class Multiply(_Merge): #其他的层都与第一层相乘,合并的层数可以无穷 """Layer that multiplies (element-wise) a list of inputs. It takes as input a list of tensors, all of the same shape, and returns a single tensor (also of the same shape). """ def _merge_function(self, inputs): output = inputs[0] for i in range(1, len(inputs)): output *= inputs[i] return output
Average层:多层求平均值
Maximum层:多层中的最大值
Minimum层:多层中的最小值
Concatenate层:
1 class Concatenate(_Merge): 2 #由于连接层的复杂性,所以需要自定义,weghts大小,和该层的各个属性。 3 #根据需要的坐标系,连接网络层 4 """Layer that concatenates a list of inputs. 5 It takes as input a list of tensors, 6 all of the same shape expect for the concatenation axis, 7 and returns a single tensor, the concatenation of all inputs. 8 # Arguments 9 axis: Axis along which to concatenate. 10 **kwargs: standard layer keyword arguments. 11 """ 12 13 def __init__(self, axis=-1, **kwargs): 14 super(Concatenate, self).__init__(**kwargs) 15 self.axis = axis 16 self.supports_masking = True 17 18 def build(self, input_shape): 19 # Used purely for shape validation. 20 if not isinstance(input_shape, list): 21 raise ValueError('`Concatenate` layer should be called ' 22 'on a list of inputs') 23 if all([shape is None for shape in input_shape]): 24 return 25 reduced_inputs_shapes = [list(shape) for shape in input_shape] 26 shape_set = set() 27 for i in range(len(reduced_inputs_shapes)): 28 del reduced_inputs_shapes[i][self.axis] 29 shape_set.add(tuple(reduced_inputs_shapes[i])) 30 if len(shape_set) > 1: 31 raise ValueError('`Concatenate` layer requires ' 32 'inputs with matching shapes ' 33 'except for the concat axis. ' 34 'Got inputs shapes: %s' % (input_shape)) 35 #tensorflow代码实现返回 36 def call(self, inputs): 37 if not isinstance(inputs, list): 38 raise ValueError('A `Concatenate` layer should be called ' 39 'on a list of inputs.') 40 return K.concatenate(inputs, axis=self.axis) 41 #设置该层输出值的shape大小 42 def compute_output_shape(self, input_shape): 43 if not isinstance(input_shape, list): 44 raise ValueError('A `Concatenate` layer should be called ' 45 'on a list of inputs.') 46 input_shapes = input_shape 47 output_shape = list(input_shapes[0]) 48 for shape in input_shapes[1:]: 49 if output_shape[self.axis] is None or shape[self.axis] is None: 50 output_shape[self.axis] = None 51 break 52 output_shape[self.axis] += shape[self.axis] 53 return tuple(output_shape) 54 #有无mask元素(屏蔽元素) 55 def compute_mask(self, inputs, mask=None): 56 if mask is None: 57 return None 58 if not isinstance(mask, list): 59 raise ValueError('`mask` should be a list.') 60 if not isinstance(inputs, list): 61 raise ValueError('`inputs` should be a list.') 62 if len(mask) != len(inputs): 63 raise ValueError('The lists `inputs` and `mask` ' 64 'should have the same length.') 65 if all([m is None for m in mask]): 66 return None 67 # Make a list of masks while making sure 68 # the dimensionality of each mask 69 # is the same as the corresponding input. 70 masks = [] 71 for input_i, mask_i in zip(inputs, mask): 72 if mask_i is None: 73 # Input is unmasked. Append all 1s to masks, 74 # but cast it to bool first 75 masks.append(K.cast(K.ones_like(input_i), 'bool')) 76 elif K.ndim(mask_i) < K.ndim(input_i): 77 # Mask is smaller than the input, expand it 78 masks.append(K.expand_dims(mask_i)) 79 else: 80 masks.append(mask_i) 81 concatenated = K.concatenate(masks, axis=self.axis) 82 return K.all(concatenated, axis=-1, keepdims=False) 83 84 def get_config(self): 85 config = { 86 'axis': self.axis, 87 } 88 #super申明使用父类设置 89 base_config = super(Concatenate, self).get_config() 90 return dict(list(base_config.items()) + list(config.items()))
Dot层:计算向量积,融合的层数为2
1 class Dot(_Merge): 2 """Layer that computes a dot product between samples in two tensors. 3 E.g. if applied to two tensors `a` and `b` of shape `(batch_size, n)`, 4 the output will be a tensor of shape `(batch_size, 1)` 5 where each entry `i` will be the dot product between 6 `a[i]` and `b[i]`. 7 # Arguments 8 axes: Integer or tuple of integers, 9 axis or axes along which to take the dot product. 10 normalize: Whether to L2-normalize samples along the 11 dot product axis before taking the dot product. 12 If set to True, then the output of the dot product 13 is the cosine proximity between the two samples. 14 **kwargs: Standard layer keyword arguments. 15 """ 16 17 def __init__(self, axes, normalize=False, **kwargs): 18 super(Dot, self).__init__(**kwargs) 19 if not isinstance(axes, int): 20 if not isinstance(axes, (list, tuple)): 21 raise TypeError('Invalid type for `axes` - ' 22 'should be a list or an int.') 23 if len(axes) != 2: 24 raise ValueError('Invalid format for `axes` - ' 25 'should contain two elements.') 26 if not isinstance(axes[0], int) or not isinstance(axes[1], int): 27 raise ValueError('Invalid format for `axes` - ' 28 'list elements should be "int".') 29 self.axes = axes 30 self.normalize = normalize 31 self.supports_masking = True 32 33 def build(self, input_shape): 34 # Used purely for shape validation. 35 if not isinstance(input_shape, list) or len(input_shape) != 2: 36 raise ValueError('A `Dot` layer should be called ' 37 'on a list of 2 inputs.') 38 shape1 = input_shape[0] 39 shape2 = input_shape[1] 40 if shape1 is None or shape2 is None: 41 return 42 if isinstance(self.axes, int): 43 if self.axes < 0: 44 axes = [self.axes % len(shape1), self.axes % len(shape2)] 45 else: 46 axes = [self.axes] * 2 47 else: 48 axes = self.axes 49 if shape1[axes[0]] != shape2[axes[1]]: 50 raise ValueError( 51 'Dimension incompatibility ' 52 '%s != %s. ' % (shape1[axes[0]], shape2[axes[1]]) + 53 'Layer shapes: %s, %s' % (shape1, shape2)) 54 #实现向量积,操作,根据axis,进行操作,具体操作语句为k.batch_dot(x1,x2) 55 def call(self, inputs): 56 x1 = inputs[0] 57 x2 = inputs[1] 58 if isinstance(self.axes, int): 59 if self.axes < 0: 60 axes = [self.axes % K.ndim(x1), self.axes % K.ndim(x2)] 61 else: 62 axes = [self.axes] * 2 63 else: 64 axes = [] 65 for i in range(len(self.axes)): 66 if self.axes[i] < 0: 67 axes.append(self.axes[i] % K.ndim(inputs[i])) 68 else: 69 axes.append(self.axes[i]) 70 if self.normalize: 71 x1 = K.l2_normalize(x1, axis=axes[0]) 72 x2 = K.l2_normalize(x2, axis=axes[1]) 73 output = K.batch_dot(x1, x2, axes) 74 return output 75 76 def compute_output_shape(self, input_shape): 77 if not isinstance(input_shape, list) or len(input_shape) != 2: 78 raise ValueError('A `Dot` layer should be called ' 79 'on a list of 2 inputs.') 80 shape1 = list(input_shape[0]) 81 shape2 = list(input_shape[1]) 82 if isinstance(self.axes, int): 83 if self.axes < 0: 84 axes = [self.axes % len(shape1), self.axes % len(shape2)] 85 else: 86 axes = [self.axes] * 2 87 else: 88 axes = self.axes 89 shape1.pop(axes[0]) 90 shape2.pop(axes[1]) 91 shape2.pop(0) 92 output_shape = shape1 + shape2 93 if len(output_shape) == 1: 94 output_shape += [1] 95 return tuple(output_shape) 96 97 def compute_mask(self, inputs, mask=None): 98 return None 99 100 def get_config(self): 101 config = { 102 'axes': self.axes, 103 'normalize': self.normalize, 104 } 105 base_config = super(Dot, self).get_config() 106 return dict(list(base_config.items()) + list(config.items()))
由于知道各个融合成实现的原理,所以能够自定义融合层:
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:keras 多输入多输出实验,融合层 - Python技术站