[AI] How to make language model LLMs stream output: HuggingFace Transformers implementation

HugginFace Transforms is a very convenient library that integrates a lot of SOTA models, including: LLAMA, GPT, ChatGLM Moss, etc. At present, basically the mainstream solutions are implemented based on the framework of HugginFace Transforms. In the past, if you want to stream output, you need to change the underlying reasoning logic of the model yourself.

Such as ChatGLM, the streaming output implemented by itself is as follows:

#chatglm-6bmodel/modeling_chatglm.py
@torch.no_grad()
    def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
                    do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
        if history is None:
            history = []
        if logits_processor is None:
            logits_processor = LogitsProcessorList()
        logits_processor.append(InvalidScoreLogitsProcessor())
        gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
                      "temperature": temperature, "logits_processor": logits_processor, **kwargs}
        if not history:
            prompt = query
        else:
            prompt = ""
            for i, (old_query, response) in enumerate(history):
                prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
            prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
        inputs = tokenizer([prompt], return_tensors="pt")
        inputs = inputs.to(self.device)
        for outputs in self.stream_generate(**inputs, **gen_kwargs):
            outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
            response = tokenizer.decode(outputs)
            response = self.process_response(response)
            new_history = history + [(query, response)]
            yield response, new_history

    @torch.no_grad()
    def stream_generate(
            self,
            input_ids,
            generation_config: Optional[GenerationConfig] = None,
            logits_processor: Optional[LogitsProcessorList] = None,
            stopping_criteria: Optional[StoppingCriteriaList] = None,
            prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
            **kwargs,
    ):
        batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]

        if generation_config is None:
            generation_config = self.generation_config
        generation_config = copy.deepcopy(generation_config)
        model_kwargs = generation_config.update(**kwargs)
        bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id

        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]

        has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
        if has_default_max_length and generation_config.max_new_tokens is None:
            warnings.warn(
                f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
                "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
                " recommend using `max_new_tokens` to control the maximum length of the generation.",
                UserWarning,
            )
        elif generation_config.max_new_tokens is not None:
            generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
            if not has_default_max_length:
                logger.warn(
                    f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
                    f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
                    "Please refer to the documentation for more information. "
                    "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
                    UserWarning,
                )

        if input_ids_seq_length >= generation_config.max_length:
            input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
            logger.warning(
                f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
                f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
                " increasing `max_new_tokens`."
            )

        # 2. Set generation parameters if not already defined
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

        logits_processor = self._get_logits_processor(
            generation_config=generation_config,
            input_ids_seq_length=input_ids_seq_length,
            encoder_input_ids=input_ids,
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
            logits_processor=logits_processor,
        )

        stopping_criteria = self._get_stopping_criteria(
            generation_config=generation_config, stopping_criteria=stopping_criteria
        )
        logits_warper = self._get_logits_warper(generation_config)

        unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
        scores = None
        while True:
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
            # forward pass to get next token
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=False,
                output_hidden_states=False,
            )

            next_token_logits = outputs.logits[:, -1, :]

            # pre-process distribution
            next_token_scores = logits_processor(input_ids, next_token_logits)
            next_token_scores = logits_warper(input_ids, next_token_scores)

            # sample
            probs = nn.functional.softmax(next_token_scores, dim=-1)
            if generation_config.do_sample:
                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
            else:
                next_tokens = torch.argmax(probs, dim=-1)

            # update generated ids, model inputs, and length for next step
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())

            # stop when each sentence is finished, or if we exceed the maximum length
            if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
                break
            yield input_ids

HuggingFace Transformers implementation

hugging face also noticed this requirement, and added two streaming output interfaces in v4.30.1:

  • TextStreamer : Ability to stream results to stdout
  • TextIteratorStreamer: able to operate in a custom loop

The details are as follows

TextStreamer

Text generation strategiesWe’re on a journey to advance and democratize artificial intelligence through open source and open science.https://huggingface.co/docs/transformers/main/generation_strategies

The generate() supports streaming, through its streamer input. The streamer input is compatible any instance from a class that has the following methods: put() and end(). Internally, put() is used to push new tokens and end() is used to flag the end of text generation.

The API for the streamer classes is still under development and may change in the future.

In practice, you can craft your own streaming class for all sorts of purposes! We also have basic streaming classes ready for you to use. For example, you can use the TextStreamer class to stream the output of generate() into your screen, one word at a time:

from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer

tok = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
streamer = TextStreamer(tok)

# Despite returning the usual output, the streamer will also print the generated text to stdout.
_ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)

 TextIteratorStreamer

Utilities for GenerationWe’re on a journey to advance and democratize artificial intelligence through open source and open science.icon-default.png?t=N4P3https://huggingface.co/docs/transformers/main/en/internal/generation_utils#transformers.TextStreamer

Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is useful for applications that benefit from acessing the generated text in a non-blocking way (e.g. in an interactive Gradio demo).

from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

tok = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
streamer = TextIteratorStreamer(tok)

# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
    generated_text += new_text
generated_text

ChatGLM Streaming Reply Demo 

The following is a simple cli demo using chatGLM6B plus TextIteratorStreamer and TextStreamer

import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, AutoModel
from transformers import TextIteratorStreamer
from threading import Thread

tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model = model.eval()

# 建构显示对话
def build_prompt(history):
    prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
    for query, response in history:
        prompt += f"\n\n用户:{query}"
        prompt += f"\n\nChatGLM-6B:{response}"
    return prompt

# 维护多轮历史
def build_history(history, query, response, index):
    history[index] = [query, response]
    return history

if __name__ == "__main__":
     # TextIteratorStreamer实现
    streamer = TextIteratorStreamer(tokenizer)
    history = []
    turn_count = 0
    while True:
        query = input("\n用户:")
        if query.strip() == "stop":
            break
        if query.strip() == "clear":
            history = []
            turn_count = 0
            os.system(clear_command)
            print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
            continue
        
        history.append([query, ""])
        
        inputs = tokenizer([query], return_tensors="pt").to('cuda')
        generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512)
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()
        generated_text = ""
        count = 0
        # 流式输出
        for new_text in streamer:
            generated_text += new_text
            history = build_history(history, query, generated_text, turn_count)
            count += 1
            if count % 8 == 0:
                os.system("clear")
                print(build_prompt(history), flush=True)
        os.system("clear")
        print(build_prompt(history), flush=True)
        turn_count += 1
    
    # TextStreamer实现
    # streamer = TextStreamer(tokenizer)
    # _ = model.generate(**inputs, streamer=streamer, max_new_tokens=512)

Guess you like

Origin blog.csdn.net/weixin_44491772/article/details/131205174