[논문이 이해] 네트워크 슬리밍 통해 효율적인 길쌈 네트워크 학습

네트워크 슬리밍 통해 효율적인 길쌈 네트워크 학습

간략한 소개

내가 처음 논문 모델의 압축의 관점에서 무엇을보고 이것은 긴 모델 압축에 관심이 있기 때문에 추첨이 시간을 보았다 있도록 코드 자신을 달성, 더 유명한 한 고려되어야한다입니다 보면, 그것은 아주 쉽게 생각합니다. 이 문서 관련하여, 정리 방법 BN 층, 채널 임계 값에 의해 수행되는 점에서 낮은 점수를 BN 층 필터를 사용할 권리의 저자 재평가 점수 입력 채널을 치기위한 압축 문제의 모델을 제시 할 것이다 이들 뉴런 채널 스코어 연결에 참여하지 너무 작을 때, 압축 효과를 달성하기 위해, 레이어별로 치기.

저 개인적으로, 지금은 일반적주의 메커니즘에 사용을 위해 나는 소란을 만들 수 있습니다 채널의 점수를 평가하는 데 사용할 수 있습니다 생각하지만, 확실히 특정 작업의 관점에서, 나는주의 메커니즘 모델을 사용하여, 하나 개의 실험을 위해 자신을 다시 것 가지 치기.

방법

도 본원의 방법에 도시 된 바와 같이, 즉,

  1. 층의 소정 비율을 유지하기 위해, 오른쪽의 중량 비율보다 큰 BN 층 모두 적어
  2. 정리의 모델 BN 층, 즉, 파라미터의 중량비는 상기 가중치들을 폐기 미만인
  3. 모델 층의 컨벌루션 자두, 크기 수행 (층은 일반적으로 회선 + BN 때문에, 우리는 BN 층 컨벌루션 가중치 층 크기의 전후를 알 수있다) 알고 전 컨볼 루션 정합 층 폐기 BN 대응하는 채널 요소 후 가지 치기.
  4. 은 FC 층 치기

그는 느낌이 명확하지 않다,하지만 코드를 살펴 전체를 이해했다. .

코드

'''
这是对vgg的剪枝例子,文章中说了对其他网络的slimming例子
'''
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.models import vgg19
from models import *


# Prune settings
parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune')
parser.add_argument('--dataset', type=str, default='cifar100',
                    help='training dataset (default: cifar10)')
parser.add_argument('--test-batch-size', type=int, default=256, metavar='N',
                    help='input batch size for testing (default: 256)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--depth', type=int, default=19,
                    help='depth of the vgg')
parser.add_argument('--percent', type=float, default=0.5,
                    help='scale sparse rate (default: 0.5)')
parser.add_argument('--model', default='', type=str, metavar='PATH',
                    help='path to the model (default: none)')
parser.add_argument('--save', default='', type=str, metavar='PATH',
                    help='path to save pruned model (default: none)')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

if not os.path.exists(args.save):
    os.makedirs(args.save)

model = vgg19(dataset=args.dataset, depth=args.depth)
if args.cuda:
    model.cuda()

if args.model:
    if os.path.isfile(args.model):
        print("=> loading checkpoint '{}'".format(args.model))
        checkpoint = torch.load(args.model)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
              .format(args.model, checkpoint['epoch'], best_prec1))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))

print(model)
total = 0
for m in model.modules():# 遍历vgg的每个module
    if isinstance(m, nn.BatchNorm2d): # 如果发现BN层
        total += m.weight.data.shape[0] # BN层的特征数目,total就是所有BN层的特征数目总和

bn = torch.zeros(total)
index = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        size = m.weight.data.shape[0]
        bn[index:(index+size)] = m.weight.data.abs().clone()
        index += size # 把所有BN层的权重给CLONE下来

y, i = torch.sort(bn) # 这些权重排序
thre_index = int(total * args.percent) # 要保留的数量
thre = y[thre_index] # 最小的权重值

pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):
    if isinstance(m, nn.BatchNorm2d):
        weight_copy = m.weight.data.abs().clone()
        mask = weight_copy.gt(thre).float().cuda()# 小于权重thre的为0,大于的为1
        pruned = pruned + mask.shape[0] - torch.sum(mask) # 被剪枝的权重的总数
        m.weight.data.mul_(mask) # 权重对应相乘
        m.bias.data.mul_(mask) # 偏置也对应相乘
        cfg.append(int(torch.sum(mask))) #第几个batchnorm保留多少。
        cfg_mask.append(mask.clone()) # 第几个batchnorm 保留的weight
        print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
            format(k, mask.shape[0], int(torch.sum(mask))))
    elif isinstance(m, nn.MaxPool2d):
        cfg.append('M')

pruned_ratio = pruned/total # 剪枝比例

print('Pre-processing Successful!')

# simple test model after Pre-processing prune (simple set BN scales to zeros)
def test(model):
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    if args.dataset == 'cifar10':
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=args.test_batch_size, shuffle=True, **kwargs)
    elif args.dataset == 'cifar100':
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=args.test_batch_size, shuffle=True, **kwargs)
    else:
        raise ValueError("No valid dataset is given.")
    model.eval()
    correct = 0
    for data, target in test_loader:
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
        correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
    return correct / float(len(test_loader.dataset))

acc = test(model)

# Make real prune
print(cfg)
newmodel = vgg(dataset=args.dataset, cfg=cfg)
if args.cuda:
    newmodel.cuda()
# torch.nelement() 可以统计张量的个数
num_parameters = sum([param.nelement() for param in newmodel.parameters()]) # 元素个数,比如对于张量shape为(20,3,3,3),那么他的元素个数就是四者乘积也就是20*27 = 540 
# 可以用来统计参数量 嘿嘿
savepath = os.path.join(args.save, "prune.txt")
with open(savepath, "w") as fp:
    fp.write("Configuration: \n"+str(cfg)+"\n")
    fp.write("Number of parameters: \n"+str(num_parameters)+"\n")
    fp.write("Test accuracy: \n"+str(acc))

layer_id_in_cfg = 0 # 第几层
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg] # 
for [m0, m1] in zip(model.modules(), newmodel.modules()):
    if isinstance(m0, nn.BatchNorm2d):
        # np.where 返回的是所有满足条件的数的索引,有多少个满足条件的数就有多少个索引,绝对的索引
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) # 大于0的所有数据的索引,squeeze变成向量
        if idx1.size == 1: # 只有一个要变成数组的1个
            idx1 = np.resize(idx1,(1,))
        m1.weight.data = m0.weight.data[idx1.tolist()].clone() # 用经过剪枝的替换原来的
        m1.bias.data = m0.bias.data[idx1.tolist()].clone()
        m1.running_mean = m0.running_mean[idx1.tolist()].clone()
        m1.running_var = m0.running_var[idx1.tolist()].clone()
        layer_id_in_cfg += 1 # 下一层
        start_mask = end_mask.clone() # 当前在处理的层的mask
        if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FC
            end_mask = cfg_mask[layer_id_in_cfg]
    elif isinstance(m0, nn.Conv2d): # 对卷积层进行剪枝
        # 卷积后面会接bn
        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
        print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
        if idx0.size == 1:
            idx0 = np.resize(idx0, (1,))
        if idx1.size == 1:
            idx1 = np.resize(idx1, (1,))
        w1 = m0.weight.data[:, idx0.tolist(), :, :].clone() # 这个剪枝牛B了。。
        w1 = w1[idx1.tolist(), :, :, :].clone() # 最终的权重矩阵
        m1.weight.data = w1.clone()
    elif isinstance(m0, nn.Linear):
        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
        if idx0.size == 1:
            idx0 = np.resize(idx0, (1,))
        m1.weight.data = m0.weight.data[:, idx0].clone()
        m1.bias.data = m0.bias.data.clone()

torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save, 'pruned.pth.tar'))

print(newmodel)
model = newmodel
test(model)

추천

출처www.cnblogs.com/aoru45/p/11614636.html