From 743e7fad2f03ab0082d40c5da960aca29daed76e Mon Sep 17 00:00:00 2001 From: Camille Zhong <44392324+Camille7777@users.noreply.github.com> Date: Thu, 7 Mar 2024 14:58:56 +0800 Subject: [PATCH] [colossal-llama2] add stream chat examlple for chat version model (#5428) * add stream chat for chat version * remove os.system clear * modify function name --- .../utils/stream_chat_patch.py | 247 ++++++++++++++++++ .../Colossal-LLaMA-2/stream_chat_example.py | 55 ++++ 2 files changed, 302 insertions(+) create mode 100644 applications/Colossal-LLaMA-2/colossal_llama2/utils/stream_chat_patch.py create mode 100644 applications/Colossal-LLaMA-2/stream_chat_example.py diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/stream_chat_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/stream_chat_patch.py new file mode 100644 index 000000000..8f8eecb18 --- /dev/null +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/stream_chat_patch.py @@ -0,0 +1,247 @@ +from copy import deepcopy +from typing import Optional, List, Dict, Tuple, Callable, Any + +import torch +from torch import nn + +from transformers import PreTrainedTokenizer +from transformers.utils import logging +from transformers.generation.utils import GenerationConfig, LogitsProcessorList, StoppingCriteriaList + +logger = logging.get_logger(__name__) + + +def get_prompt_template( + input_query:str, + history:List[Dict]= None, + roles:list = ["", "Human", "Assistant"], +) -> str: + """ + Generates a prompt template for chat models based on input and history. + + Args: + input_query (str): User's current input query. + history (List[Dict], optional): List of past conversations, each a dict with 'role' and 'message'. + roles (list): Specifies the roles in the conversation, defaults to ["", "Human", "Assistant"]. + + Returns: + str: A formatted prompt including the input query and history. + """ + prompt = "" + if history is None: + new_history = [] + else: + new_history = deepcopy(history) + + new_history.append({"role": roles[1], "message": input_query.strip()}) + new_history.append({"role": roles[2], "message": None}) + + for _, item in enumerate(new_history): + role = item.get("role") + message = item.get("message") + if role == roles[0]: + prompt += f"{message}\n\n" + else: + if message: + prompt += f"{role}: {message}" + else: + prompt += f"{role}: " + return prompt + +@torch.inference_mode() +def streaming_chat( + model: Any, + tokenizer: PreTrainedTokenizer, + input_query: str, + history: List[Dict] = None, + roles: list = ["", "Human", "Assistant"], + past_key_values: Tuple[Tuple[torch.FloatTensor, Any], Any] = None, + temperature: float = 0.8, + top_p: float = 0.95, + top_k: int = 50, + do_sample: bool = True, + length_penalty: float = 1.2, + max_new_tokens: int = 512, + logits_processor: LogitsProcessorList = None, + return_past_key_values: bool = False, + **kwargs, +): + """ + Streaming chat responses generation with a given model and tokenizer. + + Args: + model (Any): The language model to generate responses. + tokenizer (PreTrainedTokenizer): Tokenizer compatible with the model, used for encoding inputs and decoding responses. + input_query (str): The current user input to respond to. + history (List[Dict], optional): A list of past conversations, where each conversation is a dictionary with keys 'role' and 'message'. + roles (list): Roles involved in the conversation, defaults to ["", "Human", "Assistant"]. + past_key_values (Tuple[Tuple[torch.FloatTensor, Any], Any], optional): Past key values for incremental decoding. + temperature (float): The temperature value for token sampling, defaults to 0.8. + top_p (float): Nucleus sampling probability threshold, defaults to 0.95. + top_k (int): Top-K filtering threshold, defaults to 50. + do_sample (bool): Whether to sample responses, defaults to True. + length_penalty (float): Penalty for response length, defaults to 1.2. + max_new_tokens (int): Maximum number of new tokens to generate, defaults to 512. + logits_processor (LogitsProcessorList, optional): Custom logits processors, defaults to None. + return_past_key_values (bool): Whether to return past key values for further incremental decoding, defaults to False. + **kwargs: Additional keyword arguments for generation. + + Yields: + Tuple[str, List[Dict], Optional[Tuple[Tuple[torch.FloatTensor, Any], Any]]]: A tuple containing the generated response, updated history, and + optionally the updated past key values if `return_past_key_values` is True. + + Ensures padding is on the left side for the tokenizer. + """ + assert tokenizer.padding_side == "left", "Current generation only supports left padding." + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + + generation_kwargs = { + 'temperature': temperature, + 'top_p': top_p, + 'top_k': top_k, + 'do_sample': do_sample, + 'max_new_tokens': max_new_tokens, + 'length_penalty': length_penalty, + 'use_cache': True, + **kwargs + } + + prompt_str = get_prompt_template(input_query, history=history, roles=roles) + + eos_token_id = [tokenizer.eos_token_id] + inputs = tokenizer(prompt_str, return_tensors="pt").to(model.device) + history.append({"role": roles[1], "message": input_query.strip()}) + history.append({"role": roles[2], "message": None}) + + for outputs in stream_generate(model, **inputs, past_key_values=past_key_values, + eos_token_id=eos_token_id, return_past_key_values=return_past_key_values, + **generation_kwargs): + if return_past_key_values: + outputs, past_key_values = outputs + + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] + response = tokenizer.decode(outputs) + + history[-1]["message"] = response.strip() + if return_past_key_values: + yield response, history, past_key_values + else: + yield response, history + + +@torch.inference_mode() +def stream_generate( + model: Any, + input_ids: torch.Tensor, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + return_past_key_values: bool = False, + **kwargs, +): + """ + Generates sequences of token ids using the specified model and generation parameters. + Adapted from https://huggingface.co/THUDM/chatglm3-6b/blob/main/modeling_chatglm.py + + Args: + model (Any): The model used for generating sequences of token ids. + input_ids (torch.Tensor): The sequence used as a prompt for the generation or as model inputs to the encoder. + generation_config (Optional[GenerationConfig]): The generation configuration to be used as base parametrization for the generation call. + logits_processor (Optional[LogitsProcessorList]): Custom logits processors that complement the default logits processors built from arguments + and generation config. + stopping_criteria (Optional[StoppingCriteriaList]): Custom stopping criteria that complement the default stopping criteria built from arguments + and a generation config. + prefix_allowed_tokens_fn (Optional[Callable[[int, torch.Tensor], List[int]]]): Function to constrain token generation. + return_past_key_values (bool): Whether to return past key values for further incremental decoding, defaults to False. + **kwargs: Additional parameters for model generation. + + Yields: + torch.Tensor: The generated token IDs, updated after each generation step. + Optional[Tuple[Tuple[torch.FloatTensor, Any], Any]]: The past key values, returned if `return_past_key_values` is True, defaults to False. + """ + input_ids_len = input_ids.size(1) + + if generation_config is None: + generation_config = model.generation_config + generation_config = deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + + eos_token_id = generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + + if generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_len + + if input_ids_len >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if model.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_len}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + # prepare distribution pre_processing samplers + logits_processor = model._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_len, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + # prepare stopping criteria + stopping_criteria = model._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + + logits_warper = model._get_logits_warper(generation_config) + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + + while True: + model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) + # forward pass to get next token + outputs = model( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + # NOTE: this is correct only in left padding mode + # pre-process distribution + next_token_logits = outputs.logits[:, -1, :] + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = model._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder + ) + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + ) + + if return_past_key_values: + yield input_ids, outputs.past_key_values + else: + yield input_ids + # stop when each sentence is finished, or if exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + break \ No newline at end of file diff --git a/applications/Colossal-LLaMA-2/stream_chat_example.py b/applications/Colossal-LLaMA-2/stream_chat_example.py new file mode 100644 index 000000000..3e45c690f --- /dev/null +++ b/applications/Colossal-LLaMA-2/stream_chat_example.py @@ -0,0 +1,55 @@ +import os +import argparse + +from transformers import AutoTokenizer, AutoModelForCausalLM +from colossal_llama2.utils.stream_chat_patch import streaming_chat + +SYSTEM = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." + +def main(args): + model = AutoModelForCausalLM.from_pretrained(args.model_path).cuda().eval() + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) + + past_key_values, history = None, [] + roles = ["", "Human", "Assistant"] + + history = [] + history.append({"role": roles[0], "message": SYSTEM}) + + while True: + input_query = input(f"\n{roles[1]}: ") + if input_query.strip() == "exit": + break + if input_query.strip() == "clear": + past_key_values, history = None, [] + continue + + print(f"\n{roles[2]}: ", end="") + gen_len = 0 + for response, history, past_key_values in streaming_chat( + model, tokenizer, input_query, history=history, roles=roles, + temperature = args.temperature, + top_p = args.top_p, + top_k = args.top_k, + do_sample = args.do_sample, + length_penalty = args.length_penalty, + max_new_tokens = args.max_new_tokens, + past_key_values=past_key_values, + return_past_key_values=True): + + output = response[gen_len:] + print(output, end="", flush=True) + gen_len = len(response) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--model_path', type=str, default=None, help="path to chat version model") + parser.add_argument('--tokenizer_path', type=str, default=None, help="path to chat version tokenizer") + parser.add_argument('--temperature', type=float, default=0.8, help="set temperature") + parser.add_argument('--top_p', type=float, default=0.95, help="set top p value") + parser.add_argument('--top_k', type=int, default=50, help="set top k value") + parser.add_argument('--do_sample', type=bool, default=True, help="whether turn on do_sample or not") + parser.add_argument('--length_penalty', type=float, default=1.2, help="set length penalty") + parser.add_argument('--max_new_tokens', type=int, default=512, help="set max new tokens") + args = parser.parse_args() + main(args) \ No newline at end of file