In essence, the depth of the residual shrinkage network belonging to a convolution neural network, is the depth of the residual network (deep residual network, ResNet) of a variant. Its core idea is that the process of learning in depth study of the feature, remove redundant information is very important; soft thresholding is a very flexible way to remove redundant information.
1. Residual depth network
First, in depth introduction residual shrinkage network, we often need to start with the depth of the residual network. The following figure shows the basic module of the depth of the residual network, including some nonlinear layer (residual path) and a constant cross-layer connection and the like. Identity connecting the core network of the residual depth, is a guarantee of its excellent properties.
2. The depth of the residual shrinkage network
Residual shrink network depth is the depth of the residual path A network is a network residual shrinkage. Here the "shrink" refers to the soft thresholding.
Soft thresholding step is the core of a number of signal noise reduction method, which is close to zero (or an absolute value below a certain threshold [tau]) is set to zero characteristic, which is characteristic in the [-τ, τ] interval set to 0, so that other, far away from the 0 also features shrink towards zero.
If the offset b and a convolution layer placed before watching, then set to zero in this interval becomes [-τ + b, τ + b]. Because τ and b are automatically learning parameters obtained in this perspective, then, the soft thresholding can actually be the features of any interval is set to zero, is a more flexible, delete a feature of the range of ways, It can also be understood as a more flexible non-linear mapping.
Viewed from another aspect, in front of the convolution of two layers, and two batch standardization two activation functions, the feature redundant information, converted to a value close to zero; the useful feature away from zero conversion value. Thereafter, a set of thresholds obtained by auto-learning, by a soft thresholding weed out the redundant features, a useful feature would be retained.
By stacking a number of basic modules may be a complete depth of residual shrinkage network, as shown below:
The image recognition program and Keras
Although the depth of the residual shrinkage network was originally applied to fault diagnosis based on vibration signals, but the depth of the residual shrinkage network is in fact a common feature learning method, we believe that in many tasks (computer vision, speech, text) are likely to have some use.
The following is based on a digital handwriting recognition program MNIST depth residual shrinkage network (procedure is very simple, for reference only):
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Sat Dec 28 23:24:05 2019 Implemented using TensorFlow 1.0.1 and Keras 2.2.1 M. Zhao, S. Zhong, X. Fu, et al., Deep Residual Shrinkage Networks for Fault Diagnosis, IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898 @author: me """ from __future__ import print_function import keras import numpy as np from keras.datasets import mnist from keras.layers import Dense, Conv2D, BatchNormalization, Activation from keras.layers import AveragePooling2D, Input, GlobalAveragePooling2D from keras.optimizers import Adam from keras.regularizers import l2 from keras import backend as K from keras.models import Model from keras.layers.core import Lambda K.set_learning_phase(1) # Input image dimensions img_rows, img_cols = 28, 28 # The data, split between train and test sets (x_train, y_train), (x_test, y_test) = mnist.load_data() if K.image_data_format() == 'channels_first': x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) input_shape = (1, img_rows, img_cols) else: x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) input_shape = (img_rows, img_cols, 1) # Noised data x_train = x_train.astype('float32') / 255. + 0.5*np.random.random([x_train.shape[0], img_rows, img_cols, 1]) x_test = x_test.astype('float32') / 255. + 0.5*np.random.random([x_test.shape[0], img_rows, img_cols, 1]) print('x_train shape:', x_train.shape) print(x_train.shape[0], 'train samples') print(x_test.shape[0], 'test samples') # convert class vectors to binary class matrices y_train = keras.utils.to_categorical(y_train, 10) y_test = keras.utils.to_categorical(y_test, 10) def abs_backend(inputs): return K.abs(inputs) def expand_dim_backend(inputs): return K.expand_dims(K.expand_dims(inputs,1),1) def sign_backend(inputs): return K.sign(inputs) def pad_backend(inputs, in_channels, out_channels): pad_dim = (out_channels - in_channels)//2 return K.spatial_3d_padding(inputs, padding = ((0,0),(0,0),(pad_dim,pad_dim))) # Residual Shrinakge Block def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False, downsample_strides=2): residual = incoming in_channels = incoming.get_shape().as_list()[-1] for i in range(nb_blocks): identity = residual if not downsample: downsample_strides = 1 residual = BatchNormalization()(residual) residual = Activation('relu')(residual) residual = Conv2D(out_channels, 3, strides=(downsample_strides, downsample_strides), padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(residual) residual = BatchNormalization()(residual) residual = Activation('relu')(residual) residual = Conv2D(out_channels, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(residual) # Calculate global means residual_abs = Lambda(abs_backend)(residual) abs_mean = GlobalAveragePooling2D()(residual_abs) # Calculate scaling coefficients scales = Dense(out_channels, activation=None, kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(abs_mean) scales = BatchNormalization()(scales) scales = Activation('relu')(scales) scales = Dense(out_channels, activation='sigmoid', kernel_regularizer=l2(1e-4))(scales) scales = Lambda(expand_dim_backend)(scales) # Calculate thresholds thres = keras.layers.multiply([abs_mean, scales]) # Soft thresholding sub = keras.layers.subtract([residual_abs, thres]) zeros = keras.layers.subtract([sub, sub]) n_sub = keras.layers.maximum([sub, zeros]) residual = keras.layers.multiply([Lambda(sign_backend)(residual), n_sub]) # Downsampling (it is important to use the pooL-size of (1, 1)) if downsample_strides > 1: identity = AveragePooling2D(pool_size=(1,1), strides=(2,2))(identity) # Zero_padding to match channels (it is important to use zero padding rather than 1by1 convolution) if in_channels != out_channels: identity = Lambda(pad_backend)(identity, in_channels, out_channels) residual = keras.layers.add([residual, identity]) return residual # define and train a model inputs = Input(shape=input_shape) net = Conv2D(8, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(inputs) net = residual_shrinkage_block(net, 1, 8, downsample=True) net = BatchNormalization()(net) net = Activation('relu')(net) net = GlobalAveragePooling2D()(net) outputs = Dense(10, activation='softmax', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(net) model = Model(inputs=inputs, outputs=outputs) model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy']) model.fit(x_train, y_train, batch_size=100, epochs=5, verbose=1, validation_data=(x_test, y_test)) # get results K.set_learning_phase(0) DRSN_train_score = model.evaluate(x_train, y_train, batch_size=100, verbose=0) print('Train loss:', DRSN_train_score[0]) print('Train accuracy:', DRSN_train_score[1]) DRSN_test_score = model.evaluate(x_test, y_test, batch_size=100, verbose=0) print('Test loss:', DRSN_test_score[0]) print('Test accuracy:', DRSN_test_score[1])
为方便对比,深度残差网络的代码如下:
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Sat Dec 28 23:19:03 2019 Implemented using TensorFlow 1.0 and Keras 2.2.1 K. He, X. Zhang, S. Ren, J. Sun, Deep Residual Learning for Image Recognition, CVPR, 2016. @author: me """ from __future__ import print_function import numpy as np import keras from keras.datasets import mnist from keras.layers import Dense, Conv2D, BatchNormalization, Activation from keras.layers import AveragePooling2D, Input, GlobalAveragePooling2D from keras.optimizers import Adam from keras.regularizers import l2 from keras import backend as K from keras.models import Model from keras.layers.core import Lambda K.set_learning_phase(1) # input image dimensions img_rows, img_cols = 28, 28 # the data, split between train and test sets (x_train, y_train), (x_test, y_test) = mnist.load_data() if K.image_data_format() == 'channels_first': x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) input_shape = (1, img_rows, img_cols) else: x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) input_shape = (img_rows, img_cols, 1) # Noised data x_train = x_train.astype('float32') / 255. + 0.5*np.random.random([x_train.shape[0], img_rows, img_cols, 1]) x_test = x_test.astype('float32') / 255. + 0.5*np.random.random([x_test.shape[0], img_rows, img_cols, 1]) print('x_train shape:', x_train.shape) print(x_train.shape[0], 'train samples') print(x_test.shape[0], 'test samples') # convert class vectors to binary class matrices y_train = keras.utils.to_categorical(y_train, 10) y_test = keras.utils.to_categorical(y_test, 10) def pad_backend(inputs, in_channels, out_channels): pad_dim = (out_channels - in_channels)//2 return K.spatial_3d_padding(inputs, padding = ((0,0),(0,0),(pad_dim,pad_dim))) def residual_block(incoming, nb_blocks, out_channels, downsample=False, downsample_strides=2): residual = incoming in_channels = incoming.get_shape().as_list()[-1] for i in range(nb_blocks): identity = residual if not downsample: downsample_strides = 1 residual = BatchNormalization()(residual) residual = Activation('relu')(residual) residual = Conv2D(out_channels, 3, strides=(downsample_strides, downsample_strides), padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(residual) residual = BatchNormalization()(residual) residual = Activation('relu')(residual) residual = Conv2D(out_channels, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(residual) # Downsampling (it is important to use the pooL-size of (1, 1)) if downsample_strides > 1: identity = AveragePooling2D(pool_size=(1, 1), strides=(2, 2))(identity) # Zero_padding to match channels (it is important to use zero padding rather than 1by1 convolution) if in_channels != out_channels: identity = Lambda(pad_backend)(identity, in_channels, out_channels) residual = keras.layers.add([residual, identity]) return residual # define and train a model inputs = Input(shape=input_shape) net = Conv2D(8, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(inputs) net = residual_block(net, 1, 8, downsample=True) net = BatchNormalization()(net) net = Activation('relu')(net) net = GlobalAveragePooling2D()(net) outputs = Dense(10, activation='softmax', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(net) model = Model(inputs=inputs, outputs=outputs) model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy']) model.fit(x_train, y_train, batch_size=100, epochs=5, verbose=1, validation_data=(x_test, y_test)) # get results K.set_learning_phase(0) resnet_train_score = model.evaluate(x_train, y_train, batch_size=100, verbose=0) print('Train loss:', resnet_train_score[0]) print('Train accuracy:', resnet_train_score[1]) resnet_test_score = model.evaluate(x_test, y_test, batch_size=100, verbose=0) print('Test loss:', resnet_test_score[0]) print('Test accuracy:', resnet_test_score[1])
备注:
(1)深度残差收缩网络的结构比普通的深度残差网络复杂,也许更难训练。
(2)程序里只设置了一个基本模块,在更复杂的数据集上,可适当增加。
(3)如果遇到这个TypeError:softmax() got an unexpected keyword argument 'axis',就点开tensorflow_backend.py,将return tf.nn.softmax(x, axis=axis)中的第一个axis改成dim即可。
转载网址:
https://segmentfault.com/a/1190000021437510
参考文献:
M. Zhao, S. Zhong, X. Fu, et al., Deep residual shrinkage networks for fault diagnosis, IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898