[Small experiment 1] Comparing the inductive bias of ResNet, ViT, and SwinTransformer (however, the expected result was not achieved)

Written in the front: This experiment did not achieve the expected results, and it is more of an experiment record.

1. idea

1.1 Experiment ideas

The idea of ​​this experiment is as follows: through random initialization (normal distribution) of untrained ResNet, ViT and SwinTransformer, the validation set (val, 50000 pictures in total, 1000 categories) of ImangeNet-1k (2012) is tested Forecast, compare the difference between the forecast result and random guess (accuracy rate 1‰).

1.2 Source of inspiration

Inspired by the paper Deep Clustering for Unsupervised Learning of Visual Features , the paper mentioned such a passage:
insert image description here
This passage says: "When the parameters in the model (such as CNN) θ \thetaWhen θ is randomly initialized with a Gaussian distribution, it predicts poorly when not trained. However, it is much better than random guessing (for ImageNet-1k, the probability of random guessing is 1 in 1,000). "

The explanation given by the author is that this is because our model, such as CNN, introduces prior knowledge (that is, CNN's inductive bias about translational variability, locality, etc.), so after random initialization, although it has not been trained by data, it is obtained. The results are also better than random guessing.

So I was thinking, using untrained randomly initialized ResNet, ViT and SwinTransformer to predict ImageNet-1k (val), the higher the accuracy, the stronger the inductive bias introduced to some extent.

2. Experimental setup

lab environment Pytorch1.12、2xRTX3090(24G)
Model ResNet、ViT、SwinTransformer
test data set ImageNet-1k (2012, val, 50000 images, 1000 categories)
experimental model Parameter amount
ResNet50 25.56M
ResNet101 44.55M
ViT_B_16 86.57M
ViT_L_16 304.33M
swin_t 28.29M
swin_s 49.61M

The initialization basically uses various types of normal distributions for initialization, such as

nn.init.trunc_normal_
nn.init.normal_
nn.init.kaiming_normal_

For specific usage, refer to the official Pytorch documentation

3. Experimental results

3.1 Results

The experimental results are: untrained, randomly initialized ResNet, ViT and SwinTransformer, the final prediction results are similar to random guesses (one thousandth).

Model predict the correct amount Number of test sets Accuracy training time batch_size Graphics card usage
serious50 50 50000 1‰ 35.485s 256 insert image description here
resnet101 50 50000 1‰ 45,556s 256 insert image description here
vit_b_16 50 50000 1‰ 83.587s 256 insert image description here
vit_l_16 50 50000 1‰ 370.091s 64 insert image description here
swin_t 43 50000 ~1‰ 44.616s 256 insert image description here
swin_s 49 50000 ~1‰ 67,754s 256 insert image description here

3.2 Analysis of results

3.2.1 A strange phenomenon

There is a more interesting phenomenon. We can see that for the ResNet series and the ViT series, 50 samples are all predicted correctly. This is because for untrained randomly initialized ResNet and ViT, no matter what sample is input, the predicted label is the same.

Note: It does not mean that the output is the same, but that the 1000-dimensional features output by softmax finally have the same index corresponding to the maximum value. So there will be exactly 50 predictions right. (The val of Imagnet-1K has 50000 pictures, 1000 categories, each category has 50 pictures).

For SwinTransformer, the final predicted label does seem to be randomly distributed, but it is the same as guessing (both 1‰), and it is not like the author of Deep Cluster said: it is much better than guessing. .

3.2.2 Analysis

I think what the author of Deep Cluster said should be correct (after all, it is an article written by a big guy, or a top meeting), the problem should lie in the details of my code implementation , such as the way of initialization? I use the initialization that comes with the TorchVision source code, and I feel that there is no problem.

I don't plan to go deeper into this phenomenon, because I just want to read about Deep Cluster. It's interesting to see this phenomenon mentioned by the author, but it doesn't have the expected effect. If you want to find out the reason in depth, you can look at the references of Deep Cluster.

4. Code

import os
import time

import torch
import torch.nn as nn
from torchvision.models import resnet50, resnet101, vit_b_16, vit_l_16, swin_t, swin_s
from torchvision import datasets, transforms
from tqdm import tqdm


def get_dataloader(data_dir=None, batch_size=64):
    '''
    :param data_dir: val dataset direction
    :return: imageNet
    '''
    assert data_dir is not None

    transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
        transforms.ToTensor(),
        # 这里我在考虑是否要进行标准化,可以做个对比实验
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    val_dataset = datasets.ImageFolder(
        data_dir,
        transform,
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        num_workers=8,
        shuffle=True,
    )

    return val_loader


def get_models():
    '''
    :return: res50, res101, ViT, SwinTransformer
    '''
    # load model with random initialization
    res50 = resnet50()
    res101 = resnet101()
    vit_B_16 = vit_b_16()
    vit_L_16 = vit_l_16()
    swin_T = swin_t()  # 参数量和res50相近
    swin_B = swin_s()  # 参数量和res101相近
    model_list = [res50, res101, vit_B_16, vit_L_16, swin_T, swin_B]
    model_names = ['res50', 'res101', 'vit_B_16', 'vit_L_16', 'swin_T', 'swin_B']
    for name, model in zip(model_names, model_list):
        print(f'{
      
      name:10}parametersize is {
      
      compute_params(model): .2f}M')

    return model_list, model_names


def compute_params(model):
    '''
    :param model: nn.Module, model
    :return: float, model parameter size
    '''
    total = sum(p.numel() for p in model.parameters())
    size = total / 1e6
    # print("Total params: %.2fM" % (total / 1e6))
    return size


def model_evaluate(model, data_loader):
    '''
    :param model: test model
    :param data_loader: val_loader
    :return: list[float] acc list
    '''
    device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
    device_ids = [1, 2]  # use 2 GPUs
    model = nn.DataParallel(model, device_ids=device_ids)
    model.to(device)
    model.eval()
    total = 0
    correct = 0
    loop = tqdm((data_loader), total=len(data_loader))
    for imgs, labels in loop:
        imgs.to(device)
        outputs = model(imgs)
        outputs = outputs.argmax(dim=1)
        labels = labels.to(device)

        # print(outputs.shape, '\n', outputs, outputs.argmax(dim=1))
        # print(labels.shape, '\n', labels)
        total += len(labels)
        res = outputs==labels
        correct += res.sum().item()
        loop.set_description(f'inference test:')
        loop.set_postfix(total=total, correct=correct, acc=f'{
      
      correct/total:.2f}')



if __name__ == '__main__':
    seed = 2022
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    data_dir = os.path.join('..', '..', '..', 'data', 'ImageNet2012', 'imagenet', 'val')
    val_loader = get_dataloader(data_dir, batch_size=256)

    get_models()  # 输出模型的大小,本来想写一个循环自动训练所有的model,但测试会爆显存,所以就单独测试每个model了
    net = swin_s()  # 这里换成我们测试的模型,可以用resnet50, resnet101, vit_b_16, vit_l_16, swin_t, swin_s
    t1 = time.time()
    model_evaluate(net, val_loader)
    print(f'total time: {
      
      time.time() - t1:.3f}s')

    # nets, net_names = get_models()
    # for net, name in zip(nets, net_names):
    #     if name == 'vit_L_16':
    #         val_loader = get_dataloader(data_dir, batch_size=16)
    #     val_loader = get_dataloader(data_dir, batch_size=128)
    #
    #     t1 = time.time()
    #     model_evaluate(net, val_loader)
    #     print(f'{name:.10} total time: {time.time()-t1:.3f}s')

参考:
1. Deep Clustering for Unsupervised Learning of Visual Features

2. Pytorch official documentation

Guess you like

Origin blog.csdn.net/qq_44166630/article/details/127757741