diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 9a19f3255..7a8a37cae 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -635,7 +635,7 @@ def _build_model_operator( model_task_name="llm_model_node", cache_task_name="llm_model_cache_node", ) - # Create a join node to merge outputs from the model and cache nodes, just keep the fist not empty output + # Create a join node to merge outputs from the model and cache nodes, just keep the first not empty output join_node = JoinOperator( combine_function=lambda model_out, cache_out: cache_out or model_out ) diff --git a/pilot/utils/benchmarks/llm/fastchat_benchmarks_inference.py b/pilot/utils/benchmarks/llm/fastchat_benchmarks_inference.py index 67b11357e..cb05ab33f 100644 --- a/pilot/utils/benchmarks/llm/fastchat_benchmarks_inference.py +++ b/pilot/utils/benchmarks/llm/fastchat_benchmarks_inference.py @@ -3,28 +3,10 @@ Adapted from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/ser For benchmarks. """ -import abc import gc -import json -import math -import os -import sys -import time -from typing import Iterable, Optional, Dict, TYPE_CHECKING -import warnings +from typing import Iterable, Dict -import psutil import torch -from transformers import ( - AutoTokenizer, - AutoModelForCausalLM, - LlamaTokenizer, - LlamaForCausalLM, - AutoModel, - AutoModelForSeq2SeqLM, - T5Tokenizer, - AutoConfig, -) from transformers.generation.logits_process import ( LogitsProcessorList, RepetitionPenaltyLogitsProcessor, @@ -33,18 +15,6 @@ from transformers.generation.logits_process import ( TopPLogitsWarper, ) -from fastchat.conversation import get_conv_template, SeparatorStyle -from fastchat.model.model_adapter import ( - load_model, - get_conversation_template, - get_generate_stream_function, -) -from fastchat.modules.awq import AWQConfig -from fastchat.modules.gptq import GptqConfig - -if TYPE_CHECKING: - from fastchat.modules.exllama import ExllamaConfig - from fastchat.modules.xfastertransformer import XftConfig from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length @@ -324,242 +294,3 @@ def generate_stream( torch.xpu.empty_cache() if device == "npu": torch.npu.empty_cache() - - -class ChatIO(abc.ABC): - @abc.abstractmethod - def prompt_for_input(self, role: str) -> str: - """Prompt for input from a role.""" - - @abc.abstractmethod - def prompt_for_output(self, role: str): - """Prompt for output from a role.""" - - @abc.abstractmethod - def stream_output(self, output_stream): - """Stream output.""" - - @abc.abstractmethod - def print_output(self, text: str): - """Print output.""" - - -def chat_loop( - model_path: str, - device: str, - num_gpus: int, - max_gpu_memory: str, - dtype: Optional[torch.dtype], - load_8bit: bool, - cpu_offloading: bool, - conv_template: Optional[str], - conv_system_msg: Optional[str], - temperature: float, - repetition_penalty: float, - max_new_tokens: int, - chatio: ChatIO, - gptq_config: Optional[GptqConfig] = None, - awq_config: Optional[AWQConfig] = None, - exllama_config: Optional["ExllamaConfig"] = None, - xft_config: Optional["XftConfig"] = None, - revision: str = "main", - judge_sent_end: bool = True, - debug: bool = True, - history: bool = True, -): - # Model - model, tokenizer = load_model( - model_path, - device=device, - num_gpus=num_gpus, - max_gpu_memory=max_gpu_memory, - dtype=dtype, - load_8bit=load_8bit, - cpu_offloading=cpu_offloading, - gptq_config=gptq_config, - awq_config=awq_config, - exllama_config=exllama_config, - xft_config=xft_config, - revision=revision, - debug=debug, - ) - generate_stream_func = get_generate_stream_function(model, model_path) - - model_type = str(type(model)).lower() - is_t5 = "t5" in model_type - is_codet5p = "codet5p" in model_type - is_xft = "xft" in model_type - - # Hardcode T5's default repetition penalty to be 1.2 - if is_t5 and repetition_penalty == 1.0: - repetition_penalty = 1.2 - - # Set context length - context_len = get_context_length(model.config) - - # Chat - def new_chat(): - if conv_template: - conv = get_conv_template(conv_template) - else: - conv = get_conversation_template(model_path) - if conv_system_msg is not None: - conv.set_system_message(conv_system_msg) - return conv - - def reload_conv(conv): - """ - Reprints the conversation from the start. - """ - for message in conv.messages[conv.offset :]: - chatio.prompt_for_output(message[0]) - chatio.print_output(message[1]) - - conv = None - - while True: - if not history or not conv: - conv = new_chat() - - try: - inp = chatio.prompt_for_input(conv.roles[0]) - except EOFError: - inp = "" - - if inp == "!!exit" or not inp: - print("exit...") - break - elif inp == "!!reset": - print("resetting...") - conv = new_chat() - continue - elif inp == "!!remove": - print("removing last message...") - if len(conv.messages) > conv.offset: - # Assistant - if conv.messages[-1][0] == conv.roles[1]: - conv.messages.pop() - # User - if conv.messages[-1][0] == conv.roles[0]: - conv.messages.pop() - reload_conv(conv) - else: - print("No messages to remove.") - continue - elif inp == "!!regen": - print("regenerating last message...") - if len(conv.messages) > conv.offset: - # Assistant - if conv.messages[-1][0] == conv.roles[1]: - conv.messages.pop() - # User - if conv.messages[-1][0] == conv.roles[0]: - reload_conv(conv) - # Set inp to previous message - inp = conv.messages.pop()[1] - else: - # Shouldn't happen in normal circumstances - print("No user message to regenerate from.") - continue - else: - print("No messages to regenerate.") - continue - elif inp.startswith("!!save"): - args = inp.split(" ", 1) - - if len(args) != 2: - print("usage: !!save ") - continue - else: - filename = args[1] - - # Add .json if extension not present - if not "." in filename: - filename += ".json" - - print("saving...", filename) - with open(filename, "w") as outfile: - json.dump(conv.dict(), outfile) - continue - elif inp.startswith("!!load"): - args = inp.split(" ", 1) - - if len(args) != 2: - print("usage: !!load ") - continue - else: - filename = args[1] - - # Check if file exists and add .json if needed - if not os.path.exists(filename): - if (not filename.endswith(".json")) and os.path.exists( - filename + ".json" - ): - filename += ".json" - else: - print("file not found:", filename) - continue - - print("loading...", filename) - with open(filename, "r") as infile: - new_conv = json.load(infile) - - conv = get_conv_template(new_conv["template_name"]) - conv.set_system_message(new_conv["system_message"]) - conv.messages = new_conv["messages"] - reload_conv(conv) - continue - - conv.append_message(conv.roles[0], inp) - conv.append_message(conv.roles[1], None) - prompt = conv.get_prompt() - - if is_codet5p: # codet5p is a code completion model. - prompt = inp - - gen_params = { - "model": model_path, - "prompt": prompt, - "temperature": temperature, - "repetition_penalty": repetition_penalty, - "max_new_tokens": max_new_tokens, - "stop": conv.stop_str, - "stop_token_ids": conv.stop_token_ids, - "echo": False, - } - - try: - chatio.prompt_for_output(conv.roles[1]) - output_stream = generate_stream_func( - model, - tokenizer, - gen_params, - device, - context_len=context_len, - judge_sent_end=judge_sent_end, - ) - t = time.time() - outputs = chatio.stream_output(output_stream) - duration = time.time() - t - conv.update_last_message(outputs.strip()) - - if debug: - num_tokens = len(tokenizer.encode(outputs)) - msg = { - "conv_template": conv.name, - "prompt": prompt, - "outputs": outputs, - "speed (token/s)": round(num_tokens / duration, 2), - } - print(f"\n{msg}\n") - - except KeyboardInterrupt: - print("stopped generation.") - # If generation didn't finish - if conv.messages[-1][1] is None: - conv.messages.pop() - # Remove last user message, so there isn't a double up - if conv.messages[-1][0] == conv.roles[0]: - conv.messages.pop() - - reload_conv(conv) diff --git a/pilot/utils/benchmarks/llm/llm_benchmarks.py b/pilot/utils/benchmarks/llm/llm_benchmarks.py index 1d651fff6..becb78b00 100644 --- a/pilot/utils/benchmarks/llm/llm_benchmarks.py +++ b/pilot/utils/benchmarks/llm/llm_benchmarks.py @@ -5,6 +5,8 @@ import sys import time import csv import argparse +import logging +import traceback from pilot.configs.model_config import ROOT_PATH, LLM_MODEL_CONFIG from pilot.model.cluster.worker.manager import ( @@ -19,14 +21,12 @@ from pilot.model.cluster import PromptRequest from pilot.scene.base_message import ModelMessage, ModelMessageRoleType -# model_name = "chatglm2-6b" -# model_name = "vicuna-7b-v1.5" -model_name = "baichuan2-7b" +model_name = "vicuna-7b-v1.5" model_path = LLM_MODEL_CONFIG[model_name] # or vllm model_type = "huggingface" -controller_addr = "http://127.0.0.1:5005" +controller_addr = "http://127.0.0.1:5000" result_csv_file = None @@ -59,7 +59,7 @@ METRICS_HEADERS = [ # Merge parallel result "test_time_cost_ms", "test_total_tokens", - "test_speed_per_second", + "test_speed_per_second", # (tokens / s) # Detail for each task "start_time_ms", "end_time_ms", @@ -93,7 +93,7 @@ def build_param( ) hist.append(ModelMessage(role=ModelMessageRoleType.HUMAN, content=user_input)) hist = list(h.dict() for h in hist) - context_len = input_len + output_len + context_len = input_len + output_len + 2 params = { "prompt": user_input, "messages": hist, @@ -167,7 +167,15 @@ async def run_model(wh: WorkerManager) -> None: os.rename(result_csv_file, f"{result_csv_file}.bak.csv") for parallel_num in parallel_nums: for input_len, output_len in zip(input_lens, output_lens): - await run_batch(wh, input_len, output_len, parallel_num, result_csv_file) + try: + await run_batch( + wh, input_len, output_len, parallel_num, result_csv_file + ) + except Exception: + msg = traceback.format_exc() + logging.error( + f"Run benchmarks error, input_len: {input_len}, output_len: {output_len}, parallel_num: {parallel_num}, error message: {msg}" + ) sys.exit(0) @@ -184,7 +192,6 @@ def startup_llm_env(): controller_addr=controller_addr, local_port=6000, start_listener=run_model, - # system_app=system_app, ) @@ -198,9 +205,9 @@ if __name__ == "__main__": parser.add_argument("--model_path", type=str, default=None) parser.add_argument("--model_type", type=str, default="huggingface") parser.add_argument("--result_csv_file", type=str, default=None) - parser.add_argument("--input_lens", type=str, default="64,64,64,512,1024,1024,2048") + parser.add_argument("--input_lens", type=str, default="8,8,256,1024") parser.add_argument( - "--output_lens", type=str, default="256,512,1024,1024,1024,2048,2048" + "--output_lens", type=str, default="256,512,1024,1024" ) parser.add_argument("--parallel_nums", type=str, default="1,2,4,16,32") parser.add_argument( @@ -225,8 +232,10 @@ if __name__ == "__main__": raise ValueError("input_lens size must equal output_lens size") if remote_model: + # Connect to remote model and run benchmarks connect_to_remote_model() else: + # Start worker manager and run benchmarks run_worker_manager( model_name=model_name, model_path=model_path,