Text-Driven Semantic Image Segmentation: A New Approach to Reshaping Visual Semantic Understanding

Table of contents

Show results

1. What is text-driven image semantic segmentation?

2. Why use text-driven image semantic segmentation?

3. How to achieve text-driven image semantic segmentation?

4. Algorithm introduction

5. Dependency installation

6. Model building

6.1 Image Encoder

6.2 Text Encoder

6.3 Feature fusion network

6.4 Semantic Segmentation Model

7. Model reasoning

7.1 Visualization tools

7.2 Loading the model

7.3 Model Prediction


Show results

  • By controlling the text labels, the model is driven to perform semantic segmentation on specific objects:

introduction

Image semantic segmentation is a crucial task in the field of computer vision, with the goal of understanding objects in an image and their interrelationships. In recent years, a new trend is to fuse text information into image semantic segmentation, which is called text-driven image semantic segmentation. This paper will discuss in depth the principle, implementation and future prospects of this new method.

1. What is text-driven image semantic segmentation?

Text-driven image semantic segmentation is a method that combines textual descriptions with visual information to improve semantic segmentation performance. In this approach, the model needs to understand not only visual information, but also textual information, and integrate the two kinds of information to make decisions.

2. Why use text-driven image semantic segmentation?

In many cases, textual information can provide additional context that helps resolve ambiguities in visual information. For example, if there is an object in an image, we may not be able to determine whether it is a dog or a cat. But if we have a textual description that says "this is a dog", then we can be sure that the object is a dog. Therefore, textual information can provide useful prior knowledge to help us understand images better.

3. How to achieve text-driven image semantic segmentation?

The key to realize text-driven image semantic segmentation is how to integrate text information and visual information. A common approach is to use deep learning models such as convolutional neural networks (CNN) and long short-term memory networks (LSTM).

  • Convolutional Neural Networks : CNNs are a deep learning model especially suited for processing images. It can automatically extract important features in the image for subsequent decision-making.

  • Long Short-Term Memory Networks : LSTMs are a deep learning model that is particularly well suited for processing sequential data, such as text. It can understand the semantics and context of text and encode this information into a vector.

In this approach, CNN is first used to extract image features, and then LSTM is used to extract text features. Finally, these two features are integrated and input to a classifier for decision-making.

4. Algorithm introduction

  • Model architecture diagram:

  • As can be seen from the above model architecture diagram, the entire LSeg model is divided into three main parts:

    • Image Coding Network

      • Through a CNN or Transformer model, the ViT and CLIP models are tested in the article for encoding image features:

    • Text Encoding Network

      • Through a Transformer model, the CLIP model is tested in the article for encoding text features:

    • Feature Fusion Network

      • Use some CNN modules to fuse image and text features and generate image segmentation results:

  • Algorithm ideas:

    • The model training is similar to the conventional image semantic segmentation model, and also uses the labeled semantic segmentation data to do a supervised training

    • The difference is that the semantic label of the image is used as an additional input during training, which is converted into a text feature of a specific dimension, and the category and number of categories of the segmentation output are controlled.

    • In this way, multiple different semantic segmentation datasets can be used for fusion training, even if their labels are different, the model can be trained normally

    • Because compared to the data scale used by models such as CLIP, the scale of labeled semantic segmentation data that can be used now is still relatively small, so the model parameters of CLIP are not updated during training to avoid degrading the effect of the model.

    • By introducing text features through a text encoding network such as CLIP, a text-driven semantic segmentation model can be easily implemented.

5. Dependency installation

  • Install PaddleNLP and PaddleClas

In [ ]

!pip install paddleclas paddlenlp ftfy regex --upgrade

6. Model building

6.1 Image Encoder

  • Here the Vision Transformer model is used as the image encoder

  • In order to better extract the feature information of the image, some minor modifications have been made to the model

    • Extract multi-level model output features

    • Add a small network for feature postprocessing

    • Deleted the Norm and Linear layers of the original model output

In [2]

import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F

from paddleclas.ppcls.arch.backbone.model_zoo.vision_transformer import VisionTransformer


class Slice(nn.Layer):
    def __init__(self, start_index=1):
        super(Slice, self).__init__()
        self.start_index = start_index

    def forward(self, x):
        return x[:, self.start_index:]


class AddReadout(nn.Layer):
    def __init__(self, start_index=1):
        super(AddReadout, self).__init__()
        self.start_index = start_index

    def forward(self, x):
        if self.start_index == 2:
            readout = (x[:, 0] + x[:, 1]) / 2
        else:
            readout = x[:, 0]
        return x[:, self.start_index:] + readout.unsqueeze(1)


class Transpose(nn.Layer):
    def __init__(self, dim0, dim1):
        super(Transpose, self).__init__()
        self.dim0 = dim0
        self.dim1 = dim1

    def forward(self, x):
        prems = list(range(x.dim()))
        prems[self.dim0], prems[self.dim1] = prems[self.dim1], prems[self.dim0]
        x = x.transpose(prems)
        return x


class Unflatten(nn.Layer):
    def __init__(self, start_axis, shape):
        super(Unflatten, self).__init__()
        self.start_axis = start_axis
        self.shape = shape

    def forward(self, x):
        return paddle.reshape(x, x.shape[:self.start_axis] + [self.shape])


class ProjectReadout(nn.Layer):
    def __init__(self, in_features, start_index=1):
        super(ProjectReadout, self).__init__()
        self.start_index = start_index

        self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())

    def forward(self, x):
        readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
        features = paddle.concat((x[:, self.start_index :], readout), -1)

        return self.project(features)

class ViT(VisionTransformer):
    def __init__(self, img_size=384, patch_size=16, in_chans=3, class_num=1000,
                 embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
                 qk_scale=None, drop_rate=0, attn_drop_rate=0, drop_path_rate=0,
                 norm_layer='nn.LayerNorm', epsilon=1e-6, **kwargs):
        super().__init__(img_size, patch_size, in_chans, class_num, embed_dim,
                         depth, num_heads, mlp_ratio, qkv_bias, qk_scale, drop_rate,
                         attn_drop_rate, drop_path_rate, norm_layer, epsilon, **kwargs)
        self.patch_size = patch_size
        self.start_index = 1
        features = [256, 512, 1024, 1024]
        readout_oper = [
            ProjectReadout(embed_dim, self.start_index) for out_feat in features
        ]
        self.act_postprocess1 = nn.Sequential(
            readout_oper[0],
            Transpose(1, 2),
            Unflatten(2, [img_size // 16, img_size // 16]),
            nn.Conv2D(
                in_channels=embed_dim,
                out_channels=features[0],
                kernel_size=1,
                stride=1,
                padding=0,
            ),
            nn.Conv2DTranspose(
                in_channels=features[0],
                out_channels=features[0],
                kernel_size=4,
                stride=4,
                padding=0,
                dilation=1,
                groups=1,
            ),
        )

        self.act_postprocess2 = nn.Sequential(
            readout_oper[1],
            Transpose(1, 2),
            Unflatten(2, [img_size // 16, img_size // 16]),
            nn.Conv2D(
                in_channels=embed_dim,
                out_channels=features[1],
                kernel_size=1,
                stride=1,
                padding=0,
            ),
            nn.Conv2DTranspose(
                in_channels=features[1],
                out_channels=features[1],
                kernel_size=2,
                stride=2,
                padding=0,
                dilation=1,
                groups=1,
            ),
        )

        self.act_postprocess3 = nn.Sequential(
            readout_oper[2],
            Transpose(1, 2),
            Unflatten(2, [img_size // 16, img_size // 16]),
            nn.Conv2D(
                in_channels=embed_dim,
                out_channels=features[2],
                kernel_size=1,
                stride=1,
                padding=0,
            ),
        )

        self.act_postprocess4 = nn.Sequential(
            readout_oper[3],
            Transpose(1, 2),
            Unflatten(2, [img_size // 16, img_size // 16]),
            nn.Conv2D(
                in_channels=embed_dim,
                out_channels=features[3],
                kernel_size=1,
                stride=1,
                padding=0,
            ),
            nn.Conv2D(
                in_channels=features[3],
                out_channels=features[3],
                kernel_size=3,
                stride=2,
                padding=1,
            ),
        )

        self.norm = nn.Identity()
        self.head = nn.Identity()

    def _resize_pos_embed(self, posemb, gs_h, gs_w):
        posemb_tok, posemb_grid = (
            posemb[:, : self.start_index],
            posemb[0, self.start_index:],
        )

        gs_old = int(math.sqrt(len(posemb_grid)))

        posemb_grid = posemb_grid.reshape(
            (1, gs_old, gs_old, -1)).transpose((0, 3, 1, 2))
        posemb_grid = F.interpolate(
            posemb_grid, size=(gs_h, gs_w), mode="bilinear")
        posemb_grid = posemb_grid.transpose(
            (0, 2, 3, 1)).reshape((1, gs_h * gs_w, -1))

        posemb = paddle.concat([posemb_tok, posemb_grid], axis=1)

        return posemb

    def forward(self, x):
        b, c, h, w = x.shape

        pos_embed = self._resize_pos_embed(
            self.pos_embed, h // self.patch_size, w // self.patch_size
        )
        x = self.patch_embed.proj(x).flatten(2).transpose((0, 2, 1))

        cls_tokens = self.cls_token.expand(
            (b, -1, -1)
        )
        x = paddle.concat((cls_tokens, x), axis=1)

        x = x + pos_embed
        x = self.pos_drop(x)

        outputs = []
        for index, blk in enumerate(self.blocks):
            x = blk(x)
            if index in [5, 11, 17, 23]:
                outputs.append(x)

        layer_1 = self.act_postprocess1[0:2](outputs[0])
        layer_2 = self.act_postprocess2[0:2](outputs[1])
        layer_3 = self.act_postprocess3[0:2](outputs[2])
        layer_4 = self.act_postprocess4[0:2](outputs[3])

        shape = (-1, 1024, h // self.patch_size, w // self.patch_size)
        layer_1 = layer_1.reshape(shape)
        layer_2 = layer_2.reshape(shape)
        layer_3 = layer_3.reshape(shape)
        layer_4 = layer_4.reshape(shape)

        layer_1 = self.act_postprocess1[3: len(self.act_postprocess1)](layer_1)
        layer_2 = self.act_postprocess2[3: len(self.act_postprocess2)](layer_2)
        layer_3 = self.act_postprocess3[3: len(self.act_postprocess3)](layer_3)
        layer_4 = self.act_postprocess4[3: len(self.act_postprocess4)](layer_4)

        return layer_1, layer_2, layer_3, layer_4

6.2 Text Encoder

  • The CLIP model is used here as a text encoder
  • Since only textual information needs to be encoded, the image encoder included in CLIP does not need to retain

In [3]

import paddle
import paddle.nn as nn
from paddlenlp.transformers.clip.modeling import TextTransformer


class CLIPText(nn.Layer):
    def __init__(
            self,
            max_text_length: int = 77,
            vocab_size: int = 49408,
            text_embed_dim: int = 512,
            text_heads: int = 8,
            text_layers: int = 12,
            text_hidden_act: str = "quick_gelu",
            projection_dim: int = 512):
        super().__init__()

        self.text_model = TextTransformer(context_length=max_text_length,
                                          transformer_width=text_embed_dim,
                                          transformer_heads=text_heads,
                                          transformer_layers=text_layers,
                                          vocab_size=vocab_size,
                                          activation=text_hidden_act,
                                          normalize_before=True)

        self.text_projection = paddle.create_parameter(
            (text_embed_dim, projection_dim), paddle.get_default_dtype())

    def get_text_features(
        self,
        input_ids,
        attention_mask=None,
        position_ids=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=False,
    ):
        text_outputs = self.text_model(
            input_ids=input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict)
        pooled_output = text_outputs[1]
        text_features = paddle.matmul(pooled_output, self.text_projection)
        return text_features

6.3 Feature fusion network

  • This is a feature fusion model proposed in the paper

In [4]

import paddle
import paddle.nn as nn

import numpy as np


class Interpolate(nn.Layer):
    """Interpolation module."""

    def __init__(self, scale_factor, mode, align_corners=False):
        """Init.

        Args:
            scale_factor (float): scaling
            mode (str): interpolation mode
        """
        super(Interpolate, self).__init__()

        self.interp = nn.functional.interpolate
        self.scale_factor = scale_factor
        self.mode = mode
        self.align_corners = align_corners

    def forward(self, x):
        """Forward pass.

        Args:
            x (tensor): input

        Returns:
            tensor: interpolated data
        """

        x = self.interp(
            x,
            scale_factor=self.scale_factor,
            mode=self.mode,
            align_corners=self.align_corners,
        )

        return x


class ResidualConvUnit(nn.Layer):
    """Residual convolution module."""

    def __init__(self, features):
        """Init.

        Args:
            features (int): number of features
        """
        super().__init__()

        self.conv1 = nn.Conv2D(
            features, features, kernel_size=3, stride=1, padding=1
        )

        self.conv2 = nn.Conv2D(
            features, features, kernel_size=3, stride=1, padding=1
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        """Forward pass.

        Args:
            x (tensor): input

        Returns:
            tensor: output
        """
        out = self.relu(x)
        out = self.conv1(out)
        out = self.relu(out)
        out = self.conv2(out)

        return out + x


class FeatureFusionBlock(nn.Layer):
    """Feature fusion block."""

    def __init__(self, features):
        """Init.

        Args:
            features (int): number of features
        """
        super(FeatureFusionBlock, self).__init__()

        self.resConfUnit1 = ResidualConvUnit(features)
        self.resConfUnit2 = ResidualConvUnit(features)

    def forward(self, *xs):
        """Forward pass.

        Returns:
            tensor: output
        """
        output = xs[0]

        if len(xs) == 2:
            output += self.resConfUnit1(xs[1])

        output = self.resConfUnit2(output)

        output = nn.functional.interpolate(
            output, scale_factor=2, mode="bilinear", align_corners=True
        )

        return output


class ResidualConvUnit_custom(nn.Layer):
    """Residual convolution module."""

    def __init__(self, features, activation, bn):
        """Init.

        Args:
            features (int): number of features
        """
        super().__init__()

        self.bn = bn

        self.groups = 1

        self.conv1 = nn.Conv2D(
            features,
            features,
            kernel_size=3,
            stride=1,
            padding=1,
            bias_attr=not self.bn,
            groups=self.groups,
        )

        self.conv2 = nn.Conv2D(
            features,
            features,
            kernel_size=3,
            stride=1,
            padding=1,
            bias_attr=not self.bn,
            groups=self.groups,
        )

        if self.bn == True:
            self.bn1 = nn.BatchNorm2D(features)
            self.bn2 = nn.BatchNorm2D(features)

        self.activation = activation

    def forward(self, x):
        """Forward pass.

        Args:
            x (tensor): input

        Returns:
            tensor: output
        """

        out = self.activation(x)
        out = self.conv1(out)
        if self.bn == True:
            out = self.bn1(out)

        out = self.activation(out)
        out = self.conv2(out)
        if self.bn == True:
            out = self.bn2(out)

        if self.groups > 1:
            out = self.conv_merge(out)

        return out + x


class FeatureFusionBlock_custom(nn.Layer):
    """Feature fusion block."""

    def __init__(
        self,
        features,
        activation=nn.ReLU(),
        deconv=False,
        bn=False,
        expand=False,
        align_corners=True,
    ):
        """Init.

        Args:
            features (int): number of features
        """
        super(FeatureFusionBlock_custom, self).__init__()

        self.deconv = deconv
        self.align_corners = align_corners

        self.groups = 1

        self.expand = expand
        out_features = features
        if self.expand == True:
            out_features = features // 2

        self.out_conv = nn.Conv2D(
            features,
            out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias_attr=True,
            groups=1,
        )

        self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
        self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)

    def forward(self, *xs):
        """Forward pass.

        Returns:
            tensor: output
        """
        output = xs[0]

        if len(xs) == 2:
            res = self.resConfUnit1(xs[1])
            output += res

        output = self.resConfUnit2(output)

        output = nn.functional.interpolate(
            output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
        )

        output = self.out_conv(output)

        return output


class Scratch(nn.Layer):
    def __init__(self, in_channels=[256, 512, 1024, 1024], out_channels=256):
        super().__init__()
        self.out_c = 512
        self.logit_scale = paddle.to_tensor(np.exp(np.log([1 / 0.07])))
        self.layer1_rn = nn.Conv2D(
            in_channels[0],
            out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            bias_attr=False,
            groups=1,
        )
        self.layer2_rn = nn.Conv2D(
            in_channels[1],
            out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            bias_attr=False,
            groups=1,
        )
        self.layer3_rn = nn.Conv2D(
            in_channels[2],
            out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            bias_attr=False,
            groups=1,
        )
        self.layer4_rn = nn.Conv2D(
            in_channels[3],
            out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            bias_attr=False,
            groups=1,
        )

        self.refinenet1 = FeatureFusionBlock_custom(
            out_channels, bn=True
        )
        self.refinenet2 = FeatureFusionBlock_custom(
            out_channels, bn=True
        )
        self.refinenet3 = FeatureFusionBlock_custom(
            out_channels, bn=True
        )
        self.refinenet4 = FeatureFusionBlock_custom(
            out_channels, bn=True
        )

        self.head1 = nn.Conv2D(out_channels, self.out_c, kernel_size=1)

        self.output_conv = nn.Sequential(
            Interpolate(scale_factor=2, mode="bilinear", align_corners=True)
        )

    def forward(self, layer_1, layer_2, layer_3, layer_4, text_features):

        layer_1_rn = self.layer1_rn(layer_1)
        layer_2_rn = self.layer2_rn(layer_2)
        layer_3_rn = self.layer3_rn(layer_3)
        layer_4_rn = self.layer4_rn(layer_4)

        path_4 = self.refinenet4(layer_4_rn)
        path_3 = self.refinenet3(path_4, layer_3_rn)
        path_2 = self.refinenet2(path_3, layer_2_rn)
        path_1 = self.refinenet1(path_2, layer_1_rn)

        image_features = self.head1(path_1)

        imshape = image_features.shape
        image_features = image_features.transpose(
            (0, 2, 3, 1)).reshape((-1, self.out_c))

        # normalized features
        image_features = image_features / \
            image_features.norm(axis=-1, keepdim=True)
        text_features = text_features / \
            text_features.norm(axis=-1, keepdim=True)

        logits_per_image = self.logit_scale * image_features @ text_features.t()

        out = logits_per_image.reshape(
            (imshape[0], imshape[2], imshape[3], -1)).transpose((0, 3, 1, 2))

        out = self.output_conv(out)

        return out

6.4 Semantic Segmentation Model

  • Combining the above three modules together forms a text-driven semantic segmentation model

In [5]

class LSeg(nn.Layer):
    def __init__(self):
        super().__init__()
        self.clip = CLIPText()
        self.vit = ViT()
        self.scratch = Scratch()
    
    def forward(self, images, texts):
        layer_1, layer_2, layer_3, layer_4 = self.vit.forward(images)
        text_features = self.clip.get_text_features(texts)
        return self.scratch.forward(layer_1, layer_2, layer_3, layer_4, text_features)

7. Model reasoning

7.1 Visualization tools

In [ ]

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches


def get_new_pallete(num_cls):
    n = num_cls
    pallete = [0]*(n*3)
    for j in range(0,n):
            lab = j
            pallete[j*3+0] = 0
            pallete[j*3+1] = 0
            pallete[j*3+2] = 0
            i = 0
            while (lab > 0):
                    pallete[j*3+0] |= (((lab >> 0) & 1) << (7-i))
                    pallete[j*3+1] |= (((lab >> 1) & 1) << (7-i))
                    pallete[j*3+2] |= (((lab >> 2) & 1) << (7-i))
                    i = i + 1
                    lab >>= 3
    return pallete

def get_new_mask_pallete(npimg, new_palette, out_label_flag=False, labels=None):
    """Get image color pallete for visualizing masks"""
    # put colormap
    out_img = Image.fromarray(npimg.squeeze().astype('uint8'))
    out_img.putpalette(new_palette)

    if out_label_flag:
        assert labels is not None
        u_index = np.unique(npimg)
        patches = []
        for i, index in enumerate(u_index):
            label = labels[index]
            cur_color = [new_palette[index * 3] / 255.0, new_palette[index * 3 + 1] / 255.0, new_palette[index * 3 + 2] / 255.0]
            red_patch = mpatches.Patch(color=cur_color, label=label)
            patches.append(red_patch)
    return out_img, patches

7.2 Loading the model

In [ ]

import paddle.vision.transforms as transforms

from paddlenlp.transformers.clip.tokenizer import CLIPTokenizer


model = LSeg()
state_dict = paddle.load('data/data169501/LSeg.pdparams')
model.set_state_dict(state_dict)
model.eval()

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5, 0.5, 0.5], 
            [0.5, 0.5, 0.5]
        ),
    ]
)

tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32')

7.3 Model Prediction

In [8]

import cv2
import numpy as np

from PIL import Image


# 指定图像路径
img_path = 'images/cat.jpeg'

# 指定类别标签
labels = ['plant', 'grass', 'cat', 'stone', 'other']

image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
h, w = image.shape[:2]
image = image[:-(h%32) if h%32 else None, :-(w%32) if w%32 else None]
images = transform(image).unsqueeze(0)
image = Image.fromarray(image).convert("RGBA")


texts = tokenizer(labels, padding=True, return_tensors="pd")['input_ids']


with paddle.no_grad():
    results = model.forward(images, texts)
    results = paddle.argmax(results, 1)
    results = results.numpy()

new_palette = get_new_pallete(len(labels))
mask, patches = get_new_mask_pallete(results, new_palette, out_label_flag=True, labels=labels)

seg = mask.convert("RGBA")
out = Image.blend(image, seg, alpha=0.5)
plt.axis('off')
plt.imshow(image)
plt.figure()
plt.axis('off')
plt.imshow(out)
plt.figure()
plt.legend(handles=patches, loc='upper right', bbox_to_anchor=(1.5, 1), prop={'size': 20})
plt.axis('off')
plt.imshow(seg)
<matplotlib.image.AxesImage at 0x7ff189365f10>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

Guess you like

Origin blog.csdn.net/m0_68036862/article/details/131359380