SwinTransformer is a Transformer-based backbone structure suitable for the CV field proposed by Microsoft Research Asia in 2021.
It is the abbreviation of Shift Window Transformer , and the main innovations are as follows.
1. Perform Transformer calculations in windows , reducing the amount of self-attention calculations from the square level of the input size to the linear level.
2. Use Shift Window, that is, pane offset technology to fuse information between different panes . (SW-MSA)
3. Use similar jigsaw puzzle techniques and Mask techniques to perform attention calculations on panes of different sizes after Window offset to improve calculation efficiency.
4. Introduce the Relative Position Bias item in the classic QKV attention formula to express the influence of position information very naturally.
5. Use the Patch Merging technique to realize the downsampling of the feature map, which is similar to the pooling operation but is not easy to lose information.
6. Use windows of different sizes to extract features of different levels and fuse them.
Although SwinTransformer adopts the implementation method of Transformer, it draws on the design features of many convolutions in the overall design.
Such as: locality, translation invariance, feature map gradually decreases, channel number gradually increases, multi-scale feature fusion, etc.
At the same time, it also applies a lot of tricks to make up for the shortcomings of Transformer, such as efficiency problems, insufficient expression of location information, etc.
There is a UP owner on station B who said that SwinTransformer is a CNN in Transformer skin. But after all, its main internal calculation is Transformer, so I feel that it is more like a Transformer with convolution Buff superimposed .
The backbone structure of SwinTransformer has very strong expressive ability and wide applicability. It can be applied to various tasks such as image classification, segmentation, detection, etc., and the structural design and experimental work are relatively touchy, so it was rated as the ICCV best paper in 2021. .
In the following example, we fine-tune the SwinTransformer model in the timm library to do a cat and dog picture classification task.
Backstage of the public account Algorithm Gourmet House replies to the keyword: torchkeras , get the source code of this article notebook and the download link of the dataset.
#!pip install -U timm, torchkeras
〇, pre-trained model
import timm
from urllib.request import urlopen
from PIL import Image
import timm
import torch
img = Image.open(urlopen(
'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))
img
model = timm.create_model("swin_base_patch4_window7_224.ms_in22k_ft_in1k", pretrained=True)
model = model.eval()
# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
output = model(transforms(img).unsqueeze(0)) # unsqueeze single image into batch of 1
top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1), k=5)
info = timm.data.ImageNetInfo()
class_codes = info.__dict__['_synsets']
class_names = [info.__dict__['_lemmas'][x] for x in class_codes]
{class_names[i]:v for i,v in zip(top5_class_indices.tolist()[0],
top5_probabilities.tolist()[0])}
{'espresso': 0.1655443161725998,
'cup': 0.12100766599178314,
'chocolate sauce, chocolate syrup': 0.11809349805116653,
'eggnog': 0.06144588068127632,
'tray': 0.03965265676379204}
识别出来的主要是 espresso(蒸馏咖啡),cup 啥的,跟图片差不多,么得问题。
1. Prepare data
import torch
import os
data_path = './datasets/cats_vs_dogs'
train_cats = os.listdir(os.path.join(data_path,"train","cats"))
img = Image.open(os.path.join(os.path.join(data_path,"train","cats",train_cats[0])))
img
train_dogs = os.listdir(os.path.join(data_path,"train","dogs"))
img = Image.open(os.path.join(os.path.join(data_path,"train","dogs",train_dogs[0])))
img
from torchvision.datasets import ImageFolder
ds_train = ImageFolder(os.path.join(data_path,"train"),transforms)
ds_val = ImageFolder(os.path.join(data_path,"val"),transforms)
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=4 ,
shuffle=True)
dl_val = torch.utils.data.DataLoader(ds_val, batch_size=2,
shuffle=True)
class_names = ds_train.classes
print(len(ds_train))
print(len(ds_val))
2000
995
for batch in dl_val:
break
batch[1]
tensor([0, 1])
Second, define the model
model.reset_classifier(num_classes=2)
model(batch[0])
tensor([[ 0.1698, -0.3366],
[ 0.4805, 0.1415]], grad_fn=<AddmmBackward0>)
model.cuda();
Three, training model
from torchkeras import KerasModel
from torchmetrics import Accuracy
loss_fn = torch.nn.CrossEntropyLoss()
metrics_dict = {"acc":Accuracy(task='multiclass',num_classes=2)}
optimizer = torch.optim.Adam(model.parameters(),
lr=1e-5)
keras_model = KerasModel(model,
loss_fn = loss_fn,
metrics_dict= metrics_dict,
optimizer = optimizer
)
features,labels = batch
loss_fn(model(features.cuda()),labels.cuda())
tensor(0.6743, device='cuda:0', grad_fn=<NllLossBackward0>)
dfhistory= keras_model.fit(train_data=dl_train,
val_data=dl_val,
epochs=100,
ckpt_path='checkpoint.pt',
patience=10,
monitor="val_acc",
mode="max",
mixed_precision='no',
plot = True,
quiet=True
)
It can be seen that the fitting ability of SwinTransformer is very strong. On this simple data set, two Epochs of finetune directly hit the Acc on the training set to 100%, and the final verification set result is also as high as 99.8%, which is very powerful~
Fourth, the evaluation model
keras_model.evaluate(dl_val)
Five, use the model
from PIL import Image
img = Image.open('./datasets/cats_vs_dogs/val/dogs/dog.2005.jpg')
model.eval();
model(transforms(img)[None,...].cuda()).softmax(axis=1)
tensor([[1.1537e-04, 9.9988e-01]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
Six, save the model
torch.save(model.state_dict(),'swin_transformer.pt')
For more interesting examples, reply to the keyword in the backstage of the public account Algorithm Gourmet House: torchkeras , you can get the example source code in the tochkeras warehouse.