TorchScript para implantação de modelo

1. Introdução ao torchscript e jit

1. Sobre o script de tocha

TorchScript é uma representação intermediária do modelo Pytorch (herdado de nn.Module).O modelo torchscript salvo pode ser executado em um ambiente de alto desempenho como C++

TorchScript é uma forma de criar modelos serializáveis ​​e otimizáveis ​​a partir do código PyTorch. Qualquer programa TorchScript pode ser salvo de um processo Python e carregado em um processo que não possui dependências Python.

Simplificando, o TorchScript pode converter um gráfico dinâmico em um gráfico estático. Sob o recurso de gráfico dinâmico flexível do pytorch, o torchscript fornece uma ferramenta que ainda pode obter a estrutura do modelo (definição do modelo).

2. Sobre tocha.jit

O que é JIT?
Em primeiro lugar, devemos saber que JIT é um conceito. O nome completo é Just In Time Compilation, que é traduzido como "Just In Time Compilation" em chinês. É um método de otimização de programa. Um cenário de uso comum é "regular expressão". Por exemplo, usando expressões regulares em Python:

prog = re.compile(pattern)
result = prog.match(string)
#或
result = re.match(pattern, string)

Os dois exemplos acima foram extraídos diretamente da documentação oficial do Python e pode ser visto na documentação que os dois métodos de escrita são "equivalentes" em termos de resultados. Mas preste atenção na primeira forma de escrever, a expressão regular será compilada primeiro e depois usada. Se você continuar lendo a documentação do Python, poderá encontrar a seguinte passagem:


usar re.compile() e salvar o objeto de expressão regular resultante para reutilização é mais eficiente quando a expressão será usada várias vezes em um único programa. E esse processo de compilação pode ser entendido como JIT (compilação just-in-time).

PyTorch é conhecido por sua "facilidade de uso" desde o seu lançamento e é mais adequado para o desenvolvimento de Python nativo, graças à estrutura de "gráfico dinâmico" do PyTorch. Podemos adicionar qualquer instrução de controle de processo Python antes do modelo PyTorch, e não haverá problema em passar pelo ponto de interrupção, mas se for TensorFlow, você precisará usar o processo desenvolvido pelo TensorFlow, como o controle tf.cond. Os modelos de gráficos dinâmicos trocam alguns recursos avançados pela facilidade de uso.

Vantagens dos JITs:

1. Implantação do modelo
Os dois novos recursos principais lançados pelo PyTorch versão 1.0 são JIT e API C++. Não é absurdo lançar esses dois recursos juntos. JIT é uma ponte entre Python e C++. Podemos usar Python para treinar o modelo e então O modelo é convertido em um módulo independente de linguagem por meio do JIT, para que o C++ possa ser chamado de maneira muito conveniente.A partir de então, "usar Python para treinar o modelo e usar C++ para implantar o modelo no ambiente de produção" tornou-se fácil tarefa para PyTorch. E por causa do uso de C++, agora podemos implantar modelos PyTorch em praticamente qualquer plataforma e dispositivo: Raspberry Pi, iOS, Android, etc...

  1. melhoria de desempenho

Por se tratar de um recurso fornecido para implantação e produção, é inevitável que uma grande otimização tenha sido feita no desempenho. Se a cena inferida tiver requisitos de alto desempenho, você pode considerar a conversão do modelo (torch.nn.Module) para o Módulo TorchScript, e então prossiga para inferir.

  1. visualização do modelo

TensorFlow ou Keras é muito amigável para ferramentas de visualização de modelo (TensorBoard, etc.), porque é um modelo de programação de gráfico estático. Depois que o modelo é definido, a estrutura e a lógica direta de todo o modelo já estão claras, mas o próprio PyTorch não apoie-o, então os modelos PyTorch sempre foram ruins na visualização, mas o JIT melhora a situação. Agora você pode usar a função trace do JIT para obter a lógica direta do modelo PyTorch para uma determinada entrada e pode obter a estrutura aproximada do modelo por meio da lógica direta. (Mas se forwardhouver muitas instruções de controle condicional no método, este ainda não é um bom método)

3. Duas maneiras de gerar o Módulo TorchScript

1. Scripts

Você pode usar a linguagem TorchScript diretamente para definir um módulo PyTorch JIT e, em seguida, usar torch.jit.script para convertê-lo em um módulo TorchScript e salvá-lo como um arquivo. A própria linguagem TorchScript também é um código Python, portanto pode ser escrita diretamente em um arquivo Python.

Usar a linguagem TorchScript é como usar o TensorFlow, você precisa definir um gráfico completo de antemão. Para TensorFlow, sabemos que não podemos usar diretamente if e outras instruções em Python para controle condicional, mas precisamos usar tf.cond, mas para TorchScript ainda podemos usar diretamente instruções de controle condicional, como if e for, mesmo em gráficos estáticos , o PyTorch ainda adere ao recurso "fácil de usar". A linguagem TorchScript é um subconjunto de tipo estaticamente do Python, que também é implementado usando o módulo de digitação do Python 3, portanto, a experiência de escrever a linguagem TorchScript é exatamente a mesma do Python, exceto que alguns recursos do Python não podem ser usados ​​​​(porque é um subconjunto), que pode ser passado na Referência da linguagem TorchScript para visualizar as semelhanças e diferenças com o Python nativo.

Em teoria, o Módulo TorchScript definido pelo Scripting é muito amigável para modelar ferramentas de visualização, pois toda a estrutura do gráfico foi definida previamente.

  1. Rastreamento

Uma maneira mais fácil de usar o Módulo TorchScript é usar o Tracing, que pode converter diretamente o modelo PyTorch (torch.nn.Module) em um Módulo TorchScript. “Tracking”, como o nome sugere, é fornecer uma “entrada” para deixar o modelo avançar novamente, de forma a obter a estrutura do gráfico através do caminho de fluxo da entrada. Este método é muito prático para modelos com lógica direta simples, mas se o próprio encaminhamento contém muitas instruções de controle de fluxo, pode haver problemas, porque a mesma entrada não pode atravessar todos os ramos lógicos.

2. Gere um modelo de tocha para raciocínio

1. Carregue o modelo de checkpointer da tocha exportado

Carregar arquivos de configuração de modelo pré-treinados e estrutura de modelo reescrita

# 【multitask_classify_ner 多任务分类模型代码(包括classify任务和ner任务)】
class BertFourLevelArea(BertPreTrainedModel):
    """BERT model for four level area.
    """
    def __init__(self, config, num_labels_cls, num_labels_ner, inner_dim, RoPE):
        super(BertFourLevelArea, self).__init__(config, num_labels_cls, num_labels_ner, inner_dim, RoPE)
        self.bert = BertModel(config)
        self.num_labels_cls = num_labels_cls
        self.num_labels_ner = num_labels_ner
        self.inner_dim = inner_dim
        self.hidden_size = config.hidden_size
        self.dense_ner = nn.Linear(self.hidden_size, self.num_labels_ner * self.inner_dim * 2)
        self.dense_cls = nn.Linear(self.hidden_size, num_labels_cls)
        self.RoPE = RoPE
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.apply(self.init_bert_weights)

    def sinusoidal_position_embedding(self, batch_size, seq_len, output_dim):
        position_ids = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(-1)

        indices = torch.arange(0, output_dim // 2, dtype=torch.float)
        indices = torch.pow(10000, -2 * indices / output_dim)
        embeddings = position_ids * indices
        embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
        embeddings = embeddings.repeat((batch_size, *([1]*len(embeddings.shape))))
        embeddings = torch.reshape(embeddings, (batch_size, seq_len, output_dim))
        embeddings = embeddings.to(self.device)
        return embeddings

    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        # sequence_output: Last Encoder Layer.shape: (batch_size, seq_len, hidden_size)
        encoded_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
        sequence_output = encoded_layers[-1]

        batch_size = sequence_output.size()[0]
        seq_len = sequence_output.size()[1]

        # 【Bert Ner GlobalPointer】:
        # outputs: (batch_size, seq_len, num_labels_ner*inner_dim*2)
        outputs = self.dense_ner(sequence_output)
        # outputs: (batch_size, seq_len, num_labels_ner, inner_dim*2)
        outputs = torch.split(outputs, self.inner_dim * 2, dim=-1)      # TODO:1
        outputs = torch.stack(outputs, dim=-2)              # TODO:2

        # qw,kw: (batch_size, seq_len, num_labels_ner, inner_dim)
        qw, kw = outputs[...,:self.inner_dim], outputs[...,self.inner_dim:] # TODO:3

        if self.RoPE:
            # pos_emb:(batch_size, seq_len, inner_dim)
            pos_emb = self.sinusoidal_position_embedding(batch_size, seq_len, self.inner_dim)
            # cos_pos,sin_pos: (batch_size, seq_len, 1, inner_dim)
            cos_pos = pos_emb[..., None, 1::2].repeat_interleave(2, dim=-1)
            sin_pos = pos_emb[..., None,::2].repeat_interleave(2, dim=-1)
            qw2 = torch.stack([-qw[..., 1::2], qw[...,::2]], -1)
            qw2 = qw2.reshape(qw.shape)
            qw = qw * cos_pos + qw2 * sin_pos
            kw2 = torch.stack([-kw[..., 1::2], kw[...,::2]], -1)
            kw2 = kw2.reshape(kw.shape)
            kw = kw * cos_pos + kw2 * sin_pos

        # logits_ner:(batch_size, num_labels_ner, seq_len, seq_len)
        logits_ner = torch.einsum('bmhd,bnhd->bhmn', qw, kw)    # TODO:4

        # padding mask
        pad_mask = attention_mask.unsqueeze(1).unsqueeze(1).expand(batch_size, self.num_labels_ner, seq_len, seq_len)   # TODO:5
        # pad_mask_h = attention_mask.unsqueeze(1).unsqueeze(-1).expand(batch_size, self.num_labels_ner, seq_len, seq_len)
        # pad_mask = pad_mask_v&pad_mask_h
        logits_ner = logits_ner*pad_mask - (1-pad_mask)*1e12    # TODO:6

        # 排除下三角
        mask = torch.tril(torch.ones_like(logits_ner), -1)  # TODO:7
        logits_ner = logits_ner - mask * 1e12   # TODO:8

        # 【Bert Classify】:
        pooled_output = self.dropout(pooled_output)
        logits_cls = self.dense_cls(pooled_output)

        return logits_cls, logits_ner



#【加载预训练模型参数】
config = modeling.BertConfig.from_json_file('/root/ljh/space-based/Deep_Learning/Pytorch/multitask_classify_ner/pretrain_model/bert-base-chinese/config.json')

#【加载我们训练的模型  】
#【num_labels_cls 和 num_labels_ner为我们训练的label_counts 这次训练的分类任务标签数为1524 ,NER任务的分类数为13】
num_labels_cls = 1524
num_labels_ner = 13
model = modeling.BertFourLevelArea(
    config,
    num_labels_cls=num_labels_cls,
    num_labels_ner=num_labels_ner,
    inner_dim=64,
    RoPE=False
)

2. Carregar parâmetros do modelo do modelo

Carregue os parâmetros do modelo treinado

#【训练完成的tourch 模型地址】
init_checkpoint='/root/ljh/space-based/Deep_Learning/Pytorch/multitask_classify_ner/outputs.bak/multitask_classify_ner/pytorch_model.bin'
#【载入模型】
checkpoint = torch.load(init_checkpoint, map_location=torch.device("cuda"))
checkpoint = checkpoint["model"] if "model" in checkpoint.keys() else checkpoint
model.load_state_dict(checkpoint)
device = torch.device("cuda")

#【将模型导入GPU】
model = model.to(device)
#【模型初始化】
model.eval()

3. Construir parâmetros de rastreamento torch.jit

Na lógica de propagação direta de nosso modelo multitarefa de limpeza de endereço, não há estruturas de condições de julgamento múltiplas, então escolhemos a forma de rastreamento (Tracing) para registrar o processo de propagação direta e a estrutura do modelo.

#【定义tokenizer 】
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained('/root/ljh/space-based/Deep_Learning/Pytorch/multitask_classify_ner/pretrain_model/bert-base-chinese', add_special_tokens=True, do_lower_case=False)


input_str='上海上海市青浦区华隆路E通世界华新园'
max_seq_length=64


#【生成bert模型输出】
def input2feature(input_str, max_seq_length=48):
    # 预处理字符
    tokens_a = tokenizer.tokenize(input_str)
    # 如果超过长度限制,则进行截断
    if len(input_str) > max_seq_length - 2:
        tokens_a = tokens_a[0:(max_seq_length - 2)]
    tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_length = len(input_ids)
    input_mask = [1] * input_length
    segment_ids = [0] * input_length
    while len(input_ids) < max_seq_length:
        input_ids.append(0)
        input_mask.append(0)
        segment_ids.append(0)
    return input_ids, input_mask, segment_ids

#【输入地址token化     input_ids --> list()】
input_ids, input_mask, segment_ids = input2feature(input_str, max_seq_length)


# 【list -> tensor】
input_ids = torch.tensor(input_ids, dtype=torch.long)
input_mask = torch.tensor(input_mask, dtype=torch.long)
segment_ids = torch.tensor(segment_ids, dtype=torch.long)

#【这里stack 是因为模型内部定义的输出参数需要stack 】
input_ids = torch.stack([input_ids], dim=0)
input_mask = torch.stack([input_mask], dim=0)
segment_ids = torch.stack([segment_ids], dim=0)


#【将参数推送至cuda设备中】
device = torch.device("cuda")
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)

#【input_ids.shape --> torch.Size([1, 64])】

4. Use torch.jit para exportar o modelo do módulo TorchScript

jit usa a forma de rastreamento (Tracing) para registrar o processo de propagação direta e a estrutura do modelo

#【根据输出的input_ids, input_mask, segment_id记录前向传播过程】
script_model = torch.jit.trace(model,[input_ids, input_mask, segment_ids],strict=True)
#【保存】
torch.jit.save(script_model, "./multitask_test/multitask_model/1/model.pt")

5. Verifique se o módulo TorchScript está correto

#【查看torch模型结果】
cls_res, ner_res = model(input_ids, input_mask, segment_ids)

import numpy as np
np.argmax(cls_res.detach().cpu().numpy()) 
#【result:673】


#【load torchscript model】
jit_model = torch.jit.load('./multitask_test/multitask_model/1/model.pt')
example_outputs_cls,example_outputs_ner = jit_model(input_ids, input_mask, segment_ids)
np.argmax(example_outputs_cls.detach().cpu().numpy()) 
#【result:673】

3. Use o servidor triton para iniciar o modelo do modelo torchscript

1. Modifique o arquivo de configuração config.ptxx

name: "multitask_model"
platform: "pytorch_libtorch"
max_batch_size: 8
input [
  {
    
    
    name: "input_ids"
    data_type: TYPE_INT64
    dims:  64
  },
  {
    
    
    name: "segment_ids"
    data_type: TYPE_INT64
    dims:  64 
  },
  {
    
    
    name: "input_mask"
    data_type: TYPE_INT64
    dims:  64
  }
]
output [
  {
    
    
    name: "cls_logits"
    data_type: TYPE_FP32
    dims: [1, 1524]
  },
  {
    
    
    name: "ner_logits"
    data_type: TYPE_FP32
    dims: [ -1, 13, 64, 64 ]
   }
]


dynamic_batching {
    
    
    preferred_batch_size: [ 1, 2, 4, 8 ]
    max_queue_delay_microseconds: 50
  }

instance_group [
{
    
    
    count: 1
    kind: KIND_GPU
    gpus: [0]
}
]

2. Modelo de estrutura de diretório

multitask_test
└── multitask_model
├── 1
│ └── model.pt
├── config.pbtxt
└── label_cls.csv

3. Inicie o tritão

tritonserver --model-store=/root/ljh/space-based/Deep_Learning/Pytorch/multitask_classify_ner/multitask_test  --strict-model-config=false --exit-on-error=false

4. Implantação do modelo Torchscript a ser resolvida

No processo de tentar usar a implantação de múltiplas placas do torchscript, o modelo do modelo está vinculado ao cuda. ​​Depois de iniciar o modelo com triton, apenas o dispositivo cuda vinculado ao modelo pode ser executado normalmente.

Referência de problema semelhante: https://github.com/triton-inference-server/server/issues/2626

Guess you like

Origin blog.csdn.net/TFATS/article/details/129706241