Documentación API AIMET (2)


1.1.3 API del preparador de modelos

La API AIMET PyTorch ModelPreparer utiliza las nuevas capacidades de transformación de gráficos disponibles en PyTorch 1.9+ y realiza automáticamente los cambios de definición del modelo requeridos por el usuario. Por ejemplo, cambia las funciones definidas en el pase directo a módulos tipo torch.nn.Module para activación y funciones de elementos. Además, cuando se reutiliza un módulo de tipo torch.nn.Module, se expande a un módulo independiente.

Se recomienda encarecidamente a los usuarios que utilicen primero la API AIMET PyTorch ModelPreparer y luego utilicen el modelo devuelto como entrada para todas las funciones de cuantificación de AIMET.

La API AIMET PyTorch ModelPreparer requiere al menos la versión 1.9 de PyTorch.

1.1.3.1 API de nivel superior

aimet_torch.model_preparer.prepare_model (modelo, module_to_exclude=Ninguno, module_classes_to_exclude=Ninguno, concrete_args=Ninguno)[fuente]

Prepare y modifique modelos de pytorch para funciones AIMET utilizando la API de seguimiento de símbolos torch.FX.

  1. Reemplace torch.nn.function con un módulo de tipo torch.nn.Module
  2. Cree una nueva instancia independiente de torch.nn.Module para reutilizar/duplicar el módulo

parámetro:

  • modelo (Módulo) : el modelo de pytorch que se va a modificar.
  • module_to_exclude (Opcional[List[Module]]) : lista de módulos que se excluirán al realizar el seguimiento.
  • module_classes_to_exclude (Opcional[List[Callable]]) : lista de clases de módulos que se excluirán durante el seguimiento.
  • concrete_args (Opcional[Dict[str, Any]]) : le permite especializar parcialmente sus funciones, ya sea eliminando el flujo de control o las estructuras de datos. torch.fx no podrá rastrear el modelo si tiene flujo de control. Consulte la API torch.fx.symbolic_trace en detalle.

Tipo de devolución : GraphModule
Devoluciones : modelo de pytorch modificado

1.1.3.2 Ejemplos de código

Importaciones requeridas

import torch
import torch.nn.functional as F
from aimet_torch.model_preparer import prepare_model

Ejemplo 1: modelo con reLU funcional

Comenzamos con el modelo siguiente, que contiene dos funciones relus y métodos relu en el método forward.

class ModelWithFunctionalReLU(torch.nn.Module):
    """ Model that uses functional ReLU instead of nn.Modules. Expects input of shape (1, 3, 32, 32) """
    def __init__(self):
        super(ModelWithFunctionalReLU, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 6, 5)
        self.conv2 = torch.nn.Conv2d(6, 16, 5)
        self.fc1 = torch.nn.Linear(9216, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x).relu()
        return x

Ejecute la API del preparador de modelos en un modelo pasando el modelo.

def model_preparer_functional_example():

    # Load the model and keep in eval() mode
    model = ModelWithFunctionalReLU().eval()
    input_shape = (1, 3, 32, 32)
    input_tensor = torch.randn(*input_shape)

    # Call to prepare_model API
    prepared_model = prepare_model(model)
    print(prepared_model)

    # Compare the outputs of original and transformed model
    assert torch.allclose(model(input_tensor), prepared_model(input_tensor))

Después de eso, obtenemos el modelo_preparado, que es funcionalmente igual que el modelo original. Los usuarios pueden verificar esto comparando el resultado de los dos modelos.

prepare_model ahora debería convertir las tres funciones relus en módulos torch.nn.ReLU que satisfagan las pautas del modelo descritas en Pautas del modelo .

Ejemplo 2: Modelo con módulo torch.nn.ReLU reutilizado

Comenzamos con el siguiente modelo, que contiene el módulo torch.nn.ReLU, que se utiliza en múltiples instancias dentro de la función de avance del modelo.

class ModelWithReusedReLU(torch.nn.Module):
    """ Model that uses single ReLU instances multiple times in the forward. Expects input of shape (1, 3, 32, 32) """
    def __init__(self):
        super(ModelWithReusedReLU, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 6, 5)
        self.conv2 = torch.nn.Conv2d(6, 16, 5)
        self.relu = torch.nn.ReLU()
        self.fc1 = torch.nn.Linear(9216, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        return x

Ejecute la API del preparador de modelos en un modelo pasando el modelo.

def model_preparer_reused_example():

    # Load the model and keep in eval() mode
    model = ModelWithReusedReLU().eval()
    input_shape = (1, 3, 32, 32)
    input_tensor = torch.randn(*input_shape)

    # Call to prepare_model API
    prepared_model = prepare_model(model)
    print(prepared_model)

    # Compare the outputs of original and transformed model
    assert torch.allclose(model(input_tensor), prepared_model(input_tensor))

Después de eso, obtenemos el modelo_preparado, que es funcionalmente igual que el modelo original. Los usuarios pueden verificar esto comparando el resultado de los dos modelos.

prepare_model debe tener una única instancia torch.nn.Module independiente que cumpla con las pautas del modelo descritas en Pautas del modelo.

Ejemplo 3: modelo con suma de elementos

Comenzamos con el siguiente modelo, que contiene operaciones de suma de elementos dentro de la función directa del modelo.

class ModelWithElementwiseAddOp(torch.nn.Module):
    def __init__(self):
        super(ModelWithElementwiseAddOp, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 6, 5, bias=False)
        self.conv2 = torch.nn.Conv2d(3, 6, 5)

    def forward(self, *inputs):
        x1 = self.conv1(inputs[0])
        x2 = self.conv2(inputs[1])
        x = x1 + x2
        return x

Ejecute la API del preparador de modelos en un modelo pasando el modelo.

def model_preparer_elementwise_add_example():

    # Load the model and keep in eval() mode
    model = ModelWithElementwiseAddOp().eval()
    input_shape = (1, 3, 32, 32)
    input_tensor = [torch.randn(*input_shape), torch.randn(*input_shape)]

    # Call to prepare_model API
    prepared_model = prepare_model(model)
    print(prepared_model)

    # Compare the outputs of original and transformed model
    assert torch.allclose(model(*input_tensor), prepared_model(*input_tensor))

Después de eso, obtenemos el modelo_preparado, que es funcionalmente igual que el modelo original. Los usuarios pueden verificar esto comparando el resultado de los dos modelos.

1.1.3.3 Limitaciones de la API de seguimiento de símbolos de torch.fx

Limitaciones del rastreo simbólico de torch.fx: https://pytorch.org/docs/stable/fx.html#limitations-of-symbolic-tracing

  1. Los bucles torch.fx o las declaraciones if-else no admiten el flujo de control dinámico , donde las condiciones pueden depender de ciertos valores de entrada. Solo puede rastrear una ruta de ejecución; todas las demás ramas sin seguimiento se ignorarán. Por ejemplo, el seguimiento de la siguiente función simple fallará con un TraceError que indique "Las variables de seguimiento de símbolos no se pueden utilizar como entradas para controlar el flujo":
def f(x, flag):
    if flag:
        return x
    else:
        return x*2

torch.fx.symbolic_trace(f) # Fails!
fx.symbolic_trace(f, concrete_args={
    
    'flag': True})

Solución alternativa para este problema:

  • El flujo de control dinámico en muchas situaciones puede convertirse simplemente en un flujo de control estático, respaldado por el seguimiento de símbolos de torch.fx. El flujo de control estático es un bucle donde o una declaración if-else cuyo valor no puede cambiar entre diferentes modelos en el paso hacia adelante. Al especializar su función de reenvío pasando valores concretos a "concrete_args", puede realizar un seguimiento de dichos casos eliminando la dependencia de los datos de los valores de entrada.

  • En un flujo de control verdaderamente dinámico, el usuario debe envolver este código en un alcance a nivel de modelo usando la API torch.fx.wrap, que lo mantendrá como un nodo en lugar de rastrearlo mediante:

    @torch.fx.wrap
    def custom_function_not_to_be_traced(x, y):
        """ Function which we do not want to be traced, when traced using torch FX API, call to this function will
        be inserted as call_function, and won't be traced through """
        for i in range(2):
            x += x
            y += y
        return x * x + y * y
    
    
  1. El seguimiento de símbolos no admite funciones que no sean de Torch y que no utilicen el mecanismo torch_function de forma predeterminada.

Solución alternativa para este problema:

  • Si no queremos capturarlos en el seguimiento de símbolos, entonces los usuarios deben usar la API torch.fx.wrap() en el alcance del nivel del módulo:

    import torch
    import torch.fx
    torch.fx.wrap('len')  # call the API at module-level scope.
    torch.fx.wrap('sqrt') # call the API at module-level scope.
    
    class ModelWithNonTorchFunction(torch.nn.Module):
        def __init__(self):
            super(ModelWithNonTorchFunction, self).__init__()
            self.conv = torch.nn.Conv2d(3, 4, kernel_size=2, stride=2, padding=2, bias=False)
    
        def forward(self, *inputs):
            x = self.conv(inputs[0])
            return x / sqrt(len(x))
    
    model = ModelWithNonTorchFunction().eval()
    model_transformed = prepare_model(model)
    
    
  1. Personalice el comportamiento de seguimiento anulando la API Tracer.is_leaf_module()

En el rastreo de símbolos, los módulos de hoja aparecen como nodos en lugar de ser rastreados, y todos los módulos estándar de torch.nn son el conjunto predeterminado de módulos de hoja. Sin embargo, este comportamiento se puede cambiar anulando la API Tracer.is_leaf_module().

La API del preparador de modelos AIMET expone el parámetro "module_to_exclude", que se puede utilizar para evitar el seguimiento de ciertos módulos. Por ejemplo, revisemos el siguiente fragmento de código; no queremos rastrear más el CustomModule:

class CustomModule(torch.nn.Module):
    @staticmethod
    def forward(x):
        return x * torch.nn.functional.softplus(x).sigmoid()

class CustomModel(torch.nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=2)
        self.custom = CustomModule()

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.custom(x)
        return x

model = CustomModel().eval()
prepared_model = prepare_model(model, modules_to_exclude=[model.custom])
print(prepared_model)

En este ejemplo, "self.custom" se conserva como nodo y no se realiza un seguimiento.

  1. El constructor tensorial no es rastreable

Por ejemplo, examinemos el siguiente fragmento de código:

def f(x):
    return torch.arange(x.shape[0], device=x.device)

torch.fx.symbolic_trace(f)

Error traceback:
    return torch.arange(x.shape[0], device=x.device)
    TypeError: arange() received an invalid combination of arguments - got (Proxy, device=Attribute), but expected one of:
    * (Number end, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
    * (Number start, Number end, Number step, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)

El fragmento de código anterior es problemático porque los parámetros de torch.arange() dependen de la entrada. Solución alternativa para este problema:

  • Utilice constructores deterministas (codificados) para que los valores que producen se incrusten en el gráfico como constantes:

    def f(x):
        return torch.arange(10, device=torch.device('cpu'))
    
  • O use la API torch.fx.wrap para envolver torch.arange() y llamarlo:

@torch.fx.wrap
def do_not_trace_me(x):
    return torch.arange(x.shape[0], device=x.device)

def f(x):
    return do_not_trace_me(x)

torch.fx.symbolic_trace(f)

Supongo que te gusta

Origin blog.csdn.net/weixin_38498942/article/details/133066832
Recomendado
Clasificación