Attention mechanism CBAM code implementation (sequel)

Click here to view the previous article
Github: kobiso/CBAM-tensorflow
The following code is the re-channel attention module, and the descriptors obtained after the average and maximum pooling are input into the MLP respectively. You can also see the two descriptors in the code. Get MLP weight sharing.

def cbam_block(input_feature, index, reduction_ratio=8):
    with tf.variable_scope('cbam_%s' % index):
        attention_feature = channel_attention(input_feature, index, reduction_ratio)
        attention_feature = spatial_attention(attention_feature, index)
        print("hello CBAM")
    return attention_feature


def channel_attention(input_feature, index, reduction_ratio=8):
    kernel_initializer = tf.contrib.layers.variance_scaling_initializer()
    bias_initializer = tf.constant_initializer(value=0.0)

    with tf.variable_scope('ch_attention_%s' % index):
        feature_map_shape = input_feature.get_shape()
        channel = input_feature.get_shape()[-1]
        avg_pool = tf.nn.avg_pool(value=input_feature,
                                  ksize=[1, feature_map_shape[1], feature_map_shape[2], 1],
                                  strides=[1, 1, 1, 1],
                                  padding='VALID')
        assert avg_pool.get_shape()[1:] == (1, 1, channel)
        avg_pool = tf.layers.dense(inputs=avg_pool,
                                   units=channel//reduction_ratio,
                                   activation=tf.nn.relu,
                                   kernel_initializer=kernel_initializer,
                                   bias_initializer=bias_initializer,
                                   name='mlp_0',
                                   reuse=None)
        assert avg_pool.get_shape()[1:] == (1, 1, channel//reduction_ratio)
        avg_pool = tf.layers.dense(inputs=avg_pool,
                                   units=channel,
                                   kernel_initializer=kernel_initializer,
                                   bias_initializer=bias_initializer,
                                   name='mlp_1',
                                   reuse=None)
        assert avg_pool.get_shape()[1:] == (1, 1, channel)

        max_pool = tf.nn.max_pool(value=input_feature,
                                  ksize=[1, feature_map_shape[1], feature_map_shape[2], 1],
                                  strides=[1, 1, 1, 1],
                                  padding='VALID')
        assert max_pool.get_shape()[1:] == (1, 1, channel)
        max_pool = tf.layers.dense(inputs=max_pool,
                                   units=channel//reduction_ratio,
                                   activation=tf.nn.relu,
                                   name='mlp_0',
                                   reuse=True)
        assert max_pool.get_shape()[1:] == (1, 1, channel//reduction_ratio)
        max_pool = tf.layers.dense(inputs=max_pool,
                                   units=channel,
                                   name='mlp_1',
                                   reuse=True)
        assert max_pool.get_shape()[1:] == (1, 1, channel)
        scale = tf.nn.sigmoid(avg_pool + max_pool)
    return input_feature * scale


def spatial_attention(input_feature, index):
    kernel_size = 7
    kernel_initializer = tf.contrib.layers.variance_scaling_initializer()
    with tf.variable_scope("sp_attention_%s" % index):
        avg_pool = tf.reduce_mean(input_feature, axis=3, keepdims=True)
        assert avg_pool.get_shape()[-1] == 1
        max_pool = tf.reduce_max(input_feature, axis=3, keepdims=True)
        assert max_pool.get_shape()[-1] == 1
        # 按通道拼接
        concat = tf.concat([avg_pool, max_pool], axis=3)
        assert concat.get_shape()[-1] == 2

        concat = slim.conv2d(concat, num_outputs=1,
                             kernel_size=[kernel_size, kernel_size],
                             padding='SAME',
                             activation_fn=tf.nn.sigmoid,
                             weights_initializer=kernel_initializer,
                             scope='conv')
        assert concat.get_shape()[-1] == 1

    return input_feature * concat

Guess you like

Origin blog.csdn.net/qq_43265072/article/details/106058693