mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-30 15:21:02 +00:00
chore: Modify benchmarks for less times
This commit is contained in:
parent
4ffd054a2a
commit
d6318c21ee
@ -635,7 +635,7 @@ def _build_model_operator(
|
||||
model_task_name="llm_model_node",
|
||||
cache_task_name="llm_model_cache_node",
|
||||
)
|
||||
# Create a join node to merge outputs from the model and cache nodes, just keep the fist not empty output
|
||||
# Create a join node to merge outputs from the model and cache nodes, just keep the first not empty output
|
||||
join_node = JoinOperator(
|
||||
combine_function=lambda model_out, cache_out: cache_out or model_out
|
||||
)
|
||||
|
@ -3,28 +3,10 @@ Adapted from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/ser
|
||||
For benchmarks.
|
||||
|
||||
"""
|
||||
import abc
|
||||
import gc
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import Iterable, Optional, Dict, TYPE_CHECKING
|
||||
import warnings
|
||||
from typing import Iterable, Dict
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
LlamaTokenizer,
|
||||
LlamaForCausalLM,
|
||||
AutoModel,
|
||||
AutoModelForSeq2SeqLM,
|
||||
T5Tokenizer,
|
||||
AutoConfig,
|
||||
)
|
||||
from transformers.generation.logits_process import (
|
||||
LogitsProcessorList,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
@ -33,18 +15,6 @@ from transformers.generation.logits_process import (
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
|
||||
from fastchat.conversation import get_conv_template, SeparatorStyle
|
||||
from fastchat.model.model_adapter import (
|
||||
load_model,
|
||||
get_conversation_template,
|
||||
get_generate_stream_function,
|
||||
)
|
||||
from fastchat.modules.awq import AWQConfig
|
||||
from fastchat.modules.gptq import GptqConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastchat.modules.exllama import ExllamaConfig
|
||||
from fastchat.modules.xfastertransformer import XftConfig
|
||||
|
||||
from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length
|
||||
|
||||
@ -324,242 +294,3 @@ def generate_stream(
|
||||
torch.xpu.empty_cache()
|
||||
if device == "npu":
|
||||
torch.npu.empty_cache()
|
||||
|
||||
|
||||
class ChatIO(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def prompt_for_input(self, role: str) -> str:
|
||||
"""Prompt for input from a role."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def prompt_for_output(self, role: str):
|
||||
"""Prompt for output from a role."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def stream_output(self, output_stream):
|
||||
"""Stream output."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def print_output(self, text: str):
|
||||
"""Print output."""
|
||||
|
||||
|
||||
def chat_loop(
|
||||
model_path: str,
|
||||
device: str,
|
||||
num_gpus: int,
|
||||
max_gpu_memory: str,
|
||||
dtype: Optional[torch.dtype],
|
||||
load_8bit: bool,
|
||||
cpu_offloading: bool,
|
||||
conv_template: Optional[str],
|
||||
conv_system_msg: Optional[str],
|
||||
temperature: float,
|
||||
repetition_penalty: float,
|
||||
max_new_tokens: int,
|
||||
chatio: ChatIO,
|
||||
gptq_config: Optional[GptqConfig] = None,
|
||||
awq_config: Optional[AWQConfig] = None,
|
||||
exllama_config: Optional["ExllamaConfig"] = None,
|
||||
xft_config: Optional["XftConfig"] = None,
|
||||
revision: str = "main",
|
||||
judge_sent_end: bool = True,
|
||||
debug: bool = True,
|
||||
history: bool = True,
|
||||
):
|
||||
# Model
|
||||
model, tokenizer = load_model(
|
||||
model_path,
|
||||
device=device,
|
||||
num_gpus=num_gpus,
|
||||
max_gpu_memory=max_gpu_memory,
|
||||
dtype=dtype,
|
||||
load_8bit=load_8bit,
|
||||
cpu_offloading=cpu_offloading,
|
||||
gptq_config=gptq_config,
|
||||
awq_config=awq_config,
|
||||
exllama_config=exllama_config,
|
||||
xft_config=xft_config,
|
||||
revision=revision,
|
||||
debug=debug,
|
||||
)
|
||||
generate_stream_func = get_generate_stream_function(model, model_path)
|
||||
|
||||
model_type = str(type(model)).lower()
|
||||
is_t5 = "t5" in model_type
|
||||
is_codet5p = "codet5p" in model_type
|
||||
is_xft = "xft" in model_type
|
||||
|
||||
# Hardcode T5's default repetition penalty to be 1.2
|
||||
if is_t5 and repetition_penalty == 1.0:
|
||||
repetition_penalty = 1.2
|
||||
|
||||
# Set context length
|
||||
context_len = get_context_length(model.config)
|
||||
|
||||
# Chat
|
||||
def new_chat():
|
||||
if conv_template:
|
||||
conv = get_conv_template(conv_template)
|
||||
else:
|
||||
conv = get_conversation_template(model_path)
|
||||
if conv_system_msg is not None:
|
||||
conv.set_system_message(conv_system_msg)
|
||||
return conv
|
||||
|
||||
def reload_conv(conv):
|
||||
"""
|
||||
Reprints the conversation from the start.
|
||||
"""
|
||||
for message in conv.messages[conv.offset :]:
|
||||
chatio.prompt_for_output(message[0])
|
||||
chatio.print_output(message[1])
|
||||
|
||||
conv = None
|
||||
|
||||
while True:
|
||||
if not history or not conv:
|
||||
conv = new_chat()
|
||||
|
||||
try:
|
||||
inp = chatio.prompt_for_input(conv.roles[0])
|
||||
except EOFError:
|
||||
inp = ""
|
||||
|
||||
if inp == "!!exit" or not inp:
|
||||
print("exit...")
|
||||
break
|
||||
elif inp == "!!reset":
|
||||
print("resetting...")
|
||||
conv = new_chat()
|
||||
continue
|
||||
elif inp == "!!remove":
|
||||
print("removing last message...")
|
||||
if len(conv.messages) > conv.offset:
|
||||
# Assistant
|
||||
if conv.messages[-1][0] == conv.roles[1]:
|
||||
conv.messages.pop()
|
||||
# User
|
||||
if conv.messages[-1][0] == conv.roles[0]:
|
||||
conv.messages.pop()
|
||||
reload_conv(conv)
|
||||
else:
|
||||
print("No messages to remove.")
|
||||
continue
|
||||
elif inp == "!!regen":
|
||||
print("regenerating last message...")
|
||||
if len(conv.messages) > conv.offset:
|
||||
# Assistant
|
||||
if conv.messages[-1][0] == conv.roles[1]:
|
||||
conv.messages.pop()
|
||||
# User
|
||||
if conv.messages[-1][0] == conv.roles[0]:
|
||||
reload_conv(conv)
|
||||
# Set inp to previous message
|
||||
inp = conv.messages.pop()[1]
|
||||
else:
|
||||
# Shouldn't happen in normal circumstances
|
||||
print("No user message to regenerate from.")
|
||||
continue
|
||||
else:
|
||||
print("No messages to regenerate.")
|
||||
continue
|
||||
elif inp.startswith("!!save"):
|
||||
args = inp.split(" ", 1)
|
||||
|
||||
if len(args) != 2:
|
||||
print("usage: !!save <filename>")
|
||||
continue
|
||||
else:
|
||||
filename = args[1]
|
||||
|
||||
# Add .json if extension not present
|
||||
if not "." in filename:
|
||||
filename += ".json"
|
||||
|
||||
print("saving...", filename)
|
||||
with open(filename, "w") as outfile:
|
||||
json.dump(conv.dict(), outfile)
|
||||
continue
|
||||
elif inp.startswith("!!load"):
|
||||
args = inp.split(" ", 1)
|
||||
|
||||
if len(args) != 2:
|
||||
print("usage: !!load <filename>")
|
||||
continue
|
||||
else:
|
||||
filename = args[1]
|
||||
|
||||
# Check if file exists and add .json if needed
|
||||
if not os.path.exists(filename):
|
||||
if (not filename.endswith(".json")) and os.path.exists(
|
||||
filename + ".json"
|
||||
):
|
||||
filename += ".json"
|
||||
else:
|
||||
print("file not found:", filename)
|
||||
continue
|
||||
|
||||
print("loading...", filename)
|
||||
with open(filename, "r") as infile:
|
||||
new_conv = json.load(infile)
|
||||
|
||||
conv = get_conv_template(new_conv["template_name"])
|
||||
conv.set_system_message(new_conv["system_message"])
|
||||
conv.messages = new_conv["messages"]
|
||||
reload_conv(conv)
|
||||
continue
|
||||
|
||||
conv.append_message(conv.roles[0], inp)
|
||||
conv.append_message(conv.roles[1], None)
|
||||
prompt = conv.get_prompt()
|
||||
|
||||
if is_codet5p: # codet5p is a code completion model.
|
||||
prompt = inp
|
||||
|
||||
gen_params = {
|
||||
"model": model_path,
|
||||
"prompt": prompt,
|
||||
"temperature": temperature,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"stop": conv.stop_str,
|
||||
"stop_token_ids": conv.stop_token_ids,
|
||||
"echo": False,
|
||||
}
|
||||
|
||||
try:
|
||||
chatio.prompt_for_output(conv.roles[1])
|
||||
output_stream = generate_stream_func(
|
||||
model,
|
||||
tokenizer,
|
||||
gen_params,
|
||||
device,
|
||||
context_len=context_len,
|
||||
judge_sent_end=judge_sent_end,
|
||||
)
|
||||
t = time.time()
|
||||
outputs = chatio.stream_output(output_stream)
|
||||
duration = time.time() - t
|
||||
conv.update_last_message(outputs.strip())
|
||||
|
||||
if debug:
|
||||
num_tokens = len(tokenizer.encode(outputs))
|
||||
msg = {
|
||||
"conv_template": conv.name,
|
||||
"prompt": prompt,
|
||||
"outputs": outputs,
|
||||
"speed (token/s)": round(num_tokens / duration, 2),
|
||||
}
|
||||
print(f"\n{msg}\n")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("stopped generation.")
|
||||
# If generation didn't finish
|
||||
if conv.messages[-1][1] is None:
|
||||
conv.messages.pop()
|
||||
# Remove last user message, so there isn't a double up
|
||||
if conv.messages[-1][0] == conv.roles[0]:
|
||||
conv.messages.pop()
|
||||
|
||||
reload_conv(conv)
|
||||
|
@ -5,6 +5,8 @@ import sys
|
||||
import time
|
||||
import csv
|
||||
import argparse
|
||||
import logging
|
||||
import traceback
|
||||
from pilot.configs.model_config import ROOT_PATH, LLM_MODEL_CONFIG
|
||||
|
||||
from pilot.model.cluster.worker.manager import (
|
||||
@ -19,14 +21,12 @@ from pilot.model.cluster import PromptRequest
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
|
||||
# model_name = "chatglm2-6b"
|
||||
# model_name = "vicuna-7b-v1.5"
|
||||
model_name = "baichuan2-7b"
|
||||
model_name = "vicuna-7b-v1.5"
|
||||
model_path = LLM_MODEL_CONFIG[model_name]
|
||||
# or vllm
|
||||
model_type = "huggingface"
|
||||
|
||||
controller_addr = "http://127.0.0.1:5005"
|
||||
controller_addr = "http://127.0.0.1:5000"
|
||||
|
||||
result_csv_file = None
|
||||
|
||||
@ -59,7 +59,7 @@ METRICS_HEADERS = [
|
||||
# Merge parallel result
|
||||
"test_time_cost_ms",
|
||||
"test_total_tokens",
|
||||
"test_speed_per_second",
|
||||
"test_speed_per_second", # (tokens / s)
|
||||
# Detail for each task
|
||||
"start_time_ms",
|
||||
"end_time_ms",
|
||||
@ -93,7 +93,7 @@ def build_param(
|
||||
)
|
||||
hist.append(ModelMessage(role=ModelMessageRoleType.HUMAN, content=user_input))
|
||||
hist = list(h.dict() for h in hist)
|
||||
context_len = input_len + output_len
|
||||
context_len = input_len + output_len + 2
|
||||
params = {
|
||||
"prompt": user_input,
|
||||
"messages": hist,
|
||||
@ -167,7 +167,15 @@ async def run_model(wh: WorkerManager) -> None:
|
||||
os.rename(result_csv_file, f"{result_csv_file}.bak.csv")
|
||||
for parallel_num in parallel_nums:
|
||||
for input_len, output_len in zip(input_lens, output_lens):
|
||||
await run_batch(wh, input_len, output_len, parallel_num, result_csv_file)
|
||||
try:
|
||||
await run_batch(
|
||||
wh, input_len, output_len, parallel_num, result_csv_file
|
||||
)
|
||||
except Exception:
|
||||
msg = traceback.format_exc()
|
||||
logging.error(
|
||||
f"Run benchmarks error, input_len: {input_len}, output_len: {output_len}, parallel_num: {parallel_num}, error message: {msg}"
|
||||
)
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
@ -184,7 +192,6 @@ def startup_llm_env():
|
||||
controller_addr=controller_addr,
|
||||
local_port=6000,
|
||||
start_listener=run_model,
|
||||
# system_app=system_app,
|
||||
)
|
||||
|
||||
|
||||
@ -198,9 +205,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--model_path", type=str, default=None)
|
||||
parser.add_argument("--model_type", type=str, default="huggingface")
|
||||
parser.add_argument("--result_csv_file", type=str, default=None)
|
||||
parser.add_argument("--input_lens", type=str, default="64,64,64,512,1024,1024,2048")
|
||||
parser.add_argument("--input_lens", type=str, default="8,8,256,1024")
|
||||
parser.add_argument(
|
||||
"--output_lens", type=str, default="256,512,1024,1024,1024,2048,2048"
|
||||
"--output_lens", type=str, default="256,512,1024,1024"
|
||||
)
|
||||
parser.add_argument("--parallel_nums", type=str, default="1,2,4,16,32")
|
||||
parser.add_argument(
|
||||
@ -225,8 +232,10 @@ if __name__ == "__main__":
|
||||
raise ValueError("input_lens size must equal output_lens size")
|
||||
|
||||
if remote_model:
|
||||
# Connect to remote model and run benchmarks
|
||||
connect_to_remote_model()
|
||||
else:
|
||||
# Start worker manager and run benchmarks
|
||||
run_worker_manager(
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
|
Loading…
Reference in New Issue
Block a user