Pytorch加速与优化:超参数调优、量化、剪枝

前言

  • 本文是个人使用Pytorch进行超参数调优、量化、剪枝的电子笔记,由于水平有限,难免出现错漏,敬请批评改正。
  • 更多精彩内容,可点击进入我的个人主页查看

前提条件

相关介绍

  • Python是一种跨平台的计算机程序设计语言。是一个高层次的结合了解释性、编译性、互动性和面向对象的脚本语言。最初被设计用于编写自动化脚本(shell),随着版本的不断更新和语言新功能的添加,越多被用于独立的、大型项目的开发。
  • PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。2017年1月,由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch。它是一个基于Python的可续计算包,提供两个高级功能:1、具有强大的GPU加速的张量计算(如NumPy)。2、包含自动求导系统的深度神经网络。

实验环境

  • Python 3.x (面向对象的高级语言)
  • PyTorch(Python第三方库)

超参数调优(hyper parameters)

  • 超参数(hyper parameters):在深度学习模型,需要人为设置的参数,比如学习率lr和批次大小batch_size。
  • 在Python中,有一个Ray Tune的包可以管理超参数调优。
pip install tensorboardX
pip install ray
Requirement already satisfied: tensorboardX in /opt/conda/lib/python3.7/site-packages (2.5.1)
Collecting protobuf<=3.20.1,>=3.8.0
  Downloading protobuf-3.20.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hRequirement already satisfied: numpy in /opt/conda/lib/python3.7/site-packages (from tensorboardX) (1.21.6)
Installing collected packages: protobuf
  Attempting uninstall: protobuf
    Found existing installation: protobuf 3.20.3
    Uninstalling protobuf-3.20.3:
      Successfully uninstalled protobuf-3.20.3
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-io 0.21.0 requires tensorflow-io-gcs-filesystem==0.21.0, which is not installed.
beatrix-jupyterlab 3.1.7 requires google-cloud-bigquery-storage, which is not installed.
tfx-bsl 1.9.0 requires google-api-python-client<2,>=1.7.11, but you have google-api-python-client 2.52.0 which is incompatible.
tfx-bsl 1.9.0 requires tensorflow!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,!=2.8.*,<3,>=1.15.5, but you have tensorflow 2.6.4 which is incompatible.
tensorflow 2.6.4 requires h5py~=3.1.0, but you have h5py 3.7.0 which is incompatible.
tensorflow 2.6.4 requires numpy~=1.19.2, but you have numpy 1.21.6 which is incompatible.
tensorflow 2.6.4 requires typing-extensions<3.11,>=3.7, but you have typing-extensions 4.1.1 which is incompatible.
tensorflow-transform 1.9.0 requires tensorflow!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,!=2.8.*,<2.10,>=1.15.5, but you have tensorflow 2.6.4 which is incompatible.
tensorflow-serving-api 2.9.0 requires tensorflow<3,>=2.9.0, but you have tensorflow 2.6.4 which is incompatible.
ortools 9.5.2237 requires protobuf>=4.21.5, but you have protobuf 3.20.1 which is incompatible.
onnx 1.13.0 requires protobuf<4,>=3.20.2, but you have protobuf 3.20.1 which is incompatible.
nnabla 1.32.1 requires protobuf<=3.19.4; platform_system != "Windows", but you have protobuf 3.20.1 which is incompatible.
google-api-core 1.33.2 requires protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<4.0.0dev,>=3.19.5, but you have protobuf 3.20.1 which is incompatible.
gcsfs 2022.5.0 requires fsspec==2022.5.0, but you have fsspec 2023.1.0 which is incompatible.
apache-beam 2.40.0 requires dill<0.3.2,>=0.3.1.1, but you have dill 0.3.6 which is incompatible.[0m[31m
[0mSuccessfully installed protobuf-3.20.1
[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv[0m[33m
[0mRequirement already satisfied: ray in /opt/conda/lib/python3.7/site-packages (2.2.0)
Requirement already satisfied: attrs in /opt/conda/lib/python3.7/site-packages (from ray) (21.4.0)
Requirement already satisfied: requests in /opt/conda/lib/python3.7/site-packages (from ray) (2.28.1)
Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /opt/conda/lib/python3.7/site-packages (from ray) (1.0.4)
Requirement already satisfied: grpcio>=1.32.0 in /opt/conda/lib/python3.7/site-packages (from ray) (1.51.1)
Requirement already satisfied: aiosignal in /opt/conda/lib/python3.7/site-packages (from ray) (1.2.0)
Requirement already satisfied: click>=7.0 in /opt/conda/lib/python3.7/site-packages (from ray) (8.1.3)
Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /opt/conda/lib/python3.7/site-packages (from ray) (3.20.1)
Requirement already satisfied: numpy>=1.16 in /opt/conda/lib/python3.7/site-packages (from ray) (1.21.6)
Requirement already satisfied: filelock in /opt/conda/lib/python3.7/site-packages (from ray) (3.7.1)
Requirement already satisfied: jsonschema in /opt/conda/lib/python3.7/site-packages (from ray) (4.6.1)
Requirement already satisfied: virtualenv>=20.0.24 in /opt/conda/lib/python3.7/site-packages (from ray) (20.17.1)
Requirement already satisfied: pyyaml in /opt/conda/lib/python3.7/site-packages (from ray) (6.0)
Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.7/site-packages (from ray) (4.1.1)
Requirement already satisfied: frozenlist in /opt/conda/lib/python3.7/site-packages (from ray) (1.3.0)
Requirement already satisfied: importlib-metadata in /opt/conda/lib/python3.7/site-packages (from click>=7.0->ray) (6.0.0)
Requirement already satisfied: distlib<1,>=0.3.6 in /opt/conda/lib/python3.7/site-packages (from virtualenv>=20.0.24->ray) (0.3.6)
Requirement already satisfied: platformdirs<3,>=2.4 in /opt/conda/lib/python3.7/site-packages (from virtualenv>=20.0.24->ray) (2.5.1)
Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /opt/conda/lib/python3.7/site-packages (from jsonschema->ray) (0.18.1)
Requirement already satisfied: importlib-resources>=1.4.0 in /opt/conda/lib/python3.7/site-packages (from jsonschema->ray) (5.10.2)
Requirement already satisfied: charset-normalizer<3,>=2 in /opt/conda/lib/python3.7/site-packages (from requests->ray) (2.1.0)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests->ray) (3.3)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests->ray) (1.26.14)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests->ray) (2022.12.7)
Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata->click>=7.0->ray) (3.8.0)
[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv[0m[33m
[0m
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self, nodes_1=120, nodes_2=84):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, nodes_1) # 配置fc1中的节点
        self.fc2 = nn.Linear(nodes_1, nodes_2) # 配置fc2中的节点
        self.fc3 = nn.Linear(nodes_2, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

定义超参数配置

from ray import tune
import numpy as np

config = {
    
    
  "nodes_1": tune.sample_from(
      lambda _: 2 ** np.random.randint(2, 9)), # tune.sample_from()和lambda函数来定义搜索空间
  "nodes_2": tune.sample_from(
      lambda _: 2 ** np.random.randint(2, 9)),
  "lr": tune.loguniform(1e-4, 1e-1),
  "batch_size": tune.choice([2, 4, 8, 16])  
  }
import torch
import torchvision
from torchvision import transforms

def load_data(data_dir="./data"):
    train_transforms = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(
                      (0.4914, 0.4822, 0.4465),
                      (0.2023, 0.1994, 0.2010))])
    
    test_transforms = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize(
                        (0.4914, 0.4822, 0.4465),
                        (0.2023, 0.1994, 0.2010))])

    trainset = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, 
        download=True, transform=train_transforms)

    testset = torchvision.datasets.CIFAR10(
        root=data_dir, train=False, 
        download=True, transform=test_transforms)

    return trainset, testset
from torch import optim
from torch import nn
from torch.utils.data import random_split

def train_model(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = Net(config['nodes_1'],config['nodes_2']).to(device=device) # 可配置模型层

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(),
                        lr=config['lr'],
                        momentum=0.9) # 可配置学习率

    trainset, testset = load_data()

    test_abs = int(len(trainset) * 0.8)
    train_subset, val_subset = random_split(
                                trainset, 
                                [test_abs, len(trainset) - test_abs])

    trainloader = torch.utils.data.DataLoader(
        train_subset,
        batch_size=int(config["batch_size"]),
        shuffle=True) # 可配置批次大小

    valloader = torch.utils.data.DataLoader(
        val_subset,
        batch_size=int(config["batch_size"]),
        shuffle=True) # 可配置批次大小

    for epoch in range(10):
        train_loss = 0.0
        epoch_steps = 0
        for data in trainloader:
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        val_loss = 0.0
        total = 0
        correct = 0
        for data in valloader:
            with torch.no_grad(): # 临时将所有的require_grad设为False
                inputs, labels = data
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                _, predicted = torch.max(
                          outputs.data, 1)
                total += labels.size(0)
                correct += \
                (predicted == labels).sum().item()

                loss = criterion(outputs, labels)
                val_loss += loss.cpu().numpy()

        print(f'epoch: {
      
      epoch} ',
            f'train_loss: ',
            f'{
      
      train_loss/len(trainloader)}',
            f'val_loss: ',
            f'{
      
      val_loss/len(valloader)}',
            f'val_acc: {
      
      correct/total}')
        tune.report(loss=(val_loss / len(valloader)),
                    accuracy=correct / total)

在运行Ray Tune之前,需要使用调度器和报告器。调度器(scheduler)用于搜索和选择超参数。报告器(reporter)用于查看结果。

from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler

# 调度器,这里使用异步逐次减半算法(asynchronous successive halving algorithm,ASHA)
scheduler = ASHAScheduler(
    metric="loss", # 指定要损失
    mode="min", # 最小化损失
    max_t=1, # 最大周期数
    grace_period=1,
    reduction_factor=2)

# 报告器,这里配置一个CLI报告器来报告损失、精度、训练迭代和每次运行是CLI上选择的超参数。
reporter = CLIReporter(
    metric_columns=["loss", 
                    "accuracy", 
                    "training_iteration"])

使用runn()方法运行Ray Tune

from functools import partial

result = tune.run(
    partial(train_model),
    # resources_per_trial={"cpu": 2, "gpu": 1}, # 每次训练的资源数
    resources_per_trial={
    
    "cpu": 1, "gpu": 2}, 
    config=config,
    num_samples=10, # 测试样本数
    scheduler=scheduler,
    progress_reporter=reporter)
2023-02-02 02:05:24,885	INFO worker.py:1538 -- Started a local Ray instance.


== Status ==
Current time: 2023-02-02 02:05:27 (running for 00:00:00.63)
Memory usage on this node: 1.7/15.6 GiB 
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 1.000: None
Resources requested: 1.0/2 CPUs, 2.0/2 GPUs, 0.0/7.15 GiB heap, 0.0/3.58 GiB objects (0.0/1.0 accelerator_type:T4)
Result logdir: /root/ray_results/train_model_2023-02-02_02-05-26
Number of trials: 10/10 (9 PENDING, 1 RUNNING)
+-------------------------+----------+----------------+--------------+-------------+-----------+-----------+
| Trial name              | status   | loc            |   batch_size |          lr |   nodes_1 |   nodes_2 |
|-------------------------+----------+----------------+--------------+-------------+-----------+-----------|
| train_model_0f495_00000 | RUNNING  | 172.19.2.2:732 |            8 | 0.00208745  |        64 |       256 |
| train_model_0f495_00001 | PENDING  |                |            8 | 0.011537    |        64 |        32 |
| train_model_0f495_00002 | PENDING  |                |            2 | 0.000202415 |        64 |       128 |
| train_model_0f495_00003 | PENDING  |                |            8 | 0.000397489 |       128 |         8 |
| train_model_0f495_00004 | PENDING  |                |           16 | 0.000670083 |        16 |       128 |
| train_model_0f495_00005 | PENDING  |                |           16 | 0.00385978  |         8 |        32 |
| train_model_0f495_00006 | PENDING  |                |           16 | 0.0461144   |        64 |        64 |
| train_model_0f495_00007 | PENDING  |                |            2 | 0.0169714   |         4 |       256 |
| train_model_0f495_00008 | PENDING  |                |            2 | 0.00162063  |        64 |         8 |
| train_model_0f495_00009 | PENDING  |                |           16 | 0.00110084  |        64 |       128 |
+-------------------------+----------+----------------+--------------+-------------+-----------+-----------+


  0%|          | 0/170498071 [00:00<?, ?it/s]
  0%|          | 440320/170498071 [00:00<00:42, 4010336.03it/s]
  4%|▎         | 6168576/170498071 [00:00<00:04, 34129025.00it/s]
  9%|▊         | 14848000/170498071 [00:00<00:02, 57656698.87it/s]
 14%|█▍        | 24336384/170498071 [00:00<00:02, 72154733.40it/s]
 19%|█▉        | 33229824/170498071 [00:00<00:01, 78153262.54it/s]
 25%|██▌       | 42682368/170498071 [00:00<00:01, 83684904.43it/s]
 30%|███       | 51696640/170498071 [00:00<00:01, 85787329.41it/s]
 36%|███▌      | 61095936/170498071 [00:00<00:01, 88391595.04it/s]
 41%|████      | 70256640/170498071 [00:00<00:01, 89105431.15it/s]
 47%|████▋     | 80496640/170498071 [00:01<00:00, 93190171.60it/s]
 53%|█████▎    | 90400768/170498071 [00:01<00:00, 94972310.31it/s]
 59%|█████▊    | 99906560/170498071 [00:01<00:00, 94834848.89it/s]
 64%|██████▍   | 109395968/170498071 [00:01<00:00, 94666500.13it/s]
 70%|██████▉   | 118912000/170498071 [00:01<00:00, 94807265.06it/s]
 75%|███████▌  | 128464896/170498071 [00:01<00:00, 95015498.17it/s]
 81%|████████  | 138288128/170498071 [00:01<00:00, 95976600.47it/s]
 87%|████████▋ | 148819968/170498071 [00:01<00:00, 98781379.96it/s]
 93%|█████████▎| 158700544/170498071 [00:01<00:00, 97807065.64it/s]
170499072it [00:01, 88559278.31it/s]                                

2023-02-02 02:16:05,160	INFO tune.py:763 -- Total run time: 638.90 seconds (638.46 seconds for the tuning loop).


[2m[36m(func pid=1840)[0m epoch: 0  train_loss:  1.9899510860443115 val_loss:  1.7621070608139038 val_acc: 0.3433
== Status ==
Current time: 2023-02-02 02:16:05 (running for 00:10:38.48)
Memory usage on this node: 4.4/15.6 GiB 
Using AsyncHyperBand: num_stopped=10
Bracket: Iter 1.000: -1.812330790552497
Resources requested: 0/2 CPUs, 0/2 GPUs, 0.0/7.15 GiB heap, 0.0/3.58 GiB objects (0.0/1.0 accelerator_type:T4)
Result logdir: /root/ray_results/train_model_2023-02-02_02-05-26
Number of trials: 10/10 (10 TERMINATED)
+-------------------------+------------+-----------------+--------------+-------------+-----------+-----------+---------+------------+----------------------+
| Trial name              | status     | loc             |   batch_size |          lr |   nodes_1 |   nodes_2 |    loss |   accuracy |   training_iteration |
|-------------------------+------------+-----------------+--------------+-------------+-----------+-----------+---------+------------+----------------------|
| train_model_0f495_00000 | TERMINATED | 172.19.2.2:732  |            8 | 0.00208745  |        64 |       256 | 1.60968 |     0.4072 |                    1 |
| train_model_0f495_00001 | TERMINATED | 172.19.2.2:850  |            8 | 0.011537    |        64 |        32 | 2.11739 |     0.182  |                    1 |
| train_model_0f495_00002 | TERMINATED | 172.19.2.2:968  |            2 | 0.000202415 |        64 |       128 | 1.6399  |     0.3874 |                    1 |
| train_model_0f495_00003 | TERMINATED | 172.19.2.2:1114 |            8 | 0.000397489 |       128 |         8 | 1.84503 |     0.2984 |                    1 |
| train_model_0f495_00004 | TERMINATED | 172.19.2.2:1231 |           16 | 0.000670083 |        16 |       128 | 1.83333 |     0.3142 |                    1 |
| train_model_0f495_00005 | TERMINATED | 172.19.2.2:1333 |           16 | 0.00385978  |         8 |        32 | 1.69983 |     0.3735 |                    1 |
| train_model_0f495_00006 | TERMINATED | 172.19.2.2:1441 |           16 | 0.0461144   |        64 |        64 | 2.31054 |     0.1009 |                    1 |
| train_model_0f495_00007 | TERMINATED | 172.19.2.2:1550 |            2 | 0.0169714   |         4 |       256 | 2.31998 |     0.099  |                    1 |
| train_model_0f495_00008 | TERMINATED | 172.19.2.2:1691 |            2 | 0.00162063  |        64 |         8 | 1.79133 |     0.3197 |                    1 |
| train_model_0f495_00009 | TERMINATED | 172.19.2.2:1840 |           16 | 0.00110084  |        64 |       128 | 1.76211 |     0.3433 |                    1 |
+-------------------------+------------+-----------------+--------------+-------------+-----------+-----------+---------+------------+----------------------+


best_trial = result.get_best_trial(
    "loss", "min", "last")
print("Best trial config: {}".format(
    best_trial.config))
print("Best trial final validation loss:",
      "{}".format(
          best_trial.last_result["loss"]))
print("Best trial final validation accuracy:",
      "{}".format(
          best_trial.last_result["accuracy"]))
Best trial config: {'nodes_1': 64, 'nodes_2': 256, 'lr': 0.0020874538538687972, 'batch_size': 8}
Best trial final validation loss: 1.6096843579292297
Best trial final validation accuracy: 0.4072

量化(quantization)

  • 量化是指用较低精度的数据计算和访问内存技术
import torch 
from torch import nn
import torch.nn.functional as F

class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(
            F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(
            F.relu(self.conv2(x)), 2)
        x = x.view(-1, 
                   int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet5()
for n, p in model.named_parameters():
    print(n, ": ", p.dtype)
conv1.weight :  torch.float32
conv1.bias :  torch.float32
conv2.weight :  torch.float32
conv2.bias :  torch.float32
fc1.weight :  torch.float32
fc1.bias :  torch.float32
fc2.weight :  torch.float32
fc2.bias :  torch.float32
fc3.weight :  torch.float32
fc3.bias :  torch.float32
  • 最快捷的量化方法:将是所有计算精度减半
model = model.half() # 模型精度减半

for n, p in model.named_parameters():
    print(n, ": ", p.dtype)
conv1.weight :  torch.float16
conv1.bias :  torch.float16
conv2.weight :  torch.float16
conv2.bias :  torch.float16
fc1.weight :  torch.float16
fc1.bias :  torch.float16
fc2.weight :  torch.float16
fc2.bias :  torch.float16
fc3.weight :  torch.float16
fc3.bias :  torch.float16
  • 实际上,我们一般不会用同样的方式量化每一个计算。而且,float16可能还不够,还需量化为更低的精度。
  • Pytorch提供了另外3种量化模式:动态量化(Dynamic quantization)、后训练静态量化(Post-training static quantization)和量化感知训练(quantization-aware training,QAT)。

动态量化(Dynamic quantization)

  • 动态量化(Dynamic quantization)是最简单的一类量化。其动态地将激活转化为int8。计算中使用int8值,但会按浮点数格式向内存读写激活。
import torch.quantization

quantized_model = torch.quantization.quantize_dynamic(model, {
    
    torch.nn.Linear}, dtype=torch.qint8)
for n, p in quantized_model.named_parameters():
    print(n, ": ", p.dtype)
conv1.weight :  torch.float16
conv1.bias :  torch.float16
conv2.weight :  torch.float16
conv2.bias :  torch.float16

后训练静态量化(Post-training static quantization)

  • 后训练静态量化(Post-training static quantization)可以用来进一步减少延迟,其会观察训练中不同激活的分布,并决定推理时应当如何量化这些激活。这种量化允许我们在操作之间传递量化值,而不用在内存中来回转化float和int。
  • 注:量化依赖于用来运行量化模型的后端。目前,对于CPU推理,只有x86(fbgemm)和ARM(qnnpack)支持量化操作。不过,后面的量化感知训练(quantization-aware training,QAT)使用全浮点数,在GPU和CPU上都能运行。
static_quant_model = LeNet5()
static_quant_model.qconfig = torch.quantization.get_default_qconfig('fbgemm')

torch.quantization.prepare(static_quant_model, inplace=True)
torch.quantization.convert(static_quant_model, inplace=True)
LeNet5(
  (conv1): QuantizedConv2d(3, 6, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0)
  (conv2): QuantizedConv2d(6, 16, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0)
  (fc1): QuantizedLinear(in_features=400, out_features=120, scale=1.0, zero_point=0, qscheme=torch.per_channel_affine)
  (fc2): QuantizedLinear(in_features=120, out_features=84, scale=1.0, zero_point=0, qscheme=torch.per_channel_affine)
  (fc3): QuantizedLinear(in_features=84, out_features=10, scale=1.0, zero_point=0, qscheme=torch.per_channel_affine)
)

量化感知训练(quantization-aware training,QAT)

  • 量化感知训练(quantization-aware training,QAT)通常可以得到最好的精度。在这种情况下,所有的权重和激活会在训练的前向和后向传播中“假量化”(fake quantized)。Float值取整为相应的int8,不过,计算仍用浮点数完成,即,会让权重调整“感知到”将在训练期间量化。
qat_model = LeNet5()
qat_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')

torch.quantization.prepare_qat(qat_model, inplace=True)
torch.quantization.convert(qat_model, inplace=True)
/opt/conda/lib/python3.7/site-packages/torch/ao/quantization/utils.py:211: UserWarning: must run observer before calling calculate_qparams. Returning default values.
  "must run observer before calling calculate_qparams. " +





LeNet5(
  (conv1): QuantizedConv2d(3, 6, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0)
  (conv2): QuantizedConv2d(6, 16, kernel_size=(5, 5), stride=(1, 1), scale=1.0, zero_point=0)
  (fc1): QuantizedLinear(in_features=400, out_features=120, scale=1.0, zero_point=0, qscheme=torch.per_channel_affine)
  (fc2): QuantizedLinear(in_features=120, out_features=84, scale=1.0, zero_point=0, qscheme=torch.per_channel_affine)
  (fc3): QuantizedLinear(in_features=84, out_features=10, scale=1.0, zero_point=0, qscheme=torch.per_channel_affine)
)
  • 注:Pytorch的量化功能还在继续开发中,目前还处于beta测试阶段。

剪枝(Pruning)

  • 剪枝(Pruning)是建设模型参数个数而且对性能影响最小的一种技术。这使得可以用更小的内存、更小的处理器和更少的硬件资源来部署模型。
import torch 
from torch import nn
import torch.nn.functional as F

class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(
            F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(
            F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
  • LeNet5有5个子模块:conv1、conv2、fc1、fc2、fc3。模型参数包括权重和偏置,可以用named_parameters()方法查看这些参数。
device = torch.device("cuda" if 
    torch.cuda.is_available() else "cpu")
model = LeNet5().to(device)

print(list(model.conv1.named_parameters()))
[('weight', Parameter containing:
tensor([[[[ 0.0807, -0.0330, -0.0133,  0.0424,  0.0620],
          [ 0.0338,  0.1058, -0.1049, -0.0152, -0.0697],
          [-0.0215, -0.1002,  0.0803,  0.0423, -0.0491],
          [ 0.0769,  0.0831, -0.0053,  0.0519,  0.0787],
          [ 0.0449,  0.0963,  0.1036, -0.0119,  0.0780]],

         [[ 0.0206,  0.0409, -0.0407, -0.0231, -0.0977],
          [-0.0069,  0.0188, -0.0466, -0.0172,  0.0372],
          [-0.0804,  0.0902,  0.1082, -0.0192,  0.0477],
          [ 0.0057,  0.0447, -0.0272, -0.1057,  0.1135],
          [ 0.1046,  0.0197, -0.0288, -0.0803,  0.0797]],

         [[ 0.0961, -0.0309, -0.0433,  0.0510, -0.0408],
          [ 0.0218, -0.0093,  0.0297, -0.0055,  0.0561],
          [ 0.0161, -0.0166,  0.0739, -0.0938,  0.0317],
          [-0.0573,  0.0727, -0.0758, -0.0565,  0.0878],
          [-0.0913, -0.0770, -0.0225,  0.0828,  0.1036]]],

        [[[-0.0178,  0.1112,  0.0027,  0.0701, -0.0215],
          [ 0.0193, -0.1126, -0.0067, -0.0459, -0.0953],
          [-0.0825, -0.0526,  0.0168,  0.0145, -0.0125],
          [-0.0877,  0.0207,  0.0051, -0.0489,  0.0720],
          [ 0.0074,  0.0232, -0.0267, -0.0912, -0.0016]],

         [[ 0.1091,  0.0140,  0.0271, -0.0390, -0.0958],
          [ 0.0068,  0.0734, -0.0895,  0.0667, -0.0704],
          [ 0.0640,  0.0240, -0.0811, -0.1071, -0.0046],
          [-0.0286, -0.0557,  0.0219, -0.0797,  0.0399],
          [ 0.0951, -0.0194,  0.0160, -0.1102,  0.0037]],

         [[ 0.0625,  0.0565,  0.1011, -0.0599, -0.0048],
          [-0.0233, -0.0210,  0.0191,  0.0663, -0.0904],
          [ 0.1000, -0.0677, -0.0137,  0.0629,  0.1139],
          [-0.0315,  0.0504, -0.1096, -0.0365, -0.0279],
          [-0.0512,  0.0821, -0.0359,  0.0349, -0.0828]]],

        [[[-0.0085,  0.0708,  0.0927, -0.0134,  0.1040],
          [ 0.1011,  0.0380, -0.0932,  0.0248, -0.0573],
          [ 0.0597,  0.0865, -0.0899,  0.0878,  0.1042],
          [-0.0423,  0.0050, -0.0296, -0.0998,  0.0412],
          [ 0.0276,  0.0230,  0.0052,  0.0527,  0.0328]],

         [[-0.0116,  0.0606,  0.0782,  0.1016, -0.0558],
          [-0.0879, -0.0913,  0.0039, -0.0486,  0.0302],
          [-0.1125,  0.0397,  0.1011,  0.1051, -0.0013],
          [ 0.0604,  0.0398, -0.0025, -0.0450,  0.0254],
          [-0.0317, -0.0395,  0.0556, -0.0077, -0.0087]],

         [[-0.0811,  0.1145, -0.0649, -0.0265,  0.1032],
          [ 0.0794, -0.0024, -0.0237,  0.0598, -0.0944],
          [ 0.1095, -0.0970, -0.0178, -0.0926,  0.0684],
          [ 0.0907,  0.0652, -0.0588,  0.0637,  0.0302],
          [ 0.1132, -0.0547,  0.0659,  0.0479,  0.1095]]],

        [[[ 0.0822, -0.0710,  0.0067,  0.0500,  0.0274],
          [-0.0423, -0.0655,  0.0858,  0.0685,  0.1024],
          [-0.0693, -0.0567,  0.0308,  0.0589,  0.0455],
          [ 0.0904, -0.0133, -0.0870, -0.0671,  0.1025],
          [-0.0686, -0.0085,  0.0624,  0.1017, -0.0239]],

         [[ 0.0907, -0.0579,  0.0706,  0.0307, -0.1153],
          [-0.0122, -0.0377, -0.0445, -0.0538,  0.0338],
          [-0.0725, -0.1115,  0.0604, -0.0136,  0.0975],
          [ 0.0648,  0.0492, -0.0770,  0.0845,  0.0173],
          [-0.0533,  0.0212,  0.0801, -0.1113,  0.0864]],

         [[-0.0126, -0.0099,  0.0226, -0.1111,  0.0698],
          [ 0.0987, -0.0507,  0.0460,  0.0509, -0.1049],
          [ 0.0899,  0.0256, -0.0954, -0.0310,  0.1025],
          [-0.0658, -0.0842, -0.0705, -0.0690, -0.0596],
          [ 0.0873,  0.0355, -0.0280,  0.0308,  0.0801]]],

        [[[ 0.0558,  0.0660, -0.0859,  0.0719, -0.0570],
          [-0.0832,  0.1147,  0.0418, -0.0291, -0.0384],
          [-0.1143,  0.0522,  0.0428,  0.0614, -0.0119],
          [ 0.0641,  0.0930,  0.0407, -0.0353, -0.0657],
          [ 0.0042,  0.0132, -0.0557, -0.0803, -0.0464]],

         [[-0.0611, -0.0598, -0.0383,  0.0453,  0.0462],
          [ 0.1045, -0.0514, -0.0189, -0.0014, -0.0054],
          [-0.0372,  0.0966,  0.0741,  0.0870,  0.1023],
          [-0.0117, -0.0157, -0.1145,  0.0599, -0.0392],
          [-0.0648,  0.0903,  0.0471, -0.0930, -0.1113]],

         [[-0.0528,  0.0461, -0.0693, -0.0424,  0.0825],
          [-0.0244, -0.0363,  0.0469,  0.0252, -0.0127],
          [ 0.0590,  0.0485,  0.0280, -0.0457,  0.0224],
          [-0.0290, -0.0319,  0.0266, -0.1103,  0.0002],
          [-0.1103, -0.0315,  0.0587, -0.0035, -0.0100]]],

        [[[ 0.0358, -0.0845, -0.1016,  0.1149,  0.0869],
          [ 0.0829, -0.0099,  0.0339, -0.1071,  0.0679],
          [ 0.0901, -0.0212,  0.0468,  0.0042, -0.0929],
          [-0.0648, -0.0580, -0.0112, -0.0113, -0.0682],
          [ 0.0406, -0.0807,  0.0634,  0.0170, -0.1031]],

         [[-0.0955, -0.0185, -0.0148,  0.0005, -0.0372],
          [-0.0207,  0.1041, -0.0922,  0.0103,  0.0424],
          [-0.0581,  0.1128,  0.0292,  0.0042, -0.0814],
          [ 0.0882, -0.0714, -0.0918, -0.1019, -0.0829],
          [ 0.0179,  0.0246, -0.0940,  0.0159,  0.0944]],

         [[ 0.0258,  0.0743,  0.0390, -0.1051, -0.0090],
          [-0.0187, -0.0850, -0.0034, -0.0107, -0.0168],
          [-0.0350,  0.0346,  0.0705, -0.0884,  0.0876],
          [-0.0850, -0.0734, -0.1152,  0.0609, -0.1100],
          [ 0.0363, -0.0489, -0.0183, -0.0161,  0.0226]]]], requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.0973,  0.0784,  0.0344, -0.0536,  0.0964, -0.0110],
       requires_grad=True))]
  • 局部剪枝(Local pruning)是指只剪枝模型中一个指定的部分。

对conv1层的weight参数随机非结构化剪枝。

import torch.nn.utils.prune as prune

prune.random_unstructured(model.conv1, 
                          name="weight", 
                          amount=0.25)
Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))

对conv1层的bias参数随机非结构化剪枝。

prune.random_unstructured(model.conv1, 
                          name="bias", 
                          amount=0.25)
Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))

以不同方式对卷积层和线性层进行剪枝。

model = LeNet5().to(device)

for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d): 
        prune.random_unstructured(module, 
                              name='weight', 
                              amount=0.3) # 将所有Conv2d层剪枝30%
    elif isinstance(module, torch.nn.Linear):
        prune.random_unstructured(module, 
                              name='weight', 
                              amount=0.5) # 将所有Linear层剪枝50%
  • 全局剪枝(global pruning)是对整个模型进行剪枝。例如,将这个模型中的所有参数剪枝25%,示例如下。
model = LeNet5().to(device)

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.25)

定制剪枝方法

class MyPruningMethod(prune.BasePruningMethod):
    '''
    每隔一个参数进行剪枝
    '''
    PRUNING_TYPE = 'unstructured' # 剪枝类型

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone() 
        mask.view(-1)[::2] = 0
        return mask

def my_unstructured(module, name):
    MyPruningMethod.apply(module, name)
    return module
model = LeNet5().to(device)
my_unstructured(model.fc1, name='bias')
Linear(in_features=400, out_features=120, bias=True)

参考文献

[1] https://docs.ray.io/en/master/index.html
[2] https://pytorch.org/docs/stable/quantization.html
[3] https://www.pytorchacademy.com/bundles/pytorch-academy
[4] Joe Papa. PyTorch Pocket Reference. 北京: 中国电力出版社,2022

  • 更多精彩内容,可点击进入我的个人主页查看

猜你喜欢

转载自blog.csdn.net/FriendshipTang/article/details/128939447