アテンションメカニズムCBAMコードの実装(続編)

前の記事を表示するには、ここをクリックしてください
Github:kobiso / CBAM-tensorflow
次のコードはリチャネルアテンションモジュールであり、平均および最大プーリング後に取得された記述子がそれぞれMLPに入力されます。コード内の2つの記述子も確認できます。 MLPウェイトシェアリングを取得します。

def cbam_block(input_feature, index, reduction_ratio=8):
    with tf.variable_scope('cbam_%s' % index):
        attention_feature = channel_attention(input_feature, index, reduction_ratio)
        attention_feature = spatial_attention(attention_feature, index)
        print("hello CBAM")
    return attention_feature


def channel_attention(input_feature, index, reduction_ratio=8):
    kernel_initializer = tf.contrib.layers.variance_scaling_initializer()
    bias_initializer = tf.constant_initializer(value=0.0)

    with tf.variable_scope('ch_attention_%s' % index):
        feature_map_shape = input_feature.get_shape()
        channel = input_feature.get_shape()[-1]
        avg_pool = tf.nn.avg_pool(value=input_feature,
                                  ksize=[1, feature_map_shape[1], feature_map_shape[2], 1],
                                  strides=[1, 1, 1, 1],
                                  padding='VALID')
        assert avg_pool.get_shape()[1:] == (1, 1, channel)
        avg_pool = tf.layers.dense(inputs=avg_pool,
                                   units=channel//reduction_ratio,
                                   activation=tf.nn.relu,
                                   kernel_initializer=kernel_initializer,
                                   bias_initializer=bias_initializer,
                                   name='mlp_0',
                                   reuse=None)
        assert avg_pool.get_shape()[1:] == (1, 1, channel//reduction_ratio)
        avg_pool = tf.layers.dense(inputs=avg_pool,
                                   units=channel,
                                   kernel_initializer=kernel_initializer,
                                   bias_initializer=bias_initializer,
                                   name='mlp_1',
                                   reuse=None)
        assert avg_pool.get_shape()[1:] == (1, 1, channel)

        max_pool = tf.nn.max_pool(value=input_feature,
                                  ksize=[1, feature_map_shape[1], feature_map_shape[2], 1],
                                  strides=[1, 1, 1, 1],
                                  padding='VALID')
        assert max_pool.get_shape()[1:] == (1, 1, channel)
        max_pool = tf.layers.dense(inputs=max_pool,
                                   units=channel//reduction_ratio,
                                   activation=tf.nn.relu,
                                   name='mlp_0',
                                   reuse=True)
        assert max_pool.get_shape()[1:] == (1, 1, channel//reduction_ratio)
        max_pool = tf.layers.dense(inputs=max_pool,
                                   units=channel,
                                   name='mlp_1',
                                   reuse=True)
        assert max_pool.get_shape()[1:] == (1, 1, channel)
        scale = tf.nn.sigmoid(avg_pool + max_pool)
    return input_feature * scale


def spatial_attention(input_feature, index):
    kernel_size = 7
    kernel_initializer = tf.contrib.layers.variance_scaling_initializer()
    with tf.variable_scope("sp_attention_%s" % index):
        avg_pool = tf.reduce_mean(input_feature, axis=3, keepdims=True)
        assert avg_pool.get_shape()[-1] == 1
        max_pool = tf.reduce_max(input_feature, axis=3, keepdims=True)
        assert max_pool.get_shape()[-1] == 1
        # 按通道拼接
        concat = tf.concat([avg_pool, max_pool], axis=3)
        assert concat.get_shape()[-1] == 2

        concat = slim.conv2d(concat, num_outputs=1,
                             kernel_size=[kernel_size, kernel_size],
                             padding='SAME',
                             activation_fn=tf.nn.sigmoid,
                             weights_initializer=kernel_initializer,
                             scope='conv')
        assert concat.get_shape()[-1] == 1

    return input_feature * concat

おすすめ

転載: blog.csdn.net/qq_43265072/article/details/106058693