[Deep residual shrinkage network] Deep-Residual-Shrinkage-Networks model + code

https://github.com/zhao62/Deep-Residual-Shrinkage-Networks  code link

1. Code

1.1DRSN_keras.py

The python version is 3.6

Install tensorflow1.15.0

Direct use of keras in tensorflow

The code in the import section is changed to:

from __future__ import print_function
import numpy as np
from tensorflow import keras
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, Conv2D, BatchNormalization, Activation
from tensorflow.keras.layers import AveragePooling2D, Input, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow_core.python.keras.layers import Lambda

Body code:

K.set_learning_phase(1)

# Input image dimensions
img_rows, img_cols = 28, 28

# The data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

# Noised data
x_train = x_train.astype('float32') / 255. + 0.5*np.random.random([x_train.shape[0], img_rows, img_cols, 1])
x_test = x_test.astype('float32') / 255. + 0.5*np.random.random([x_test.shape[0], img_rows, img_cols, 1])
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)


def abs_backend(inputs):
    return K.abs(inputs)

def expand_dim_backend(inputs):
    return K.expand_dims(K.expand_dims(inputs,1),1)

def sign_backend(inputs):
    return K.sign(inputs)

def pad_backend(inputs, in_channels, out_channels):
    pad_dim = (out_channels - in_channels)//2
    inputs = K.expand_dims(inputs,-1)
    inputs = K.spatial_3d_padding(inputs, ((0,0),(0,0),(pad_dim,pad_dim)), 'channels_last')
    return K.squeeze(inputs, -1)

# Residual Shrinakge Block
def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
                             downsample_strides=2):
    
    residual = incoming
    in_channels = incoming.get_shape().as_list()[-1]
    
    for i in range(nb_blocks):
        
        identity = residual
        
        if not downsample:
            downsample_strides = 1
        
        residual = BatchNormalization()(residual)
        residual = Activation('relu')(residual)
        residual = Conv2D(out_channels, 3, strides=(downsample_strides, downsample_strides), 
                          padding='same', kernel_initializer='he_normal', 
                          kernel_regularizer=l2(1e-4))(residual)
        
        residual = BatchNormalization()(residual)
        residual = Activation('relu')(residual)
        residual = Conv2D(out_channels, 3, padding='same', kernel_initializer='he_normal', 
                          kernel_regularizer=l2(1e-4))(residual)
        
        # Calculate global means
        residual_abs = Lambda(abs_backend)(residual)
        abs_mean = GlobalAveragePooling2D()(residual_abs)
        
        # Calculate scaling coefficients
        scales = Dense(out_channels, activation=None, kernel_initializer='he_normal', 
                       kernel_regularizer=l2(1e-4))(abs_mean)
        scales = BatchNormalization()(scales)
        scales = Activation('relu')(scales)
        scales = Dense(out_channels, activation='sigmoid', kernel_regularizer=l2(1e-4))(scales)
        scales = Lambda(expand_dim_backend)(scales)
        
        # Calculate thresholds
        thres = keras.layers.multiply([abs_mean, scales])
        
        # Soft thresholding
        sub = keras.layers.subtract([residual_abs, thres])
        zeros = keras.layers.subtract([sub, sub])
        n_sub = keras.layers.maximum([sub, zeros])
        residual = keras.layers.multiply([Lambda(sign_backend)(residual), n_sub])
        
        # Downsampling using the pooL-size of (1, 1)
        if downsample_strides > 1:
            identity = AveragePooling2D(pool_size=(1,1), strides=(2,2))(identity)
            
        # Zero_padding to match channels
        if in_channels != out_channels:
            identity = Lambda(pad_backend, arguments={'in_channels':in_channels,'out_channels':out_channels})(identity)
        
        residual = keras.layers.add([residual, identity])
    
    return residual


# define and train a model
inputs = Input(shape=input_shape)
net = Conv2D(8, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(inputs)
net = residual_shrinkage_block(net, 1, 8, downsample=True)
net = BatchNormalization()(net)
net = Activation('relu')(net)
net = GlobalAveragePooling2D()(net)
outputs = Dense(10, activation='softmax', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(net)
model = Model(inputs=inputs, outputs=outputs)
model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=100, epochs=5, verbose=1, validation_data=(x_test, y_test))

# get results
K.set_learning_phase(0)
DRSN_train_score = model.evaluate(x_train, y_train, batch_size=100, verbose=0)
print('Train loss:', DRSN_train_score[0])
print('Train accuracy:', DRSN_train_score[1])
DRSN_test_score = model.evaluate(x_test, y_test, batch_size=100, verbose=0)
print('Test loss:', DRSN_test_score[0])
print('Test accuracy:', DRSN_test_score[1])

Experimental results:

1.2 DRSN_TFLearn.py

from __future__ import division, print_function, absolute_import

import tflearn
import numpy as np
import tensorflow as tf
from tflearn.layers.conv import conv_2d

# Data loading
from tflearn.datasets import cifar10

(X, Y), (testX, testY) = cifar10.load_data()

# Add noise
X = X + np.random.random((50000, 32, 32, 3)) * 0.1
testX = testX + np.random.random((10000, 32, 32, 3)) * 0.1

# Transform labels to one-hot format
Y = tflearn.data_utils.to_categorical(Y, 10)
testY = tflearn.data_utils.to_categorical(testY, 10)


def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
                             downsample_strides=2, activation='relu', batch_norm=True,
                             bias=True, weights_init='variance_scaling',
                             bias_init='zeros', regularizer='L2', weight_decay=0.0001,
                             trainable=True, restore=True, reuse=False, scope=None,
                             name="ResidualBlock"):
    # residual shrinkage blocks with channel-wise thresholds

    residual = incoming
    in_channels = incoming.get_shape().as_list()[-1]

    # Variable Scope fix for older TF
    try:
        vscope = tf.variable_scope(scope, default_name=name, values=[incoming],
                                   reuse=reuse)
    except Exception:
        vscope = tf.compat.v1.variable_op_scope([incoming], scope, name, reuse=reuse)

    with vscope as scope:
        name = scope.name  # TODO

        for i in range(nb_blocks):

            identity = residual

            if not downsample:
                downsample_strides = 1

            if batch_norm:
                residual = tflearn.batch_normalization(residual)
            residual = tflearn.activation(residual, activation)
            residual = conv_2d(residual, out_channels, 3,
                               downsample_strides, 'same', 'linear',
                               bias, weights_init, bias_init,
                               regularizer, weight_decay, trainable,
                               restore)

            if batch_norm:
                residual = tflearn.batch_normalization(residual)
            residual = tflearn.activation(residual, activation)
            residual = conv_2d(residual, out_channels, 3, 1, 'same',
                               'linear', bias, weights_init,
                               bias_init, regularizer, weight_decay,
                               trainable, restore)

            # get thresholds and apply thresholding
            abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual), axis=2, keep_dims=True), axis=1, keep_dims=True)
            scales = tflearn.fully_connected(abs_mean, out_channels // 4, activation='linear', regularizer='L2',
                                             weight_decay=0.0001, weights_init='variance_scaling')
            scales = tflearn.batch_normalization(scales)
            scales = tflearn.activation(scales, 'relu')
            scales = tflearn.fully_connected(scales, out_channels, activation='linear', regularizer='L2',
                                             weight_decay=0.0001, weights_init='variance_scaling')
            scales = tf.expand_dims(tf.expand_dims(scales, axis=1), axis=1)
            thres = tf.multiply(abs_mean, tflearn.activations.sigmoid(scales))
            # soft thresholding
            residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual) - thres, 0))

            # Downsampling
            if downsample_strides > 1:
                identity = tflearn.avg_pool_2d(identity, 1,
                                               downsample_strides)

            # Projection to new dimension
            if in_channels != out_channels:
                if (out_channels - in_channels) % 2 == 0:
                    ch = (out_channels - in_channels) // 2
                    identity = tf.pad(identity,
                                      [[0, 0], [0, 0], [0, 0], [ch, ch]])
                else:
                    ch = (out_channels - in_channels) // 2
                    identity = tf.pad(identity,
                                      [[0, 0], [0, 0], [0, 0], [ch, ch + 1]])
                in_channels = out_channels

            residual = residual + identity

    return residual


# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True)

# Real-time data augmentation
img_aug = tflearn.ImageAugmentation()
img_aug.add_random_flip_leftright()
img_aug.add_random_crop([32, 32], padding=4)

# Build a Deep Residual Shrinkage Network with 3 blocks
net = tflearn.input_data(shape=[None, 32, 32, 3],
                         data_preprocessing=img_prep,
                         data_augmentation=img_aug)
net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001)
net = residual_shrinkage_block(net, 1, 16)
net = residual_shrinkage_block(net, 1, 32, downsample=True)
net = residual_shrinkage_block(net, 1, 32, downsample=True)
net = tflearn.batch_normalization(net)
net = tflearn.activation(net, 'relu')
net = tflearn.global_avg_pool(net)
# Regression
net = tflearn.fully_connected(net, 10, activation='softmax')
mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=20000, staircase=True)
net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')
# Training
model = tflearn.DNN(net, checkpoint_path='model_cifar10',
                    max_checkpoints=10, tensorboard_verbose=0,
                    clip_gradients=0.)

model.fit(X, Y, n_epoch=100, snapshot_epoch=False, snapshot_step=500,
          show_metric=True, batch_size=100, shuffle=True, run_id='model_cifar10')

training_acc = model.evaluate(X, Y)[0]
validation_acc = model.evaluate(testX, testY)[0]

Experimental results

2. Model

2.1 Residual network

The deep residual shrinkage network is a network that shrinks (soft thresholds) the residual path of the "deep residual network".

Design idea : In the process of feature learning, it is also very important to remove redundant information.

Residual module:

        As shown in the figure below, the cuboid represents the feature map with the number of channels C, the width W, and the height 1; a residual module can contain two batch normalization (Batch Normalization, BN), two rectified linear unit activation functions (Rectifier Linear Unit activation function, ReLU), two convolutional layers (Convolutional layer) and identity mapping (Identity shortcut). Identity mapping is the core contribution of deep residual network, which greatly reduces the difficulty of deep neural network training. K represents the number of convolution kernels in the convolution layer.

 (a) Input feature map size = output feature map size. In the residual module, the width of the output feature map can be changed. Figure (b) sets the moving step of the convolution kernel in the convolution layer to 2 (indicated by /2), then the width of the output feature map will be reduced Half, it becomes 0.5W. The number of channels of the output feature map can also be changed. In Figure (c), the number of convolution kernels in the convolution layer is set to 2C, and the number of channels of the characteristic map will become 2C, so that the output feature map doubled the number of channels. Figure (d) is an overall schematic diagram of a deep residual network.

2.2 Deep residual shrinkage network:

The deep residual shrinkage network is oriented to the signal with "noise", introduces "soft thresholding" as a "shrinkage layer" into the residual module, and proposes a method of adaptively setting the threshold. Noise can be understood as characteristic information that has nothing to do with the current task, that is, interference information.

Soft Thresholding:

 

Network structure:

The threshold set by this sub-network is actually (the average value of the absolute value of the feature map) × (a coefficient α). Under the action of the sigmoid function, α is a number between 0 and 1. In this way, the threshold value is not only a positive number, but also not too large, that is, it does not make the output all zero.

Figure (a) is an improved residual module (shared threshold between channels), RSBU-CS.

Figure (c) is an improved residual module (different thresholds for different channels), RSBU-CW.

3. Experience

Soft thresholding is a very common concept in signal noise reduction. It refers to shrinking the value of a section of signal towards "zero".

This noise reduction method has a premise that the part close to zero is noise. But for many signals, the part close to zero may contain a lot of useful information and cannot be directly removed, so soft thresholding is usually not directly performed on the original signal. The traditional idea is to perform some transformation on the original signal and convert the original signal into other forms of representation. Ideally, in this converted representation, the part close to zero is useless noise, and then soft thresholding is used to process the converted representation. Finally, the representation after soft thresholding is reconstructed back to obtain the denoised signal.

Guess you like

Origin blog.csdn.net/weixin_48983346/article/details/126320929