ICCV2021目标检测算法SwinT的配置(Swin Transformer: Hierarchical Vision Transformer using Shifted Windows)

1、论文下载地址: 

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows. [paper]

 2、代码下载地址:

SwinT可以用于分类、检测、分割等任务

原地址(此代码用于图像分类):

https://github.com/microsoft/Swin-Transformer

下载目标检测代码:

https://github.com/SwinTransformer/Swin-Transformer-Object-Detection

3、新建虚拟python环境并激活

conda create -n SwinTrans python=3.7
conda activate SwinTrans

4、安装pytorch和torchvision

pip3 install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html

注意:因为作者是用较高版本的torch训练的模型,所以安装的torch版本要大于等于1.6.0。我最起初安装torch1.4.0的版本时,加载作者提供的预训练模型时出现如下从错误:

RuntimeError: version_ <= kMaxSupportedFileFormatVersion INTERNAL ASSERT FAILED at /pytorch/caffe2/serialize/inline_container.cc:132, please report a bug to PyTorch. Attempted to read a PyTorch file with version 3, but the maximum supported version for reading is 2. Your PyTorch installation may be too old. (init at /pytorch/caffe2/serialize/inline_container.cc:132)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x33 (0x7f7cdd46a193 in /home1/users/huangbo/anaconda3/envs/SwinTrans/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: caffe2::serialize::PyTorchStreamReader::init() + 0x1f5b (0x7f7c548399eb in /home1/users/huangbo/anaconda3/envs/SwinTrans/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #2: caffe2::serialize::PyTorchStreamReader::PyTorchStreamReader(std::string const&) + 0x64 (0x7f7c5483ac04 in /home1/users/huangbo/anaconda3/envs/SwinTrans/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #3: <unknown function> + 0x6c6536 (0x7f7c9c76b536 in /home1/users/huangbo/anaconda3/envs/SwinTrans/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #4: <unknown function> + 0x295a74 (0x7f7c9c33aa74 in /home1/users/huangbo/anaconda3/envs/SwinTrans/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #46: __libc_start_main + 0xf0 (0x7f7cee1fb840 in /lib/x86_64-linux-gnu/libc.so.6)

5、安装mmcv-full

pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.6.0/index.html


# 推荐用mmcv-full==1.2.4

我的CUDA版本是10.1,torch版本是1.6.0,大家针对性更改。安装编译需要很长一段时间,耐心等待。
6、安装MMDetection

进入工程路径运行:

pip install -r requirements/build.txt

python setup.py develop

7、下载预训练模型

百度网盘 请输入提取码

密码: swin

新建checkpoints路径并放入

 8、新建demo.py并输入如下代码:

from mmdet.apis import init_detector, inference_detector, show_result_pyplot
import cv2

config_file = 'configs/swin/cascade_mask_rcnn_swin_small_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py'

# download the checkpoint from model zoo and put it in `checkpoints/`
checkpoint_file = 'checkpoints/cascade_mask_rcnn_swin_small_patch4_window7.pth'
device = 'cuda:0'
# init a detector
model = init_detector(config_file, checkpoint_file, device=device)
# inference the demo image

image='demo/demo.jpg'

result = inference_detector(model, image)

show_result_pyplot(model, image, result, score_thr=0.3)

# image = model.show_result(image, result, score_thr=0.3)
#
# cv2.imshow('demo', image)
# cv2.waitKey()

9、运行python demo.py得到预测结果

猜你喜欢

转载自blog.csdn.net/qq_17783559/article/details/119381672