The fusion of PyTorch convolution and BatchNorm (PyTorch official)

0 Principle

This is the link


1 Merge BN

Due to the dynamic graph characteristics of PyTorch, there is no way to simply implement smart merging

For the official implementation, you need to specify the name of Conv and BN.

2 Quantization

Currently includes qnnpack and fbgemm two backends

qnnpack 只支持 per tensor
fbgemm 支持 per channel

fvgemm is definitely more accurate, but only supports PC. In addition, fbgemm is faster than QNNPack, which should use the SSE instruction set on the PC CPU.

3 matters needing attention

The implementation of residuel add in torchvision needs to be changed manually:

# In __init__:
    self.skip_add = nn.quantized.FloatFunctional()
# In forward:
    # out += identity
    out = self.skip_add.add(out, identity)

Remember to add Stub
Inplace ReLU for input and output, don’t fuse
qconfig, remember to set

4 results

The torchvision res18 test result on the ImageNet Val collection, the quantitative data set is the val data of Image 1k.

......................................................
Before Q:
Evaluation accuracy 69.22

After Q:
Size of model after quantization
Size (MB): 11.719858
........................................................
Evaluation accuracy 68.97

5 code

import os
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.quantization import QuantStub, DeQuantStub, QConfig

# Official utils
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def evaluate(model, data_loader, cpu=False):
    model.eval()
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    cnt = 0
    with torch.no_grad():
        for image, target in data_loader:
            output = model(image.cuda() if not cpu else image)
            cnt += 1
            acc1, acc5 = accuracy(output, target.cuda() if not cpu else target, topk=(1, 5))
            print('.', end = '')
            top1.update(acc1[0], image.size(0))
            top5.update(acc5[0], image.size(0))
    return top1, top5


def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')


q_backend = "qnnpack"
qconfig = torch.quantization.get_default_qconfig(q_backend)
torch.backends.quantized.engine = q_backend

r18_o = torchvision.models.resnet18(True)
r18_o.eval()

# Do NOT fuse inplaced relu
r18 = torch.quantization.fuse_modules(
    r18_o,
    [['conv1', 'bn1', 'relu'],
     ['layer1.0.conv1', 'layer1.0.bn1'], # , 'layer1.0.relu'],
     ['layer1.0.conv2', 'layer1.0.bn2'],
     ['layer1.1.conv1', 'layer1.1.bn1'], #, 'layer1.1.relu'],
     ['layer1.1.conv2', 'layer1.1.bn2'],

     ['layer2.0.conv1', 'layer2.0.bn1'], #, 'layer2.0.relu'],
     ['layer2.0.conv2', 'layer2.0.bn2'],
     ['layer2.0.downsample.0', 'layer2.0.downsample.1'],
     ['layer2.1.conv1', 'layer2.1.bn1'], #, 'layer2.1.relu'],
     ['layer2.1.conv2', 'layer2.1.bn2'],

     ['layer3.0.conv1', 'layer3.0.bn1'], #, 'layer3.0.relu'],
     ['layer3.0.conv2', 'layer3.0.bn2'],
     ['layer3.0.downsample.0', 'layer3.0.downsample.1'],
     ['layer3.1.conv1', 'layer3.1.bn1'], #, 'layer3.1.relu'],
     ['layer3.1.conv2', 'layer3.1.bn2'],

     ['layer4.0.conv1', 'layer4.0.bn1'], #, 'layer4.0.relu'],
     ['layer4.0.conv2', 'layer4.0.bn2'],
     ['layer4.0.downsample.0', 'layer4.0.downsample.1'],
     ['layer4.1.conv1', 'layer4.1.bn1'], #, 'layer4.1.relu'],
     ['layer4.1.conv2', 'layer4.1.bn2'],
     ]
)


# Append input/output quant/dequant stub.
def replace_forward(module):
    module.quant = QuantStub()
    module.dequant = DeQuantStub()
    raw_forward = module.forward

    def forward(x):
        x = module.quant(x)
        x = raw_forward(x)
        x = module.dequant(x)
        return x
    module.forward = forward

replace_forward(r18)


# 1K dataset
test_db = torchvision.datasets.ImageFolder(
    'imagenet_1k/val',
    transform=transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        # transforms.RandomResizedCrop(224),
        # transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
)

calibration_dataset = torch.utils.data.DataLoader(
    test_db,
    batch_size=256)

# 50K imagenet val
image_net_db = torchvision.datasets.ImageFolder(
    './ILSVRC2012_img_val/',
    transform=transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        # transforms.RandomResizedCrop(224),
        # transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
)

imagenet_50k = torch.utils.data.DataLoader(
    image_net_db,
    batch_size=256)


r18_o = r18_o.cuda()
r18.eval()

# Original network from torchvision
top1, top5 = evaluate(r18_o.cuda(), imagenet_50k)
print('Evaluation accuracy %2.2f'%(top1.avg))


# WARNING: Do NOT forget setting qconfig
r18.qconfig = qconfig

torch.quantization.prepare(r18, inplace=True)
evaluate(r18, calibration_dataset, cpu=True)
print('Post Training Quantization: Calibration done')

# Convert to quantized model
r18 = r18.cpu()
torch.quantization.convert(r18, inplace=True)
print('Post Training Quantization: Convert done')


print("Size of model after quantization")
print_size_of_model(r18)
top1, top5 = evaluate(r18, imagenet_50k, cpu=True)
print('Evaluation accuracy %2.2f'%(top1.avg))

Guess you like

Origin blog.csdn.net/qq_36783816/article/details/112600674