WS-DAN 复现 WSDAN(Weakly Supervised Data Augmentation Network)

一, WS-DAN介绍

论文原文:《See Better Before Looking Closer: Weakly Supervised Data Augmentation Network for Fine-Grained Visual Classification》

网上很多介绍,选1篇大家自己去看:
【细粒度】WS-DAN

b站视频:
【论文讲解+复现】WS-DAN WSDAN(Weakly Supervised Data Augmentation Network

二,准备

2.1 平台

极链AI云
整个复现过程将会在这个平台上进行,原因是,有现成的pytorch和cuda环境,直接线上敲代码,不需要本地配置。

2.2 源码

码云上的源码(速度快)
github上的源码

三,开始复现

3.1,创建实例

进入极链AI云平台,选择一个最便宜的机子。
在这里插入图片描述

镜像如下:
Pytorch选择1.8,python3.8,CUDA11.1.1
在这里插入图片描述
创建好之后的样子:
在这里插入图片描述
我们使用第三方工具的jupyter lab
在这里插入图片描述

进入后,界面如下:
在这里插入图片描述

3.2 在jupyter lab中创建文件

在/home下创建wsdan文件
在这里插入图片描述
打开terminal
在这里插入图片描述

在这里插入图片描述

我们输入ls,查看一下当前目录情况:
在这里插入图片描述

在terminal当中输入:

cd /home/wsdan/

进入到wadn目录中,然后我们开始下载ws-dan的项目

3.3 下载ws-dan项目

输入命令:

git clone https://gitee.com/YFwinston/WS-DAN.PyTorch.git

在这里插入图片描述

下载好后,ws-dan的目录如下:
在这里插入图片描述

3.4 数据集准备

我这里以鸟类数据集为例子
官方下载地址:CUB-200-2011 (Bird) 但是下载这个要翻墙

在极链云平台,这个数据集我已经叫官方加进去了,就在:
/datasets/CUB-200-2011-Bird/
在这里插入图片描述

3.5 修改bird_dataset.py文件

修改bird_dataset.py中的数据集的指定路径
在这里插入图片描述
在这里插入图片描述

修改如下(修改为我们下载的路径):

#DATAPATH = '/home/guyuchong/DATA/FGVC/CUB-200-2011'
DATAPATH = '/home/dataset/CUB_200_2011'

3.6训练

没错,就是这么快,就可以开始训练了,往往最高端的论文,往往只需要最简单的操作。

在终端,进入下面的文件路径
在这里插入图片描述
在终端中输入:

python3 train.py 

在这里插入图片描述

3.8 训练结果

下面是训练日志,所有信息都在里面:

2021-08-10 20:38:58,851: INFO: [inception.py:177]: Inception3: All params loaded
2021-08-10 20:38:59,323: INFO: [wsdan.py:97]: WSDAN: using inception_mixed_6e as feature extractor, num_classes: 200, num_attentions: 32
2021-08-10 20:39:02,305: INFO: [train.py:93]: Network weights save to ./FGVC/CUB-200-2011/ckpt/
2021-08-10 20:39:02,353: INFO: [train.py:126]: Start training: Total epochs: 160, Batch size: 8, Training size: 5994, Validation size: 5794
2021-08-10 20:39:02,353: INFO: [train.py:128]: 
2021-08-10 20:39:02,353: INFO: [train.py:136]: Epoch 001, Learning Rate 0.001
2021-08-10 20:42:35,529: INFO: [train.py:247]: Train: Loss 5.4726, Raw Acc (28.40, 49.73), Crop Acc (19.95, 38.02), Drop Acc (18.92, 38.67), Time 213.17
2021-08-10 20:43:10,918: INFO: [train.py:302]: Valid: Val Loss 2.1645, Val Acc (54.06, 83.26), Time 35.38
2021-08-10 20:43:10,918: INFO: [train.py:303]: 
2021-08-10 20:43:11,094: INFO: [train.py:136]: Epoch 002, Learning Rate 0.001
2021-08-10 20:46:44,574: INFO: [train.py:247]: Train: Loss 2.8880, Raw Acc (64.60, 88.47), Crop Acc (50.30, 76.13), Drop Acc (48.68, 76.54), Time 213.48
2021-08-10 20:47:18,815: INFO: [train.py:302]: Valid: Val Loss 1.5154, Val Acc (69.14, 89.30), Time 34.24
2021-08-10 20:47:18,815: INFO: [train.py:303]: 
2021-08-10 20:47:19,097: INFO: [train.py:136]: Epoch 003, Learning Rate 0.0009
2021-08-10 20:50:52,811: INFO: [train.py:247]: Train: Loss 1.8485, Raw Acc (79.46, 95.43), Crop Acc (67.42, 87.20), Drop Acc (66.00, 88.99), Time 213.71
2021-08-10 20:51:27,392: INFO: [train.py:302]: Valid: Val Loss 1.1276, Val Acc (75.58, 93.23), Time 34.58
2021-08-10 20:51:27,393: INFO: [train.py:303]: 
2021-08-10 20:51:27,672: INFO: [train.py:136]: Epoch 004, Learning Rate 0.0009
2021-08-10 20:55:02,738: INFO: [train.py:247]: Train: Loss 1.3692, Raw Acc (84.45, 97.23), Crop Acc (75.79, 92.71), Drop Acc (72.82, 92.68), Time 215.06
2021-08-10 20:55:37,500: INFO: [train.py:302]: Valid: Val Loss 1.0774, Val Acc (76.53, 94.56), Time 34.75
2021-08-10 20:55:37,501: INFO: [train.py:303]: 
2021-08-10 20:55:37,770: INFO: [train.py:136]: Epoch 005, Learning Rate 0.00081
2021-08-10 20:59:13,200: INFO: [train.py:247]: Train: Loss 1.0293, Raw Acc (88.47, 98.53), Crop Acc (82.12, 95.58), Drop Acc (79.63, 95.93), Time 215.43
2021-08-10 20:59:48,232: INFO: [train.py:302]: Valid: Val Loss 0.8782, Val Acc (80.96, 95.60), Time 35.03
2021-08-10 20:59:48,232: INFO: [train.py:303]: 
2021-08-10 20:59:48,496: INFO: [train.py:136]: Epoch 006, Learning Rate 0.00081
2021-08-10 21:03:24,037: INFO: [train.py:247]: Train: Loss 0.8415, Raw Acc (91.02, 99.27), Crop Acc (85.30, 96.71), Drop Acc (82.83, 96.98), Time 215.54
2021-08-10 21:03:59,091: INFO: [train.py:302]: Valid: Val Loss 0.9179, Val Acc (80.82, 95.91), Time 35.05
2021-08-10 21:03:59,091: INFO: [train.py:303]: 
2021-08-10 21:03:59,092: INFO: [train.py:136]: Epoch 007, Learning Rate 0.000729
2021-08-10 21:07:36,231: INFO: [train.py:247]: Train: Loss 0.6588, Raw Acc (93.91, 99.70), Crop Acc (90.12, 98.31), Drop Acc (86.60, 97.58), Time 217.14
2021-08-10 21:08:11,372: INFO: [train.py:302]: Valid: Val Loss 0.7636, Val Acc (83.29, 96.43), Time 35.13
2021-08-10 21:08:11,373: INFO: [train.py:303]: 
2021-08-10 21:08:11,642: INFO: [train.py:136]: Epoch 008, Learning Rate 0.000729
2021-08-10 21:11:52,144: INFO: [train.py:247]: Train: Loss 0.5629, Raw Acc (95.30, 99.68), Crop Acc (92.46, 98.75), Drop Acc (88.30, 98.03), Time 220.50
2021-08-10 21:12:26,958: INFO: [train.py:302]: Valid: Val Loss 0.7243, Val Acc (83.50, 96.48), Time 34.81
2021-08-10 21:12:26,959: INFO: [train.py:303]: 
2021-08-10 21:12:27,248: INFO: [train.py:136]: Epoch 009, Learning Rate 0.0006561
2021-08-10 21:16:07,619: INFO: [train.py:247]: Train: Loss 0.4597, Raw Acc (97.26, 99.93), Crop Acc (94.29, 98.88), Drop Acc (91.29, 98.58), Time 220.37
2021-08-10 21:16:42,933: INFO: [train.py:302]: Valid: Val Loss 0.6472, Val Acc (85.88, 96.96), Time 35.31
2021-08-10 21:16:42,934: INFO: [train.py:303]: 
2021-08-10 21:16:43,229: INFO: [train.py:136]: Epoch 010, Learning Rate 0.0006561
2021-08-10 21:20:23,629: INFO: [train.py:247]: Train: Loss 0.4084, Raw Acc (97.86, 99.98), Crop Acc (95.68, 99.15), Drop Acc (92.43, 98.92), Time 220.40
2021-08-10 21:20:58,538: INFO: [train.py:302]: Valid: Val Loss 0.6557, Val Acc (84.90, 96.65), Time 34.90
2021-08-10 21:20:58,538: INFO: [train.py:303]: 
2021-08-10 21:20:58,540: INFO: [train.py:136]: Epoch 011, Learning Rate 0.00059049
2021-08-10 21:24:39,002: INFO: [train.py:247]: Train: Loss 0.3567, Raw Acc (98.63, 100.00), Crop Acc (96.06, 98.87), Drop Acc (94.63, 99.33), Time 220.46
2021-08-10 21:25:13,825: INFO: [train.py:302]: Valid: Val Loss 0.6422, Val Acc (85.33, 97.24), Time 34.82
2021-08-10 21:25:13,826: INFO: [train.py:303]: 
2021-08-10 21:25:13,827: INFO: [train.py:136]: Epoch 012, Learning Rate 0.00059049
2021-08-10 21:28:54,322: INFO: [train.py:247]: Train: Loss 0.3445, Raw Acc (98.97, 100.00), Crop Acc (95.86, 98.26), Drop Acc (95.41, 99.58), Time 220.49
2021-08-10 21:29:29,168: INFO: [train.py:302]: Valid: Val Loss 0.6214, Val Acc (85.17, 97.08), Time 34.84
2021-08-10 21:29:29,169: INFO: [train.py:303]: 
2021-08-10 21:29:29,170: INFO: [train.py:136]: Epoch 013, Learning Rate 0.000531441
2021-08-10 21:33:09,177: INFO: [train.py:247]: Train: Loss 0.3383, Raw Acc (99.27, 100.00), Crop Acc (94.99, 97.26), Drop Acc (96.50, 99.42), Time 220.00
2021-08-10 21:33:43,572: INFO: [train.py:302]: Valid: Val Loss 0.6169, Val Acc (85.88, 97.20), Time 34.39
2021-08-10 21:33:43,573: INFO: [train.py:303]: 
2021-08-10 21:33:43,574: INFO: [train.py:136]: Epoch 014, Learning Rate 0.000531441
2021-08-10 21:37:16,862: INFO: [train.py:247]: Train: Loss 0.3324, Raw Acc (99.45, 100.00), Crop Acc (94.48, 96.53), Drop Acc (97.23, 99.57), Time 213.29
2021-08-10 21:37:50,836: INFO: [train.py:302]: Valid: Val Loss 0.6173, Val Acc (85.99, 96.93), Time 33.97
2021-08-10 21:37:50,836: INFO: [train.py:303]: 
2021-08-10 21:37:51,098: INFO: [train.py:136]: Epoch 015, Learning Rate 0.000478297
2021-08-10 21:41:22,974: INFO: [train.py:247]: Train: Loss 0.3441, Raw Acc (99.62, 100.00), Crop Acc (93.14, 94.99), Drop Acc (97.51, 99.68), Time 211.87
2021-08-10 21:41:57,228: INFO: [train.py:302]: Valid: Val Loss 0.6051, Val Acc (86.28, 96.89), Time 34.25
2021-08-10 21:41:57,229: INFO: [train.py:303]: 
2021-08-10 21:41:57,534: INFO: [train.py:136]: Epoch 016, Learning Rate 0.000478297
2021-08-10 21:45:36,916: INFO: [train.py:247]: Train: Loss 0.3433, Raw Acc (99.43, 99.97), Crop Acc (92.79, 95.31), Drop Acc (97.61, 99.77), Time 219.38
2021-08-10 21:46:11,597: INFO: [train.py:302]: Valid: Val Loss 0.5868, Val Acc (86.52, 97.08), Time 34.68
2021-08-10 21:46:11,598: INFO: [train.py:303]: 
2021-08-10 21:46:11,858: INFO: [train.py:136]: Epoch 017, Learning Rate 0.000430467
2021-08-10 21:49:49,921: INFO: [train.py:247]: Train: Loss 0.3544, Raw Acc (99.80, 100.00), Crop Acc (91.11, 93.19), Drop Acc (98.13, 99.80), Time 218.06
2021-08-10 21:50:24,441: INFO: [train.py:302]: Valid: Val Loss 0.5687, Val Acc (86.54, 97.15), Time 34.52
2021-08-10 21:50:24,442: INFO: [train.py:303]: 
2021-08-10 21:50:24,693: INFO: [train.py:136]: Epoch 018, Learning Rate 0.000430467
2021-08-10 21:54:04,750: INFO: [train.py:247]: Train: Loss 0.3822, Raw Acc (99.80, 100.00), Crop Acc (89.47, 91.59), Drop Acc (97.98, 99.68), Time 220.06
2021-08-10 21:54:39,899: INFO: [train.py:302]: Valid: Val Loss 0.5953, Val Acc (86.71, 97.03), Time 35.14
2021-08-10 21:54:39,900: INFO: [train.py:303]: 
2021-08-10 21:54:40,163: INFO: [train.py:136]: Epoch 019, Learning Rate 0.00038742
2021-08-10 21:58:18,569: INFO: [train.py:247]: Train: Loss 0.4369, Raw Acc (99.73, 100.00), Crop Acc (85.60, 88.41), Drop Acc (98.60, 99.83), Time 218.40
2021-08-10 21:58:53,240: INFO: [train.py:302]: Valid: Val Loss 0.5616, Val Acc (86.35, 97.17), Time 34.66
2021-08-10 21:58:53,240: INFO: [train.py:303]: 
2021-08-10 21:58:53,241: INFO: [train.py:136]: Epoch 020, Learning Rate 0.00038742
2021-08-10 22:02:32,319: INFO: [train.py:247]: Train: Loss 0.4348, Raw Acc (99.80, 100.00), Crop Acc (85.34, 88.25), Drop Acc (98.33, 99.87), Time 219.08
2021-08-10 22:03:06,665: INFO: [train.py:302]: Valid: Val Loss 0.6059, Val Acc (86.61, 97.03), Time 34.34
2021-08-10 22:03:06,666: INFO: [train.py:303]: 
2021-08-10 22:03:06,667: INFO: [train.py:136]: Epoch 021, Learning Rate 0.000348678
2021-08-10 22:06:44,322: INFO: [train.py:247]: Train: Loss 0.5032, Raw Acc (99.85, 100.00), Crop Acc (80.63, 84.22), Drop Acc (98.78, 99.88), Time 217.65
2021-08-10 22:07:19,545: INFO: [train.py:302]: Valid: Val Loss 0.5865, Val Acc (86.69, 97.03), Time 35.22
2021-08-10 22:07:19,546: INFO: [train.py:303]: 
2021-08-10 22:07:19,547: INFO: [train.py:136]: Epoch 022, Learning Rate 0.000348678
2021-08-10 22:10:58,826: INFO: [train.py:247]: Train: Loss 0.5322, Raw Acc (99.92, 100.00), Crop Acc (78.78, 82.50), Drop Acc (99.13, 99.92), Time 219.28
2021-08-10 22:11:33,326: INFO: [train.py:302]: Valid: Val Loss 0.5980, Val Acc (86.90, 97.08), Time 34.49
2021-08-10 22:11:33,327: INFO: [train.py:303]: 
2021-08-10 22:11:33,611: INFO: [train.py:136]: Epoch 023, Learning Rate 0.000313811
2021-08-10 22:15:11,382: INFO: [train.py:247]: Train: Loss 0.5506, Raw Acc (99.85, 100.00), Crop Acc (76.96, 81.43), Drop Acc (99.07, 99.92), Time 217.77
2021-08-10 22:15:46,396: INFO: [train.py:302]: Valid: Val Loss 0.5904, Val Acc (86.87, 97.08), Time 35.01
2021-08-10 22:15:46,397: INFO: [train.py:303]: 
2021-08-10 22:15:46,398: INFO: [train.py:136]: Epoch 024, Learning Rate 0.000313811
2021-08-10 22:19:26,133: INFO: [train.py:247]: Train: Loss 0.5793, Raw Acc (99.93, 100.00), Crop Acc (75.84, 79.96), Drop Acc (99.02, 99.88), Time 219.73
2021-08-10 22:20:00,900: INFO: [train.py:302]: Valid: Val Loss 0.6006, Val Acc (86.90, 97.08), Time 34.76
2021-08-10 22:20:00,901: INFO: [train.py:303]: 
2021-08-10 22:20:00,904: INFO: [train.py:136]: Epoch 025, Learning Rate 0.00028243
2021-08-10 22:23:38,719: INFO: [train.py:247]: Train: Loss 0.5839, Raw Acc (99.93, 100.00), Crop Acc (74.46, 79.41), Drop Acc (99.40, 99.93), Time 217.81
2021-08-10 22:24:13,417: INFO: [train.py:302]: Valid: Val Loss 0.5971, Val Acc (87.28, 97.07), Time 34.69
2021-08-10 22:24:13,418: INFO: [train.py:303]: 
2021-08-10 22:24:13,715: INFO: [train.py:136]: Epoch 026, Learning Rate 0.00028243
2021-08-10 22:27:52,023: INFO: [train.py:247]: Train: Loss 0.5866, Raw Acc (99.90, 100.00), Crop Acc (74.19, 78.60), Drop Acc (99.28, 99.93), Time 218.31
2021-08-10 22:28:26,852: INFO: [train.py:302]: Valid: Val Loss 0.5772, Val Acc (87.12, 97.27), Time 34.82
2021-08-10 22:28:26,853: INFO: [train.py:303]: 
2021-08-10 22:28:26,854: INFO: [train.py:136]: Epoch 027, Learning Rate 0.000254187
2021-08-10 22:32:05,582: INFO: [train.py:247]: Train: Loss 0.5930, Raw Acc (99.87, 100.00), Crop Acc (73.49, 78.46), Drop Acc (99.33, 99.97), Time 218.73
2021-08-10 22:32:40,671: INFO: [train.py:302]: Valid: Val Loss 0.6140, Val Acc (86.81, 97.19), Time 35.08
2021-08-10 22:32:40,671: INFO: [train.py:303]: 
2021-08-10 22:32:40,673: INFO: [train.py:136]: Epoch 028, Learning Rate 0.000254187
2021-08-10 22:36:19,499: INFO: [train.py:247]: Train: Loss 0.6089, Raw Acc (99.93, 100.00), Crop Acc (72.69, 77.48), Drop Acc (99.20, 99.92), Time 218.82
2021-08-10 22:36:54,453: INFO: [train.py:302]: Valid: Val Loss 0.5868, Val Acc (87.09, 97.32), Time 34.94
2021-08-10 22:36:54,454: INFO: [train.py:303]: 
2021-08-10 22:36:54,455: INFO: [train.py:136]: Epoch 029, Learning Rate 0.000228768
2021-08-10 22:40:28,755: INFO: [train.py:247]: Train: Loss 0.5744, Raw Acc (99.92, 100.00), Crop Acc (74.59, 79.43), Drop Acc (99.23, 99.88), Time 214.30
2021-08-10 22:41:02,921: INFO: [train.py:302]: Valid: Val Loss 0.5789, Val Acc (87.11, 97.31), Time 34.16
2021-08-10 22:41:02,922: INFO: [train.py:303]: 
2021-08-10 22:41:02,922: INFO: [train.py:136]: Epoch 030, Learning Rate 0.000228768


3.7 运行过程中出现的错误

root@6110b71032f84836768746a8:/home/wsdan/WS-DAN.PyTorch# python3 train.py 
Epoch 1/160:   0%|                                                      | 0/750 [00:00<?, ? batches/s]Traceback (most recent call last):
  File "train.py", line 307, in <module>
    main()
  File "train.py", line 141, in main
    train(logs=logs,
  File "train.py", line 227, in train
    epoch_raw_acc = raw_metric(y_pred_raw, y)
  File "/home/wsdan/WS-DAN.PyTorch/utils.py", line 66, in __call__
    correct_k = correct[:k].view(-1).float().sum(0)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
Epoch 1/160:   0%|                                                      | 0/750 [00:03<?, ? batches/s]
root@6110b71032f84836768746a8:/home/wsdan/WS-DAN.PyTorch# 

解决方案:
https://blog.csdn.net/tiao_god/article/details/108189879
在utils.py中:

       for i, k in enumerate(self.topk):
            #correct_k = correct[:k].view(-1).float().sum(0)
            correct_k = correct[:k].contiguous().view(-1).float().sum(0)
   

在这里插入图片描述
这是因为view()需要Tensor中的元素地址是连续的,但可能出现Tensor不连续的情况,所以先用 .contiguous() 将其在内存中变成连续分布。

3.7 测试

在终端输入:

python3 eval.py 

在这里插入图片描述
训练结果:

Validation: 100%|█████████████| 725/725 [14:33<00:00,  1.21s/ batches, Val Acc: Raw (87.76, 97.34), Refine (88.23, 97.48)]

测试结束后,在目录:/home/wsdan/WS-DAN.PyTorch/FGVC/CUB-200-2011/visualize 中有可视化的结果
在这里插入图片描述

我就随机选择2组出来:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

四,代码讲解

【论文讲解+复现】WS-DAN WSDAN(Weakly Supervised Data Augmentation Network

猜你喜欢

转载自blog.csdn.net/WhiffeYF/article/details/119534760