from keras.engine import Layer, InputSpec
from keras import initializers
from keras import regularizers
from keras import constraints
from keras import backend as K
from keras.utils.generic_utils import get_custom_objects
classGroupNormalization(Layer):def__init__(self,
groups=32,
axis=-1,
epsilon=1e-5,
center=True,
scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,**kwargs):super(GroupNormalization, self).__init__(**kwargs)
self.supports_masking =True
self.groups = groups
self.axis = axis
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)defbuild(self, input_shape):
dim = input_shape[self.axis]if dim isNone:raise ValueError('Axis '+str(self.axis)+' of ''input tensor should have a defined dimension ''but the layer received an input with shape '+str(input_shape)+'.')if dim < self.groups:raise ValueError('Number of groups ('+str(self.groups)+') cannot be ''more than the number of channels ('+str(dim)+').')if dim % self.groups !=0:raise ValueError('Number of groups ('+str(self.groups)+') must be a ''multiple of the number of channels ('+str(dim)+').')
self.input_spec = InputSpec(ndim=len(input_shape),
axes={
self.axis: dim})
shape =(dim,)if self.scale:
self.gamma = self.add_weight(shape=shape,
name='gamma',
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)else:
self.gamma =Noneif self.center:
self.beta = self.add_weight(shape=shape,
name='beta',
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint)else:
self.beta =None
self.built =Truedefcall(self, inputs,**kwargs):
input_shape = K.int_shape(inputs)
tensor_input_shape = K.shape(inputs)# Prepare broadcasting shape.
reduction_axes =list(range(len(input_shape)))del reduction_axes[self.axis]
broadcast_shape =[1]*len(input_shape)
broadcast_shape[self.axis]= input_shape[self.axis]// self.groups
broadcast_shape.insert(1, self.groups)
reshape_group_shape = K.shape(inputs)
group_axes =[reshape_group_shape[i]for i inrange(len(input_shape))]
group_axes[self.axis]= input_shape[self.axis]// self.groups
group_axes.insert(1, self.groups)# reshape inputs to new group shape
group_shape =[group_axes[0], self.groups]+ group_axes[2:]
group_shape = K.stack(group_shape)
inputs = K.reshape(inputs, group_shape)
group_reduction_axes =list(range(len(group_axes)))
group_reduction_axes = group_reduction_axes[2:]
mean = K.mean(inputs, axis=group_reduction_axes, keepdims=True)
variance = K.var(inputs, axis=group_reduction_axes, keepdims=True)
inputs =(inputs - mean)/(K.sqrt(variance + self.epsilon))# prepare broadcast shape
inputs = K.reshape(inputs, group_shape)
outputs = inputs
# In this case we must explicitly broadcast all parameters.if self.scale:
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
outputs = outputs * broadcast_gamma
if self.center:
broadcast_beta = K.reshape(self.beta, broadcast_shape)
outputs = outputs + broadcast_beta
outputs = K.reshape(outputs, tensor_input_shape)return outputs
defget_config(self):
config ={
'groups': self.groups,'axis': self.axis,'epsilon': self.epsilon,'center': self.center,'scale': self.scale,'beta_initializer': initializers.serialize(self.beta_initializer),'gamma_initializer': initializers.serialize(self.gamma_initializer),'beta_regularizer': regularizers.serialize(self.beta_regularizer),'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),'beta_constraint': constraints.serialize(self.beta_constraint),'gamma_constraint': constraints.serialize(self.gamma_constraint)}
base_config =super(GroupNormalization, self).get_config()returndict(list(base_config.items())+list(config.items()))defcompute_output_shape(self, input_shape):return input_shape
get_custom_objects().update({
'GroupNormalization': GroupNormalization})
transfer
from keras.layers import Input
from keras.models import Model
from GroupNorm import GroupNormalization
image_input = Input(shape=(None,None,3))
x = GroupNormalization(groups=2, axis=-1, epsilon=0.1)(image_input)
model = Model(image_input, x)
model.summary()//GroupNormalization()