Introduction and Inspiration
Since ViT, the research on vision transformer has exploded in a blowout style. From the perspective of thinking, it mainly follows two directions. One is to improve the effect of ViT in image classification; the other is to apply ViT to other image tasks, such as segmentation and In terms of detection tasks, the PVT (Pyramid Vision Transformer) introduced here belongs to the latter. Compared with ViT, PVT introduces a pyramid structure similar to CNN, so that PVT is used as a backbone in dense prediction tasks (segmentation and detection, etc.) like CNN.
Design ideas
The design idea of PVT is that currently in CNN, if you want to obtain multi-scale features, you can use FPN (Feature Pyramid Network), so can Transformer do the same, so the PVT (Pyramid Vision Transformer) module is proposed. It can be easily Combined with DETR-like models to achieve end-to-end detection.
In fact, the idea of PVT is very simple, which is to combine Transformer and FPN, and reduce the feature map through convolution to reduce calculations.
Combined with the above figure, we can see that its main innovations are:
- Compared with ViT, it uses fine-grained image patches and can learn high-resolution feature representations
- Drawing on CNN, a pyramid structure is designed to learn multi-scale features
- The SRA module is introduced, which is mainly used to improve the multi-head attention module to reduce the calculation amount of QKV
Model process
The initialization parameters of the model, which will be explained later
def pvt_small():
model = PyramidVisionTransformer(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1])
First, our input image is torch.Size([2, 3, 224, 224]), that is, batch-size=2, channel=3, W=H=224 and then sent to stage1
:
x, (H, W) = self.patch_embed1(x)
pos_embed1 = self._get_pos_embed(self.pos_embed1, self.patch_embed1, H, W)
x = x + pos_embed1
x = self.pos_drop1(x)
for blk in self.block1:
x = blk(x, H, W)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
Patch operation
First, the patch operation is performed on the input image, which completes the image segmentation and linear mapping (dimension conversion)
stage1. The patchEmbed is defined as:
PatchEmbed(
(proj): Conv2d(3, 64, kernel_size=(4, 4), stride=(4, 4))
(norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm(x)
H, W = H // self.patch_size[0], W // self.patch_size[1]
return x, (H, W)
It can be seen that the vector sent to the patch_embed module first obtains its parameters (B, C, W, H)
and then calls self.proj
the operation
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
It is a two-dimensional convolution:Conv2d(3, 64, kernel_size=(4, 4), stride=(4, 4))
Convolution layer output size: o = ⌊(i + 2p - k) / s⌋ + 1
The padding defaults to 0, so the output size is [2, 64, 56, 56]
and then the flatten operation is performed to obtain [2, 64, 56X56],
and then transpose performs dimension conversion to: [2, 56X56, 64], which is torch.Size([2, 3136, 64])
then Normalization operation and change and return of W and H size, at this time W=H=56
location code
In terms of position encoding, it uses a learnable position encoding method.
pos_embed is initialized by the following method, and the size at this time is:torch.Size([1, 3136, 64])
self.pos_embed1 = nn.Parameter(torch.zeros(1, self.patch_embed1.num_patches, embed_dims[0]))
Then proceed with:
pos_embed1 = self._get_pos_embed(self.pos_embed1, self.patch_embed1, H, W)
The processing process is:
def _get_pos_embed(self, pos_embed, patch_embed, H, W):
if H * W == self.patch_embed1.num_patches:
return pos_embed
else:
return F.interpolate(
pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),
size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1)
Then directly add the processed semantic feature information and position encoding information, note:
pos_embed1 at this time is: torch.Size([1, 3136, 64])
x is torch.Size([2, 3136, 64])
, and it can also be added
x = x + pos_embed1
The following test uses the broadcast mechanism to expand the dimension for calculation
import torch
data = torch.randn((2, 1, 2, 2))
data1 = torch.randn((1, 1, 2, 2))
print(data)
print(data1)
print(data+data1)
attention calculation
Enter the attention calculation module, there are multiple attention layers in each stage
for blk in self.block1:
x = blk(x, H, W)
Then start the calculation of attention:
Q construction: after a linear construction, and separate calculations (multi-head attention)
self.q = nn.Linear(dim, dim, bias=qkv_bias)
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
Get q as: torch.Size([2, 1, 3136, 64])
Then enter its innovative module, which is to downsample K and V to reduce their number: self.sr_ratio is to reduce the size
x to torch.Size([2, 3136, 64]), first After permute, the dimension is transformed to torch.Size([2, 64,3136]), and then reshape is: torch.Size([2, 64, 56, 56])
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
Then dimensionality reduction is performed by convolution operations:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
self.sr is: Conv2d(64, 64, kernel_size=(8, 8), stride=(8, 8))
through which the output size 卷积层输出尺寸: o = ⌊(i + 2p - k) / s⌋ + 1
can be obtained: 7 or 49 The corresponding tensor is:torch.Size([2, 49, 64])
Then use x_ to get kv
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
The obtained kv is torch.Size([2, 2, 1, 49, 64])
, k and v are both torch.Size([2, 1, 49, 64]), but the values are different. The
complete code is as follows:
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
Then a series of calculations are performed:
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
@实际为x@y=x.matmul(y)
The result x of the final calculation of the attention mechanism is stilltorch.Size([2, 3136, 64])
The complete code of the attention mechanism module is as follows:
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
Finally, x is restored to a 2-dimensional result after reshaping, and then enters the next stage for calculation
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
Then the size of the feature map is reduced through convolution, and the loss in space is compensated by the dimension. The next stage x is:
torch.Size([2, 784, 128])
Through this process in turn, its size is reduced, and multi-scale information is also obtained through it. It is worth noting that only the patch=4 on stage1, and the patch on the next three stages are all 2, so it also refers to the convolution, which is a relationship of double size.
Overall, the design of Pyramid Version Transformer is relatively easy to understand. The following blogger starts with the code to explain the PVT construction process in detail.
The complete code is as follows:
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
super().__init__()
assert dim % num_heads == 0, f"dim {
dim} should be divided by num_heads {
num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \
f"img_size {
img_size} should be divided by patch_size {
patch_size}."
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm(x)
H, W = H // self.patch_size[0], W // self.patch_size[1]
return x, (H, W)
class PyramidVisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], F4=False):
super().__init__()
self.num_classes = num_classes
self.depths = depths
self.F4 = F4
# patch_embed
self.patch_embed1 = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=embed_dims[0])
self.patch_embed2 = PatchEmbed(img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0],
embed_dim=embed_dims[1])
self.patch_embed3 = PatchEmbed(img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1],
embed_dim=embed_dims[2])
self.patch_embed4 = PatchEmbed(img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2],
embed_dim=embed_dims[3])
# pos_embed
self.pos_embed1 = nn.Parameter(torch.zeros(1, self.patch_embed1.num_patches, embed_dims[0]))
self.pos_drop1 = nn.Dropout(p=drop_rate)
self.pos_embed2 = nn.Parameter(torch.zeros(1, self.patch_embed2.num_patches, embed_dims[1]))
self.pos_drop2 = nn.Dropout(p=drop_rate)
self.pos_embed3 = nn.Parameter(torch.zeros(1, self.patch_embed3.num_patches, embed_dims[2]))
self.pos_drop3 = nn.Dropout(p=drop_rate)
self.pos_embed4 = nn.Parameter(torch.zeros(1, self.patch_embed4.num_patches + 1, embed_dims[3]))
self.pos_drop4 = nn.Dropout(p=drop_rate)
# transformer encoder
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
cur = 0
self.block1 = nn.ModuleList([Block(
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[0])
for i in range(depths[0])])
cur += depths[0]
self.block2 = nn.ModuleList([Block(
dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[1])
for i in range(depths[1])])
cur += depths[1]
self.block3 = nn.ModuleList([Block(
dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[2])
for i in range(depths[2])])
cur += depths[2]
self.block4 = nn.ModuleList([Block(
dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[3])
for i in range(depths[3])])
# init weights
trunc_normal_(self.pos_embed1, std=.02)
trunc_normal_(self.pos_embed2, std=.02)
trunc_normal_(self.pos_embed3, std=.02)
trunc_normal_(self.pos_embed4, std=.02)
self.apply(self._init_weights)
def init_weights(self, pretrained=True):
import torch
# if isinstance(pretrained, str):
# logger = get_root_logger()
# load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
def reset_drop_path(self, drop_path_rate):
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
cur = 0
for i in range(self.depths[0]):
self.block1[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[0]
for i in range(self.depths[1]):
self.block2[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[1]
for i in range(self.depths[2]):
self.block3[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[2]
for i in range(self.depths[3]):
self.block4[i].drop_path.drop_prob = dpr[cur + i]
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _get_pos_embed(self, pos_embed, patch_embed, H, W):
if H * W == self.patch_embed1.num_patches:
return pos_embed
else:
return F.interpolate(
pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),
size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1)
def forward_features(self, x):
outs = []
B = x.shape[0]
# stage 1
x, (H, W) = self.patch_embed1(x)
pos_embed1 = self._get_pos_embed(self.pos_embed1, self.patch_embed1, H, W)
x = x + pos_embed1
x = self.pos_drop1(x)
for blk in self.block1:
x = blk(x, H, W)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 2
x, (H, W) = self.patch_embed2(x)
pos_embed2 = self._get_pos_embed(self.pos_embed2, self.patch_embed2, H, W)
x = x + pos_embed2
x = self.pos_drop2(x)
for blk in self.block2:
x = blk(x, H, W)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 3
x, (H, W) = self.patch_embed3(x)
pos_embed3 = self._get_pos_embed(self.pos_embed3, self.patch_embed3, H, W)
x = x + pos_embed3
x = self.pos_drop3(x)
for blk in self.block3:
x = blk(x, H, W)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 4
x, (H, W) = self.patch_embed4(x)
pos_embed4 = self._get_pos_embed(self.pos_embed4[:, 1:], self.patch_embed4, H, W)
x = x + pos_embed4
x = self.pos_drop4(x)
for blk in self.block4:
x = blk(x, H, W)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
return outs
def forward(self, x):
x = self.forward_features(x)
if self.F4:
x = x[3:4]
return x
def _conv_filter(state_dict, patch_size=16):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {
}
for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k:
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
out_dict[k] = v
return out_dict
def pvt_small():
model = PyramidVisionTransformer(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1])
return model
model = pvt_small()
data = torch.randn((2, 3, 224, 224))
feature = model(data)
print(model)
for out in feature:
print(out.shape)