From 4ffd054a2ac3148edc952d69c9d93c10be2a538a Mon Sep 17 00:00:00 2001 From: FangYin Cheng Date: Sat, 18 Nov 2023 17:15:26 +0800 Subject: [PATCH] fix(model): Fix benchmarks bugs --- pilot/model/cluster/worker/default_worker.py | 6 +- pilot/model/llm_out/vllm_llm.py | 22 +- pilot/model/model_adapter.py | 13 +- .../llm/fastchat_benchmarks_inference.py | 565 ++++++++++++++++++ pilot/utils/benchmarks/llm/llm_benchmarks.py | 23 +- scripts/run_llm_benchmarks.sh | 2 +- 6 files changed, 611 insertions(+), 20 deletions(-) create mode 100644 pilot/utils/benchmarks/llm/fastchat_benchmarks_inference.py diff --git a/pilot/model/cluster/worker/default_worker.py b/pilot/model/cluster/worker/default_worker.py index 3d7b6be98..c798e3075 100644 --- a/pilot/model/cluster/worker/default_worker.py +++ b/pilot/model/cluster/worker/default_worker.py @@ -1,9 +1,9 @@ import os import logging -from typing import Dict, Iterator, List, Optional +from typing import Dict, Iterator, List, Optional import time -import copy +import traceback from pilot.configs.model_config import get_device from pilot.model.model_adapter import get_llm_model_adapter, LLMModelAdaper @@ -332,6 +332,8 @@ class DefaultModelWorker(ModelWorker): text="**GPU OutOfMemory, Please Refresh.**", error_code=1 ) else: + msg = traceback.format_exc() + logger.error(f"Model inference error, detail: {msg}") model_output = ModelOutput( text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", error_code=1, diff --git a/pilot/model/llm_out/vllm_llm.py b/pilot/model/llm_out/vllm_llm.py index 375e9589a..de108c87c 100644 --- a/pilot/model/llm_out/vllm_llm.py +++ b/pilot/model/llm_out/vllm_llm.py @@ -1,9 +1,13 @@ from typing import Dict +import os from vllm import AsyncLLMEngine from vllm.utils import random_uuid from vllm.sampling_params import SamplingParams +_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true" + + async def generate_stream( model: AsyncLLMEngine, tokenizer, params: Dict, device: str, context_len: int ): @@ -37,15 +41,29 @@ async def generate_stream( top_p = max(top_p, 1e-5) if temperature <= 1e-5: top_p = 1.0 + gen_params = { + "stop": list(stop), + "ignore_eos": False, + } + prompt_token_ids = None + if _IS_BENCHMARK: + gen_params["stop"] = [] + gen_params["ignore_eos"] = True + prompt_len = context_len - max_new_tokens - 2 + prompt_token_ids = tokenizer([prompt]).input_ids[0] + prompt_token_ids = prompt_token_ids[-prompt_len:] sampling_params = SamplingParams( n=1, temperature=temperature, top_p=top_p, use_beam_search=False, - stop=list(stop), max_tokens=max_new_tokens, + **gen_params + ) + + results_generator = model.generate( + prompt, sampling_params, request_id, prompt_token_ids=prompt_token_ids ) - results_generator = model.generate(prompt, sampling_params, request_id) async for request_output in results_generator: prompt = request_output.prompt if echo: diff --git a/pilot/model/model_adapter.py b/pilot/model/model_adapter.py index 8fd242882..5f328c554 100644 --- a/pilot/model/model_adapter.py +++ b/pilot/model/model_adapter.py @@ -39,7 +39,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) thread_local = threading.local() - +_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true" _OLD_MODELS = [ "llama-cpp", @@ -228,9 +228,16 @@ class FastChatLLMModelAdaperWrapper(LLMModelAdaper): return self._adapter.load_model(model_path, from_pretrained_kwargs) def get_generate_stream_function(self, model: "TorchNNModule", model_path: str): - from fastchat.model.model_adapter import get_generate_stream_function + if _IS_BENCHMARK: + from pilot.utils.benchmarks.llm.fastchat_benchmarks_inference import ( + generate_stream, + ) - return get_generate_stream_function(model, model_path) + return generate_stream + else: + from fastchat.model.model_adapter import get_generate_stream_function + + return get_generate_stream_function(model, model_path) def get_default_conv_template( self, model_name: str, model_path: str diff --git a/pilot/utils/benchmarks/llm/fastchat_benchmarks_inference.py b/pilot/utils/benchmarks/llm/fastchat_benchmarks_inference.py new file mode 100644 index 000000000..67b11357e --- /dev/null +++ b/pilot/utils/benchmarks/llm/fastchat_benchmarks_inference.py @@ -0,0 +1,565 @@ +""" +Adapted from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py. +For benchmarks. + +""" +import abc +import gc +import json +import math +import os +import sys +import time +from typing import Iterable, Optional, Dict, TYPE_CHECKING +import warnings + +import psutil +import torch +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + LlamaTokenizer, + LlamaForCausalLM, + AutoModel, + AutoModelForSeq2SeqLM, + T5Tokenizer, + AutoConfig, +) +from transformers.generation.logits_process import ( + LogitsProcessorList, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, +) + +from fastchat.conversation import get_conv_template, SeparatorStyle +from fastchat.model.model_adapter import ( + load_model, + get_conversation_template, + get_generate_stream_function, +) +from fastchat.modules.awq import AWQConfig +from fastchat.modules.gptq import GptqConfig + +if TYPE_CHECKING: + from fastchat.modules.exllama import ExllamaConfig + from fastchat.modules.xfastertransformer import XftConfig + +from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length + + +def prepare_logits_processor( + temperature: float, repetition_penalty: float, top_p: float, top_k: int +) -> LogitsProcessorList: + processor_list = LogitsProcessorList() + # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases. + if temperature >= 1e-5 and temperature != 1.0: + processor_list.append(TemperatureLogitsWarper(temperature)) + if repetition_penalty > 1.0: + processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) + if 1e-8 <= top_p < 1.0: + processor_list.append(TopPLogitsWarper(top_p)) + if top_k > 0: + processor_list.append(TopKLogitsWarper(top_k)) + return processor_list + + +@torch.inference_mode() +def generate_stream( + model, + tokenizer, + params: Dict, + device: str, + context_len: int, + stream_interval: int = 2, + judge_sent_end: bool = False, +): + if hasattr(model, "device"): + device = model.device + + # Read parameters + prompt = params["prompt"] + len_prompt = len(prompt) + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = int(params.get("top_k", -1)) # -1 means disable + max_new_tokens = int(params.get("max_new_tokens", 256)) + logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1. + echo = bool(params.get("echo", True)) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + if tokenizer.eos_token_id not in stop_token_ids: + stop_token_ids.append(tokenizer.eos_token_id) + + logits_processor = prepare_logits_processor( + temperature, repetition_penalty, top_p, top_k + ) + input_ids = tokenizer(prompt).input_ids + + if model.config.is_encoder_decoder: + max_src_len = context_len + else: # truncate + max_src_len = context_len - max_new_tokens - 1 + + input_ids = input_ids[-max_src_len:] + output_ids = list(input_ids) + input_echo_len = len(input_ids) + + # Don't stop generate until max_new_tokens is reached. + stop_token_ids = [] + stop_str = None + + if model.config.is_encoder_decoder: + if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models. + raise NotImplementedError + encoder_output = model.encoder( + input_ids=torch.as_tensor([input_ids], device=device) + )[0] + start_ids = torch.as_tensor( + [[model.generation_config.decoder_start_token_id]], + dtype=torch.int64, + device=device, + ) + else: + start_ids = torch.as_tensor([input_ids], device=device) + + past_key_values = out = None + token_logprobs = [None] # The first token has no logprobs. + sent_interrupt = False + finish_reason = None + for i in range(max_new_tokens): + if i == 0: # prefill + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=start_ids, + encoder_hidden_states=encoder_output, + use_cache=True, + ) + logits = model.lm_head(out[0]) + else: + out = model(input_ids=start_ids, use_cache=True) + logits = out.logits + past_key_values = out.past_key_values + + if logprobs is not None: + # Prefull logprobs for the prompt. + shift_input_ids = start_ids[..., 1:].contiguous() + shift_logits = logits[..., :-1, :].contiguous() + shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist() + for label_id, logit in zip( + shift_input_ids[0].tolist(), shift_logits[0] + ): + token_logprobs.append(logit[label_id]) + else: # decoding + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=torch.as_tensor( + [[token] if not sent_interrupt else output_ids], + device=device, + ), + encoder_hidden_states=encoder_output, + use_cache=True, + past_key_values=past_key_values if not sent_interrupt else None, + ) + sent_interrupt = False + + logits = model.lm_head(out[0]) + else: + out = model( + input_ids=torch.as_tensor( + [[token] if not sent_interrupt else output_ids], + device=device, + ), + use_cache=True, + past_key_values=past_key_values if not sent_interrupt else None, + ) + sent_interrupt = False + logits = out.logits + past_key_values = out.past_key_values + + if logits_processor: + if repetition_penalty > 1.0: + tmp_output_ids = torch.as_tensor([output_ids], device=logits.device) + else: + tmp_output_ids = None + last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] + else: + last_token_logits = logits[0, -1, :] + + if device == "mps": + # Switch to CPU by avoiding some bugs in mps backend. + last_token_logits = last_token_logits.float().to("cpu") + + if temperature < 1e-5 or top_p < 1e-8: # greedy + _, indices = torch.topk(last_token_logits, 2) + tokens = [int(index) for index in indices.tolist()] + else: + probs = torch.softmax(last_token_logits, dim=-1) + indices = torch.multinomial(probs, num_samples=2) + tokens = [int(token) for token in indices.tolist()] + token = tokens[0] + output_ids.append(token) + if logprobs is not None: + # Cannot use last_token_logits because logprobs is based on raw logits. + token_logprobs.append( + torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist() + ) + + if token in stop_token_ids: + stopped = True + else: + stopped = False + + # Yield the output tokens + if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: + if echo: + tmp_output_ids = output_ids + rfind_start = len_prompt + else: + tmp_output_ids = output_ids[input_echo_len:] + rfind_start = 0 + + output = tokenizer.decode( + tmp_output_ids, + skip_special_tokens=True, + spaces_between_special_tokens=False, + clean_up_tokenization_spaces=True, + ) + ret_logprobs = None + if logprobs is not None: + ret_logprobs = { + "text_offset": [], + "tokens": [ + tokenizer.decode(token) + for token in ( + output_ids if echo else output_ids[input_echo_len:] + ) + ], + "token_logprobs": token_logprobs + if echo + else token_logprobs[input_echo_len:], + "top_logprobs": [{}] + * len(token_logprobs if echo else token_logprobs[input_echo_len:]), + } + # Compute text_offset + curr_pos = 0 + for text in ret_logprobs["tokens"]: + ret_logprobs["text_offset"].append(curr_pos) + curr_pos += len(text) + + # TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way + if judge_sent_end and stopped and not is_sentence_complete(output): + if len(tokens) > 1: + token = tokens[1] + output_ids[-1] = token + else: + output_ids.pop() + stopped = False + sent_interrupt = True + + partially_stopped = False + if stop_str: + if isinstance(stop_str, str): + pos = output.rfind(stop_str, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + else: + partially_stopped = is_partial_stop(output, stop_str) + elif isinstance(stop_str, Iterable): + for each_stop in stop_str: + pos = output.rfind(each_stop, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + break + else: + partially_stopped = is_partial_stop(output, each_stop) + if partially_stopped: + break + else: + raise ValueError("Invalid stop field type.") + + # Prevent yielding partial stop sequence + if not partially_stopped: + yield { + "text": output, + "logprobs": ret_logprobs, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } + + if stopped: + break + + # Finish stream event, which contains finish reason + else: + finish_reason = "length" + + if stopped: + finish_reason = "stop" + + yield { + "text": output, + "logprobs": ret_logprobs, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": finish_reason, + } + + # Clean + del past_key_values, out + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() + + +class ChatIO(abc.ABC): + @abc.abstractmethod + def prompt_for_input(self, role: str) -> str: + """Prompt for input from a role.""" + + @abc.abstractmethod + def prompt_for_output(self, role: str): + """Prompt for output from a role.""" + + @abc.abstractmethod + def stream_output(self, output_stream): + """Stream output.""" + + @abc.abstractmethod + def print_output(self, text: str): + """Print output.""" + + +def chat_loop( + model_path: str, + device: str, + num_gpus: int, + max_gpu_memory: str, + dtype: Optional[torch.dtype], + load_8bit: bool, + cpu_offloading: bool, + conv_template: Optional[str], + conv_system_msg: Optional[str], + temperature: float, + repetition_penalty: float, + max_new_tokens: int, + chatio: ChatIO, + gptq_config: Optional[GptqConfig] = None, + awq_config: Optional[AWQConfig] = None, + exllama_config: Optional["ExllamaConfig"] = None, + xft_config: Optional["XftConfig"] = None, + revision: str = "main", + judge_sent_end: bool = True, + debug: bool = True, + history: bool = True, +): + # Model + model, tokenizer = load_model( + model_path, + device=device, + num_gpus=num_gpus, + max_gpu_memory=max_gpu_memory, + dtype=dtype, + load_8bit=load_8bit, + cpu_offloading=cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + exllama_config=exllama_config, + xft_config=xft_config, + revision=revision, + debug=debug, + ) + generate_stream_func = get_generate_stream_function(model, model_path) + + model_type = str(type(model)).lower() + is_t5 = "t5" in model_type + is_codet5p = "codet5p" in model_type + is_xft = "xft" in model_type + + # Hardcode T5's default repetition penalty to be 1.2 + if is_t5 and repetition_penalty == 1.0: + repetition_penalty = 1.2 + + # Set context length + context_len = get_context_length(model.config) + + # Chat + def new_chat(): + if conv_template: + conv = get_conv_template(conv_template) + else: + conv = get_conversation_template(model_path) + if conv_system_msg is not None: + conv.set_system_message(conv_system_msg) + return conv + + def reload_conv(conv): + """ + Reprints the conversation from the start. + """ + for message in conv.messages[conv.offset :]: + chatio.prompt_for_output(message[0]) + chatio.print_output(message[1]) + + conv = None + + while True: + if not history or not conv: + conv = new_chat() + + try: + inp = chatio.prompt_for_input(conv.roles[0]) + except EOFError: + inp = "" + + if inp == "!!exit" or not inp: + print("exit...") + break + elif inp == "!!reset": + print("resetting...") + conv = new_chat() + continue + elif inp == "!!remove": + print("removing last message...") + if len(conv.messages) > conv.offset: + # Assistant + if conv.messages[-1][0] == conv.roles[1]: + conv.messages.pop() + # User + if conv.messages[-1][0] == conv.roles[0]: + conv.messages.pop() + reload_conv(conv) + else: + print("No messages to remove.") + continue + elif inp == "!!regen": + print("regenerating last message...") + if len(conv.messages) > conv.offset: + # Assistant + if conv.messages[-1][0] == conv.roles[1]: + conv.messages.pop() + # User + if conv.messages[-1][0] == conv.roles[0]: + reload_conv(conv) + # Set inp to previous message + inp = conv.messages.pop()[1] + else: + # Shouldn't happen in normal circumstances + print("No user message to regenerate from.") + continue + else: + print("No messages to regenerate.") + continue + elif inp.startswith("!!save"): + args = inp.split(" ", 1) + + if len(args) != 2: + print("usage: !!save ") + continue + else: + filename = args[1] + + # Add .json if extension not present + if not "." in filename: + filename += ".json" + + print("saving...", filename) + with open(filename, "w") as outfile: + json.dump(conv.dict(), outfile) + continue + elif inp.startswith("!!load"): + args = inp.split(" ", 1) + + if len(args) != 2: + print("usage: !!load ") + continue + else: + filename = args[1] + + # Check if file exists and add .json if needed + if not os.path.exists(filename): + if (not filename.endswith(".json")) and os.path.exists( + filename + ".json" + ): + filename += ".json" + else: + print("file not found:", filename) + continue + + print("loading...", filename) + with open(filename, "r") as infile: + new_conv = json.load(infile) + + conv = get_conv_template(new_conv["template_name"]) + conv.set_system_message(new_conv["system_message"]) + conv.messages = new_conv["messages"] + reload_conv(conv) + continue + + conv.append_message(conv.roles[0], inp) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + if is_codet5p: # codet5p is a code completion model. + prompt = inp + + gen_params = { + "model": model_path, + "prompt": prompt, + "temperature": temperature, + "repetition_penalty": repetition_penalty, + "max_new_tokens": max_new_tokens, + "stop": conv.stop_str, + "stop_token_ids": conv.stop_token_ids, + "echo": False, + } + + try: + chatio.prompt_for_output(conv.roles[1]) + output_stream = generate_stream_func( + model, + tokenizer, + gen_params, + device, + context_len=context_len, + judge_sent_end=judge_sent_end, + ) + t = time.time() + outputs = chatio.stream_output(output_stream) + duration = time.time() - t + conv.update_last_message(outputs.strip()) + + if debug: + num_tokens = len(tokenizer.encode(outputs)) + msg = { + "conv_template": conv.name, + "prompt": prompt, + "outputs": outputs, + "speed (token/s)": round(num_tokens / duration, 2), + } + print(f"\n{msg}\n") + + except KeyboardInterrupt: + print("stopped generation.") + # If generation didn't finish + if conv.messages[-1][1] is None: + conv.messages.pop() + # Remove last user message, so there isn't a double up + if conv.messages[-1][0] == conv.roles[0]: + conv.messages.pop() + + reload_conv(conv) diff --git a/pilot/utils/benchmarks/llm/llm_benchmarks.py b/pilot/utils/benchmarks/llm/llm_benchmarks.py index 7e4440b76..1d651fff6 100644 --- a/pilot/utils/benchmarks/llm/llm_benchmarks.py +++ b/pilot/utils/benchmarks/llm/llm_benchmarks.py @@ -40,15 +40,6 @@ def get_result_csv_file() -> str: ) -input_output_length_pair = [ - [64, 256], - [64, 512], - [64, 1024], - [512, 1024], - [1024, 1024], - [1024, 2048], - [2048, 2048], -] input_lens = [64, 64] output_lens = [256, 512] @@ -96,8 +87,8 @@ def build_param( system_prompt: str = None, ) -> Dict: hist = [] - if system_prompt: - hist.append()( + if system_prompt is not None: + hist.append( ModelMessage(role=ModelMessageRoleType.SYSTEM, content=system_prompt) ) hist.append(ModelMessage(role=ModelMessageRoleType.HUMAN, content=user_input)) @@ -119,8 +110,15 @@ async def run_batch( ): tasks = [] prompt = read_prompt_from_file("11k") + if model_type == "vllm": + max_input_str_len = input_len + if "baichuan" in model_name: + # TODO prompt handle first + max_input_str_len *= 2 + prompt = prompt[-max_input_str_len:] + for _ in range(parallel_num): - params = build_param(input_len, output_len, prompt) + params = build_param(input_len, output_len, prompt, system_prompt="") tasks.append(wh.generate(params)) print( f"Begin run benchmarks, model name: {model_name}, input_len: {input_len}, output_len: {output_len}, parallel_num: {parallel_num}, save result to {output_file}" @@ -136,6 +134,7 @@ async def run_batch( metrics = r.metrics if isinstance(metrics, dict): metrics = ModelInferenceMetrics(**metrics) + print(r) test_total_tokens += metrics.total_tokens row_data = metrics.to_dict() rows.append(row_data) diff --git a/scripts/run_llm_benchmarks.sh b/scripts/run_llm_benchmarks.sh index 578de54ee..ffa9ac6da 100755 --- a/scripts/run_llm_benchmarks.sh +++ b/scripts/run_llm_benchmarks.sh @@ -11,7 +11,7 @@ parallel_nums=${3:-$default_parallel_nums} run_benchmark() { local model_name=$1 local model_type=$2 - python pilot/utils/benchmarks/llm/llm_benchmarks.py --model_name ${model_name} --model_type ${model_type} --input_lens ${input_lens} --output_lens ${output_lens} --parallel_nums ${parallel_nums} + DB_GPT_MODEL_BENCHMARK=true python pilot/utils/benchmarks/llm/llm_benchmarks.py --model_name ${model_name} --model_type ${model_type} --input_lens ${input_lens} --output_lens ${output_lens} --parallel_nums ${parallel_nums} } run_benchmark "vicuna-7b-v1.5" "huggingface"