TensorRT量化工具pytorch_quantization代码解析(一)

量化工具箱pytorch_quantization 通过提供一个方便的 PyTorch 库来补充 TensorRT ,该库有助于生成可优化的 QAT 模型。该工具包提供了一个 API 来自动或手动为 QAT 或 PTQ 准备模型。

API 的核心是 TensorQuantizer 模块,它可以量化、伪量化或收集张量的统计信息。它与 QuantDescriptor 一起使用,后者描述了如何量化张量。在 TensorQuantizer 之上的是量化模块,这些模块被设计为 PyTorch 全精度模块的替代品。这些是使用 TensorQuantizer 对模块的权重和输入进行伪量化或收集统计信息的方便模块。

API 支持将 PyTorch 模块自动转换为其量化版本。转换也可以使用 API 手动完成,这允许在不想量化所有模块的情况下进行部分量化。例如,一些层可能对量化更敏感,并且使其未量化可提高任务精度。

量化第一步是将量化器模块添加到神经网络图中。该包提供了许多量化层模块,其中包含用于输入和权重的量化器。例如quant_nn.QuantLinear,它可以用来代替nn.Linear。这些量化层可以通过猴子修补或手动修改模型定义来自动替换。自动层替换是使用quant_module完成的。这应该在创建模型之前调用。

首先看以下代码:

from pytorch_quantization import quant_modules
quant_modules.initialize()

initialize()会动态地修改 PyTorch 代码,适用于每个模块的所有实例,将 torch.nn.module 的一些子类替换为对应的量化版本。如果不希望所有模块都量化,则应手动替换量化模块。独立量化器也可以添加到带有quant_nn.TensorQuantizer的模型中。

initialize()位于:pytorch_quantization\quant_modules.py,作用使用使用monkey patching进行动态模块更换为量化版本

什么是猴子补丁

  • Python是一种典型的动态脚本语言。它不仅具有 动态类型(dynamic type) ,而且它的 对象模型(object model) 也是动态的。Python的类是可变的(mutable),方法(methods)只是类的属性(attributes);这允许我们在 运行时(run time) 修改其行为。这被称为猴子补丁(Monkey Patching), 它指的是偷偷地更改代码。
  • Monkey Patching只是在 运行时(run time) 动态替换属性(attributes)。而在Python中,术语monkey patch指的是对函数(function)、类(class)或模块(module)的动态(或运行时)修改。
def initialize(float_module_list=None, custom_quant_modules=None):
    """
    用量化版本动态地替换模块。在内部,状态由helper类对象维护,该对象有助于将原始模块替换回去。

		参数:
		float_module_list:列表,用户提供的列表,其中指明哪些模块不可执行替换
		custom_quant_modules:一个字典。用户提供的映射,用于指示除torch.nn及其相应量化版本之外的任何其他模块。
		Returns:空
	"""
    # 准备monkey patching中使用的内部变量quant_map和orginal_func_map
    _quant_module_helper_object.prepare_state(float_module_list, custom_quant_modules)
    #执行量化模块替换
    _quant_module_helper_object.apply_quant_modules()

def deactivate():
    """
    动态模块更换,可逆转monkey patching
    使用维护状态的helper类对象动态地替换回先前在initialize()函数调用中被monkey patching的原始模块。
    """
    _quant_module_helper_object.restore_float_modules()

# 维护被替换模块状态的全局对象。
_quant_module_helper_object = QuantModuleReplacementHelper()

自定义量化模块使用示例:

# torch.nn模块定义不可执行替换列表
float_module_list = ["Linear"]
# torch.nn以外的模块自定义映射
custom_quant_modules = [(torch.nn, "Linear", quant_nn.QuantLinear)]
# Monkey修补模块
pytorch_quantization.quant_modules.initialize(float_module_list, custom_modules)
# 使用量化模块
pytorch_quantization.quant_modules.deactivate()

继续看helperQuantModuleReplacementHelper

class QuantModuleReplacementHelper():
    """
    帮助量化版本替换torch.nn模块
    术语monkey patch指的是对函数(function)、类(class)或模块(module)的动态(或运行时)修改
    该模块用工具内部实现或任何其他用户提供的自定义模块提供的量化版 替换(通过monkey patching)torch.nn模块

    属性:
    orginal_func_map:一个dict.维护原始torch.nn模块字典
    quant_support_list:列表,包含工具提供的量化版本的模块名称
    quant_map:一个字典,包含模块名称及其量化版本的字典
    quant_switch_opt:一个字典,用于指示哪些模块不能替换其量化版本。该dict由用户提供的列表更新,该列表指示在monkey patching中要忽略的模块
    """
    def __init__(self):
        # 保留要更换的原始模块
        self.orginal_func_map = set()

        # 默认情况下,维护工具支持的量化模块列表
        self.default_quant_map = _DEFAULT_QUANT_MAP

        # 保存最终量化模块。
        self.quant_map = set()

_DEFAULT_QUANT_MAP是包含量化模块映射的文件的全局成员

_DEFAULT_QUANT_MAP = [_quant_entry(torch.nn, "Conv1d", quant_nn.QuantConv1d),
                      _quant_entry(torch.nn, "Conv2d", quant_nn.QuantConv2d),
                      _quant_entry(torch.nn, "Conv3d", quant_nn.QuantConv3d),
                      _quant_entry(torch.nn, "ConvTranspose1d", quant_nn.QuantConvTranspose1d),
                      _quant_entry(torch.nn, "ConvTranspose2d", quant_nn.QuantConvTranspose2d),
                      _quant_entry(torch.nn, "ConvTranspose3d", quant_nn.QuantConvTranspose3d),
                      _quant_entry(torch.nn, "Linear", quant_nn.QuantLinear),
                      _quant_entry(torch.nn, "LSTM", quant_nn.QuantLSTM),
                      _quant_entry(torch.nn, "LSTMCell", quant_nn.QuantLSTMCell),
                      _quant_entry(torch.nn, "AvgPool1d", quant_nn.QuantAvgPool1d),
                      _quant_entry(torch.nn, "AvgPool2d", quant_nn.QuantAvgPool2d),
                      _quant_entry(torch.nn, "AvgPool3d", quant_nn.QuantAvgPool3d),
                      _quant_entry(torch.nn, "AdaptiveAvgPool1d", quant_nn.QuantAdaptiveAvgPool1d),
                      _quant_entry(torch.nn, "AdaptiveAvgPool2d", quant_nn.QuantAdaptiveAvgPool2d),
                      _quant_entry(torch.nn, "AdaptiveAvgPool3d", quant_nn.QuantAdaptiveAvgPool3d),]

_quant_entry定义命名元组,用于存储量化模块映射,它拥有三个属性orig_mod mod_name replace_mod

_quant_entry = namedtuple('quant_entry', 'orig_mod mod_name replace_mod')

QuantModuleReplacementHelper类的属性方法:

  • prepare_state 准备稍后在monkey patching机制中使用的量化模块的命名字典quant_map和更换为原始模块orginal_func_map
    • 设置torch.nn工具支持的量化模块列表
    • 为torch.nn以外的模块设置自定义映射
    • 使用float_module_list关闭用户指示模块的monkey patching替换
    def prepare_state(self, float_module_list=None, custom_map=None):
        """

        """

        # 对于支持的默认量化模块,生成quant_map
        for item in self.default_quant_map:
            if float_module_list is not None and item.mod_name in float_module_list:
                # 如果float_module_list中存在此模块,则跳过此模块
                continue
            else:
                # 将模块追加到将在monkey patching中使用的变量中
                self.quant_map.add(item)
                # 存储要在反向monkey patching中使用的原始模块
                self.orginal_func_map.add(_quant_entry(item.orig_mod, item.mod_name,
                                                       getattr(item.orig_mod, item.mod_name)))

        # 将自定义模块添加到quant_map
        if custom_map is not None:
            for item in custom_map:
                # 将自定义模块附加到将在monkey补丁中使用的列表中
                # 将元组转换为命名元组
                self.quant_map.add(_quant_entry(item[0], item[1], item[2]))
                # 将原始模块存储在另一个列表中,该列表将用于反向monkey patching
                self.orginal_func_map.add(_quant_entry(item[0], item[1], getattr(item[0], item[1])))
  • apply_quant_modules:根据quant_map,执行替换为量化模块
    def apply_quant_modules(self):
        for entry in self.quant_map:
            # 用于设置属性值,该属性不一定是存在的,对应函数 getattr()
            setattr(entry.orig_mod, entry.mod_name, entry.replace_mod)
  • restore_float_modules:通过使用orginal_func_map替换回原始模块,反转monkey patch的效果
    def restore_float_modules(self):
        for entry in self.orginal_func_map:
            setattr(entry.orig_mod, entry.mod_name, entry.replace_mod)

猜你喜欢

转载自blog.csdn.net/weixin_42905141/article/details/129263529