Image classification using SwinTransformer

f2da30eb3a4453fa3639013e9b073d63.png

SwinTransformer is a Transformer-based backbone structure suitable for the CV field proposed by Microsoft Research Asia in 2021.

2af577c974fa45825ea3150c66699e2f.png

It is the abbreviation of Shift Window Transformer , and the main innovations are as follows.

  • 1. Perform Transformer calculations in windows , reducing the amount of self-attention calculations from the square level of the input size to the linear level.

  • 2. Use Shift Window, that is, pane offset technology to fuse information between different panes . (SW-MSA)

  • 3. Use similar jigsaw puzzle techniques and Mask techniques to perform attention calculations on panes of different sizes after Window offset to improve calculation efficiency.

  • 4. Introduce the Relative Position Bias item in the classic QKV attention formula to express the influence of position information very naturally.

  • 5. Use the Patch Merging technique to realize the downsampling of the feature map, which is similar to the pooling operation but is not easy to lose information.

  • 6. Use windows of different sizes to extract features of different levels and fuse them.

86cd19c0056f50dee21ba76fb9329981.png

Although SwinTransformer adopts the implementation method of Transformer, it draws on the design features of many convolutions in the overall design.

Such as: locality, translation invariance, feature map gradually decreases, channel number gradually increases, multi-scale feature fusion, etc.

At the same time, it also applies a lot of tricks to make up for the shortcomings of Transformer, such as efficiency problems, insufficient expression of location information, etc.

There is a UP owner on station B who said that SwinTransformer is a CNN in Transformer skin. But after all, its main internal calculation is Transformer, so I feel that it is more like a Transformer with convolution Buff superimposed .

The backbone structure of SwinTransformer has very strong expressive ability and wide applicability. It can be applied to various tasks such as image classification, segmentation, detection, etc., and the structural design and experimental work are relatively touchy, so it was rated as the ICCV best paper in 2021. .

In the following example, we fine-tune the SwinTransformer model in the timm library to do a cat and dog picture classification task.

Backstage of the public account Algorithm Gourmet House replies to the keyword: torchkeras , get the source code of this article notebook and the download link of the dataset.

#!pip install -U  timm, torchkeras

〇, pre-trained model

import timm 
from urllib.request import urlopen
from PIL import Image
import timm
import torch 

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))
img

752eb8ca6342e372b9c28a8beaf3ff04.png


model = timm.create_model("swin_base_patch4_window7_224.ms_in22k_ft_in1k", pretrained=True)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0))  # unsqueeze single image into batch of 1

top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1), k=5)
info = timm.data.ImageNetInfo()
class_codes = info.__dict__['_synsets']
class_names = [info.__dict__['_lemmas'][x] for x in class_codes]
{class_names[i]:v for i,v in zip(top5_class_indices.tolist()[0],
                                top5_probabilities.tolist()[0])}
{'espresso': 0.1655443161725998,
 'cup': 0.12100766599178314,
 'chocolate sauce, chocolate syrup': 0.11809349805116653,
 'eggnog': 0.06144588068127632,
 'tray': 0.03965265676379204}
识别出来的主要是 espresso(蒸馏咖啡),cup 啥的,跟图片差不多,么得问题。

1. Prepare data

import torch
import os
data_path = './datasets/cats_vs_dogs'

train_cats = os.listdir(os.path.join(data_path,"train","cats"))
img = Image.open(os.path.join(os.path.join(data_path,"train","cats",train_cats[0])))
img

84910346a10e4b5fa91aa608f0312ce1.png

train_dogs = os.listdir(os.path.join(data_path,"train","dogs"))
img = Image.open(os.path.join(os.path.join(data_path,"train","dogs",train_dogs[0])))
img

457329b334acbdeb122a6162b404db2f.png

from torchvision.datasets import ImageFolder


ds_train = ImageFolder(os.path.join(data_path,"train"),transforms)

ds_val = ImageFolder(os.path.join(data_path,"val"),transforms)


dl_train = torch.utils.data.DataLoader(ds_train, batch_size=4 ,
                                             shuffle=True)
dl_val = torch.utils.data.DataLoader(ds_val, batch_size=2,
                                             shuffle=True)

class_names = ds_train.classes

print(len(ds_train))
print(len(ds_val))
2000
995
for batch in dl_val:
    break
batch[1]
tensor([0, 1])

Second, define the model

model.reset_classifier(num_classes=2)
model(batch[0])
tensor([[ 0.1698, -0.3366],
        [ 0.4805,  0.1415]], grad_fn=<AddmmBackward0>)
model.cuda();

Three, training model

from torchkeras import KerasModel 
from torchmetrics import Accuracy

loss_fn = torch.nn.CrossEntropyLoss()
metrics_dict = {"acc":Accuracy(task='multiclass',num_classes=2)}

optimizer = torch.optim.Adam(model.parameters(),
                            lr=1e-5)

keras_model = KerasModel(model,
                   loss_fn = loss_fn,
                   metrics_dict= metrics_dict,
                   optimizer = optimizer
                  )
features,labels = batch
loss_fn(model(features.cuda()),labels.cuda())
tensor(0.6743, device='cuda:0', grad_fn=<NllLossBackward0>)
dfhistory= keras_model.fit(train_data=dl_train, 
                    val_data=dl_val, 
                    epochs=100, 
                    ckpt_path='checkpoint.pt',
                    patience=10, 
                    monitor="val_acc",
                    mode="max",
                    mixed_precision='no',
                    plot = True,
                    quiet=True
                   )

3a0ef9e271bf187a5f594c9929e79e4d.png

It can be seen that the fitting ability of SwinTransformer is very strong. On this simple data set, two Epochs of finetune directly hit the Acc on the training set to 100%, and the final verification set result is also as high as 99.8%, which is very powerful~

Fourth, the evaluation model

keras_model.evaluate(dl_val)

Five, use the model

from PIL import Image 
img = Image.open('./datasets/cats_vs_dogs/val/dogs/dog.2005.jpg')
model.eval();
model(transforms(img)[None,...].cuda()).softmax(axis=1)
tensor([[1.1537e-04, 9.9988e-01]], device='cuda:0', grad_fn=<SoftmaxBackward0>)

Six, save the model

torch.save(model.state_dict(),'swin_transformer.pt')

For more interesting examples, reply to the keyword in the backstage of the public account Algorithm Gourmet House: torchkeras , you can get the example source code in the tochkeras warehouse.

49431072a4bbaa0f580555396752de60.png

619466906f396ad67a6670bf624ac227.png

Guess you like

Origin blog.csdn.net/Python_Ai_Road/article/details/131199056