深度神经网络模型剪枝
神经网络剪枝 Neural Network Pruning
下面是我对剪枝的一点点理解,如有理解不到位,请大家指正
▶剪枝只是将模型中权重比较小,对输出影响不大的神经元参数置0,并且实验证明在大多数情况下,权重较小的神经元置0后对模型的最终输出确实影响不大,当然,这也跟剪枝比例和模型规模有关系
▶剪枝并没有减少模型的参数量,间接减少计算量:剪枝只是在置0,但参数仍然以0存在,从底层方面对于一个与0相乘的数,直接输出0不会做过多计算,数据结构中也有这个理论:当 ‘与运算’ 前的表达式结果为False时,该表达式直接为False,不再计算 ‘与运算’ 后的表达式。
▶在一些特殊情况下,权重小的神经元并不一定对输出影响小,因此还有很多其他的剪枝方式
▶对网络中的每一层按比例剪枝,而不是对整个网络剪
拿MNIST跑了下,剪掉部分对模型性能的影响
网络:
self.linear1 = MaskLinear(28*28, 512)
self.bn1 = nn.BatchNorm1d(512)
self.linear2 = MaskLinear(512, 256)
self.bn2 = nn.BatchNorm1d(256)
self.linear3 = MaskLinear(256, 10)
self.act = nn.LeakyReLU()
self.softmax = nn.Softmax(dim=1)
训练了8个epoch
剪枝前:
Test acc: 97.91324%
剪枝 60% 后:
Test acc: 97.20395%
剪枝前后正确率差异: 0.7092927631578902 %
剪枝 50% 后:
Test acc: 97.61513%
剪枝前后正确率差异: 0.2981085526315752 %
剪枝 40% 后:
Test acc: 97.91324%
剪枝前后正确率差异: 0.0 %
剪枝 30% 后:
Test acc: 97.87212%
剪枝前后正确率差异: 0.04111842105263008 %
剪枝 20% 后:
Test acc: 97.87212%
剪枝前后正确率差异: 0.04111842105263008 %
剪枝 10% 后:
Test acc: 97.93380%
剪枝前后正确率差异: 0.02055921052631504 %
下面是对权重小的神经元进行剪枝的代码栗子,感兴趣的可以跑一跑
定义了一个2层的全连接,第一层输入为1, 输出为5, 第二层输入为5,输出为1,然后剪掉40%
先看一些结果
代码:
import torch
from torch import nn
import numpy as np
class MaskLinear(nn.Linear): # 重写nn.Linear 为 MaskLinear
def __init__(self, in_features, out_features, bias=True):
super(MaskLinear, self).__init__(in_features, out_features, bias)
self.mask = None
def set_mask(self, mask): # 根据掩码剪掉权重
self.mask = mask.detach().requires_grad_(False)
self.weight.data = self.weight.data * self.mask.data
self.mask = None
def get_mask(self): # 获取掩码
# print(self.mask_flag)
return self.mask
def forward(self, x):
return nn.functional.linear(x, self.weight, self.bias)
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.linear1 = MaskLinear(1, 5)
self.linear2 = MaskLinear(5, 1)
def forward(self, x):
x = x.view((x.shape[0], -1))
l1 = self.linear1(x)
l2 = self.linear2(l1)
return l2
def set_masks(self, masks):
self.linear1.set_mask(masks[0])
self.linear2.set_mask(masks[1])
def weight_prune(model, pruning_perc): # 根据比例perc计算掩码
threshold_list = []
for param in model.parameters():
if len(param.data.size()) != 1:
weight = param.cpu().data.abs().numpy().flatten()
threshold = np.percentile(weight, pruning_perc)
threshold_list.append(threshold)
masks = []
idx = 0
for param in model.parameters():
if len(param.data.size()) != 1:
pruning_inds = param.data.abs() > threshold_list[idx]
masks.append(pruning_inds.float())
idx += 1
return masks
def train(model, data, target, epoch):
model.train()
opt = torch.optim.Adam(model.parameters())
for ep in range(epoch):
opt.zero_grad()
out = model(data)
loss = torch.mean((out - target) ** 2)
loss.backward()
opt.step()
print('epoch: {1:2d} Train out: {0:2.4f}'.format(out.item(), ep+1))
return model
def main():
model = MLP()
data = torch.tensor([3.0])
model = train(model, data, data, 100)
torch.save(model, 'net.pth')
for name, param in model.named_parameters():
print('{0:15s}: {1}'.format(name, param.data))
print('\n############-▲▲剪枝前▲▲-------▼▼剪枝40%后▼▼-##########\n')
mask = weight_prune(model, 40) # 获取掩码,剪掉40%
model.set_masks(mask) # 剪枝
for name, param in model.named_parameters():
print('{0:15s}: {1}'.format(name, param.data))
if __name__ == '__main__':
main()