Table of contents
1. What is text-driven image semantic segmentation?
2. Why use text-driven image semantic segmentation?
3. How to achieve text-driven image semantic segmentation?
6.4 Semantic Segmentation Model
Show results
-
By controlling the text labels, the model is driven to perform semantic segmentation on specific objects:
introduction
Image semantic segmentation is a crucial task in the field of computer vision, with the goal of understanding objects in an image and their interrelationships. In recent years, a new trend is to fuse text information into image semantic segmentation, which is called text-driven image semantic segmentation. This paper will discuss in depth the principle, implementation and future prospects of this new method.
1. What is text-driven image semantic segmentation?
Text-driven image semantic segmentation is a method that combines textual descriptions with visual information to improve semantic segmentation performance. In this approach, the model needs to understand not only visual information, but also textual information, and integrate the two kinds of information to make decisions.
2. Why use text-driven image semantic segmentation?
In many cases, textual information can provide additional context that helps resolve ambiguities in visual information. For example, if there is an object in an image, we may not be able to determine whether it is a dog or a cat. But if we have a textual description that says "this is a dog", then we can be sure that the object is a dog. Therefore, textual information can provide useful prior knowledge to help us understand images better.
3. How to achieve text-driven image semantic segmentation?
The key to realize text-driven image semantic segmentation is how to integrate text information and visual information. A common approach is to use deep learning models such as convolutional neural networks (CNN) and long short-term memory networks (LSTM).
-
Convolutional Neural Networks : CNNs are a deep learning model especially suited for processing images. It can automatically extract important features in the image for subsequent decision-making.
-
Long Short-Term Memory Networks : LSTMs are a deep learning model that is particularly well suited for processing sequential data, such as text. It can understand the semantics and context of text and encode this information into a vector.
In this approach, CNN is first used to extract image features, and then LSTM is used to extract text features. Finally, these two features are integrated and input to a classifier for decision-making.
4. Algorithm introduction
-
Model architecture diagram:
-
As can be seen from the above model architecture diagram, the entire LSeg model is divided into three main parts:
-
Image Coding Network
-
Through a CNN or Transformer model, the ViT and CLIP models are tested in the article for encoding image features:
-
-
Text Encoding Network
-
Through a Transformer model, the CLIP model is tested in the article for encoding text features:
-
-
Feature Fusion Network
-
Use some CNN modules to fuse image and text features and generate image segmentation results:
-
-
-
Algorithm ideas:
-
The model training is similar to the conventional image semantic segmentation model, and also uses the labeled semantic segmentation data to do a supervised training
-
The difference is that the semantic label of the image is used as an additional input during training, which is converted into a text feature of a specific dimension, and the category and number of categories of the segmentation output are controlled.
-
In this way, multiple different semantic segmentation datasets can be used for fusion training, even if their labels are different, the model can be trained normally
-
Because compared to the data scale used by models such as CLIP, the scale of labeled semantic segmentation data that can be used now is still relatively small, so the model parameters of CLIP are not updated during training to avoid degrading the effect of the model.
-
By introducing text features through a text encoding network such as CLIP, a text-driven semantic segmentation model can be easily implemented.
-
5. Dependency installation
- Install PaddleNLP and PaddleClas
In [ ]
!pip install paddleclas paddlenlp ftfy regex --upgrade
6. Model building
6.1 Image Encoder
-
Here the Vision Transformer model is used as the image encoder
-
In order to better extract the feature information of the image, some minor modifications have been made to the model
-
Extract multi-level model output features
-
Add a small network for feature postprocessing
-
Deleted the Norm and Linear layers of the original model output
-
In [2]
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddleclas.ppcls.arch.backbone.model_zoo.vision_transformer import VisionTransformer
class Slice(nn.Layer):
def __init__(self, start_index=1):
super(Slice, self).__init__()
self.start_index = start_index
def forward(self, x):
return x[:, self.start_index:]
class AddReadout(nn.Layer):
def __init__(self, start_index=1):
super(AddReadout, self).__init__()
self.start_index = start_index
def forward(self, x):
if self.start_index == 2:
readout = (x[:, 0] + x[:, 1]) / 2
else:
readout = x[:, 0]
return x[:, self.start_index:] + readout.unsqueeze(1)
class Transpose(nn.Layer):
def __init__(self, dim0, dim1):
super(Transpose, self).__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x):
prems = list(range(x.dim()))
prems[self.dim0], prems[self.dim1] = prems[self.dim1], prems[self.dim0]
x = x.transpose(prems)
return x
class Unflatten(nn.Layer):
def __init__(self, start_axis, shape):
super(Unflatten, self).__init__()
self.start_axis = start_axis
self.shape = shape
def forward(self, x):
return paddle.reshape(x, x.shape[:self.start_axis] + [self.shape])
class ProjectReadout(nn.Layer):
def __init__(self, in_features, start_index=1):
super(ProjectReadout, self).__init__()
self.start_index = start_index
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
def forward(self, x):
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
features = paddle.concat((x[:, self.start_index :], readout), -1)
return self.project(features)
class ViT(VisionTransformer):
def __init__(self, img_size=384, patch_size=16, in_chans=3, class_num=1000,
embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
qk_scale=None, drop_rate=0, attn_drop_rate=0, drop_path_rate=0,
norm_layer='nn.LayerNorm', epsilon=1e-6, **kwargs):
super().__init__(img_size, patch_size, in_chans, class_num, embed_dim,
depth, num_heads, mlp_ratio, qkv_bias, qk_scale, drop_rate,
attn_drop_rate, drop_path_rate, norm_layer, epsilon, **kwargs)
self.patch_size = patch_size
self.start_index = 1
features = [256, 512, 1024, 1024]
readout_oper = [
ProjectReadout(embed_dim, self.start_index) for out_feat in features
]
self.act_postprocess1 = nn.Sequential(
readout_oper[0],
Transpose(1, 2),
Unflatten(2, [img_size // 16, img_size // 16]),
nn.Conv2D(
in_channels=embed_dim,
out_channels=features[0],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2DTranspose(
in_channels=features[0],
out_channels=features[0],
kernel_size=4,
stride=4,
padding=0,
dilation=1,
groups=1,
),
)
self.act_postprocess2 = nn.Sequential(
readout_oper[1],
Transpose(1, 2),
Unflatten(2, [img_size // 16, img_size // 16]),
nn.Conv2D(
in_channels=embed_dim,
out_channels=features[1],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2DTranspose(
in_channels=features[1],
out_channels=features[1],
kernel_size=2,
stride=2,
padding=0,
dilation=1,
groups=1,
),
)
self.act_postprocess3 = nn.Sequential(
readout_oper[2],
Transpose(1, 2),
Unflatten(2, [img_size // 16, img_size // 16]),
nn.Conv2D(
in_channels=embed_dim,
out_channels=features[2],
kernel_size=1,
stride=1,
padding=0,
),
)
self.act_postprocess4 = nn.Sequential(
readout_oper[3],
Transpose(1, 2),
Unflatten(2, [img_size // 16, img_size // 16]),
nn.Conv2D(
in_channels=embed_dim,
out_channels=features[3],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2D(
in_channels=features[3],
out_channels=features[3],
kernel_size=3,
stride=2,
padding=1,
),
)
self.norm = nn.Identity()
self.head = nn.Identity()
def _resize_pos_embed(self, posemb, gs_h, gs_w):
posemb_tok, posemb_grid = (
posemb[:, : self.start_index],
posemb[0, self.start_index:],
)
gs_old = int(math.sqrt(len(posemb_grid)))
posemb_grid = posemb_grid.reshape(
(1, gs_old, gs_old, -1)).transpose((0, 3, 1, 2))
posemb_grid = F.interpolate(
posemb_grid, size=(gs_h, gs_w), mode="bilinear")
posemb_grid = posemb_grid.transpose(
(0, 2, 3, 1)).reshape((1, gs_h * gs_w, -1))
posemb = paddle.concat([posemb_tok, posemb_grid], axis=1)
return posemb
def forward(self, x):
b, c, h, w = x.shape
pos_embed = self._resize_pos_embed(
self.pos_embed, h // self.patch_size, w // self.patch_size
)
x = self.patch_embed.proj(x).flatten(2).transpose((0, 2, 1))
cls_tokens = self.cls_token.expand(
(b, -1, -1)
)
x = paddle.concat((cls_tokens, x), axis=1)
x = x + pos_embed
x = self.pos_drop(x)
outputs = []
for index, blk in enumerate(self.blocks):
x = blk(x)
if index in [5, 11, 17, 23]:
outputs.append(x)
layer_1 = self.act_postprocess1[0:2](outputs[0])
layer_2 = self.act_postprocess2[0:2](outputs[1])
layer_3 = self.act_postprocess3[0:2](outputs[2])
layer_4 = self.act_postprocess4[0:2](outputs[3])
shape = (-1, 1024, h // self.patch_size, w // self.patch_size)
layer_1 = layer_1.reshape(shape)
layer_2 = layer_2.reshape(shape)
layer_3 = layer_3.reshape(shape)
layer_4 = layer_4.reshape(shape)
layer_1 = self.act_postprocess1[3: len(self.act_postprocess1)](layer_1)
layer_2 = self.act_postprocess2[3: len(self.act_postprocess2)](layer_2)
layer_3 = self.act_postprocess3[3: len(self.act_postprocess3)](layer_3)
layer_4 = self.act_postprocess4[3: len(self.act_postprocess4)](layer_4)
return layer_1, layer_2, layer_3, layer_4
6.2 Text Encoder
- The CLIP model is used here as a text encoder
- Since only textual information needs to be encoded, the image encoder included in CLIP does not need to retain
In [3]
import paddle
import paddle.nn as nn
from paddlenlp.transformers.clip.modeling import TextTransformer
class CLIPText(nn.Layer):
def __init__(
self,
max_text_length: int = 77,
vocab_size: int = 49408,
text_embed_dim: int = 512,
text_heads: int = 8,
text_layers: int = 12,
text_hidden_act: str = "quick_gelu",
projection_dim: int = 512):
super().__init__()
self.text_model = TextTransformer(context_length=max_text_length,
transformer_width=text_embed_dim,
transformer_heads=text_heads,
transformer_layers=text_layers,
vocab_size=vocab_size,
activation=text_hidden_act,
normalize_before=True)
self.text_projection = paddle.create_parameter(
(text_embed_dim, projection_dim), paddle.get_default_dtype())
def get_text_features(
self,
input_ids,
attention_mask=None,
position_ids=None,
output_attentions=False,
output_hidden_states=False,
return_dict=False,
):
text_outputs = self.text_model(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)
pooled_output = text_outputs[1]
text_features = paddle.matmul(pooled_output, self.text_projection)
return text_features
6.3 Feature fusion network
- This is a feature fusion model proposed in the paper
In [4]
import paddle
import paddle.nn as nn
import numpy as np
class Interpolate(nn.Layer):
"""Interpolation module."""
def __init__(self, scale_factor, mode, align_corners=False):
"""Init.
Args:
scale_factor (float): scaling
mode (str): interpolation mode
"""
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: interpolated data
"""
x = self.interp(
x,
scale_factor=self.scale_factor,
mode=self.mode,
align_corners=self.align_corners,
)
return x
class ResidualConvUnit(nn.Layer):
"""Residual convolution module."""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.conv1 = nn.Conv2D(
features, features, kernel_size=3, stride=1, padding=1
)
self.conv2 = nn.Conv2D(
features, features, kernel_size=3, stride=1, padding=1
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.relu(x)
out = self.conv1(out)
out = self.relu(out)
out = self.conv2(out)
return out + x
class FeatureFusionBlock(nn.Layer):
"""Feature fusion block."""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()
self.resConfUnit1 = ResidualConvUnit(features)
self.resConfUnit2 = ResidualConvUnit(features)
def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
output += self.resConfUnit1(xs[1])
output = self.resConfUnit2(output)
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=True
)
return output
class ResidualConvUnit_custom(nn.Layer):
"""Residual convolution module."""
def __init__(self, features, activation, bn):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.groups = 1
self.conv1 = nn.Conv2D(
features,
features,
kernel_size=3,
stride=1,
padding=1,
bias_attr=not self.bn,
groups=self.groups,
)
self.conv2 = nn.Conv2D(
features,
features,
kernel_size=3,
stride=1,
padding=1,
bias_attr=not self.bn,
groups=self.groups,
)
if self.bn == True:
self.bn1 = nn.BatchNorm2D(features)
self.bn2 = nn.BatchNorm2D(features)
self.activation = activation
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.activation(x)
out = self.conv1(out)
if self.bn == True:
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
if self.bn == True:
out = self.bn2(out)
if self.groups > 1:
out = self.conv_merge(out)
return out + x
class FeatureFusionBlock_custom(nn.Layer):
"""Feature fusion block."""
def __init__(
self,
features,
activation=nn.ReLU(),
deconv=False,
bn=False,
expand=False,
align_corners=True,
):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock_custom, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups = 1
self.expand = expand
out_features = features
if self.expand == True:
out_features = features // 2
self.out_conv = nn.Conv2D(
features,
out_features,
kernel_size=1,
stride=1,
padding=0,
bias_attr=True,
groups=1,
)
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output += res
output = self.resConfUnit2(output)
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
)
output = self.out_conv(output)
return output
class Scratch(nn.Layer):
def __init__(self, in_channels=[256, 512, 1024, 1024], out_channels=256):
super().__init__()
self.out_c = 512
self.logit_scale = paddle.to_tensor(np.exp(np.log([1 / 0.07])))
self.layer1_rn = nn.Conv2D(
in_channels[0],
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False,
groups=1,
)
self.layer2_rn = nn.Conv2D(
in_channels[1],
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False,
groups=1,
)
self.layer3_rn = nn.Conv2D(
in_channels[2],
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False,
groups=1,
)
self.layer4_rn = nn.Conv2D(
in_channels[3],
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False,
groups=1,
)
self.refinenet1 = FeatureFusionBlock_custom(
out_channels, bn=True
)
self.refinenet2 = FeatureFusionBlock_custom(
out_channels, bn=True
)
self.refinenet3 = FeatureFusionBlock_custom(
out_channels, bn=True
)
self.refinenet4 = FeatureFusionBlock_custom(
out_channels, bn=True
)
self.head1 = nn.Conv2D(out_channels, self.out_c, kernel_size=1)
self.output_conv = nn.Sequential(
Interpolate(scale_factor=2, mode="bilinear", align_corners=True)
)
def forward(self, layer_1, layer_2, layer_3, layer_4, text_features):
layer_1_rn = self.layer1_rn(layer_1)
layer_2_rn = self.layer2_rn(layer_2)
layer_3_rn = self.layer3_rn(layer_3)
layer_4_rn = self.layer4_rn(layer_4)
path_4 = self.refinenet4(layer_4_rn)
path_3 = self.refinenet3(path_4, layer_3_rn)
path_2 = self.refinenet2(path_3, layer_2_rn)
path_1 = self.refinenet1(path_2, layer_1_rn)
image_features = self.head1(path_1)
imshape = image_features.shape
image_features = image_features.transpose(
(0, 2, 3, 1)).reshape((-1, self.out_c))
# normalized features
image_features = image_features / \
image_features.norm(axis=-1, keepdim=True)
text_features = text_features / \
text_features.norm(axis=-1, keepdim=True)
logits_per_image = self.logit_scale * image_features @ text_features.t()
out = logits_per_image.reshape(
(imshape[0], imshape[2], imshape[3], -1)).transpose((0, 3, 1, 2))
out = self.output_conv(out)
return out
6.4 Semantic Segmentation Model
- Combining the above three modules together forms a text-driven semantic segmentation model
In [5]
class LSeg(nn.Layer):
def __init__(self):
super().__init__()
self.clip = CLIPText()
self.vit = ViT()
self.scratch = Scratch()
def forward(self, images, texts):
layer_1, layer_2, layer_3, layer_4 = self.vit.forward(images)
text_features = self.clip.get_text_features(texts)
return self.scratch.forward(layer_1, layer_2, layer_3, layer_4, text_features)
7. Model reasoning
7.1 Visualization tools
In [ ]
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
def get_new_pallete(num_cls):
n = num_cls
pallete = [0]*(n*3)
for j in range(0,n):
lab = j
pallete[j*3+0] = 0
pallete[j*3+1] = 0
pallete[j*3+2] = 0
i = 0
while (lab > 0):
pallete[j*3+0] |= (((lab >> 0) & 1) << (7-i))
pallete[j*3+1] |= (((lab >> 1) & 1) << (7-i))
pallete[j*3+2] |= (((lab >> 2) & 1) << (7-i))
i = i + 1
lab >>= 3
return pallete
def get_new_mask_pallete(npimg, new_palette, out_label_flag=False, labels=None):
"""Get image color pallete for visualizing masks"""
# put colormap
out_img = Image.fromarray(npimg.squeeze().astype('uint8'))
out_img.putpalette(new_palette)
if out_label_flag:
assert labels is not None
u_index = np.unique(npimg)
patches = []
for i, index in enumerate(u_index):
label = labels[index]
cur_color = [new_palette[index * 3] / 255.0, new_palette[index * 3 + 1] / 255.0, new_palette[index * 3 + 2] / 255.0]
red_patch = mpatches.Patch(color=cur_color, label=label)
patches.append(red_patch)
return out_img, patches
7.2 Loading the model
In [ ]
import paddle.vision.transforms as transforms
from paddlenlp.transformers.clip.tokenizer import CLIPTokenizer
model = LSeg()
state_dict = paddle.load('data/data169501/LSeg.pdparams')
model.set_state_dict(state_dict)
model.eval()
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
[0.5, 0.5, 0.5],
[0.5, 0.5, 0.5]
),
]
)
tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32')
7.3 Model Prediction
In [8]
import cv2
import numpy as np
from PIL import Image
# 指定图像路径
img_path = 'images/cat.jpeg'
# 指定类别标签
labels = ['plant', 'grass', 'cat', 'stone', 'other']
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
h, w = image.shape[:2]
image = image[:-(h%32) if h%32 else None, :-(w%32) if w%32 else None]
images = transform(image).unsqueeze(0)
image = Image.fromarray(image).convert("RGBA")
texts = tokenizer(labels, padding=True, return_tensors="pd")['input_ids']
with paddle.no_grad():
results = model.forward(images, texts)
results = paddle.argmax(results, 1)
results = results.numpy()
new_palette = get_new_pallete(len(labels))
mask, patches = get_new_mask_pallete(results, new_palette, out_label_flag=True, labels=labels)
seg = mask.convert("RGBA")
out = Image.blend(image, seg, alpha=0.5)
plt.axis('off')
plt.imshow(image)
plt.figure()
plt.axis('off')
plt.imshow(out)
plt.figure()
plt.legend(handles=patches, loc='upper right', bbox_to_anchor=(1.5, 1), prop={'size': 20})
plt.axis('off')
plt.imshow(seg)
<matplotlib.image.AxesImage at 0x7ff189365f10>
<Figure size 640x480 with 1 Axes>
<Figure size 640x480 with 1 Axes>
<Figure size 640x480 with 1 Axes>