mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-07 03:14:42 +00:00
chore: Modify benchmarks for less times
This commit is contained in:
parent
4ffd054a2a
commit
d6318c21ee
@ -635,7 +635,7 @@ def _build_model_operator(
|
|||||||
model_task_name="llm_model_node",
|
model_task_name="llm_model_node",
|
||||||
cache_task_name="llm_model_cache_node",
|
cache_task_name="llm_model_cache_node",
|
||||||
)
|
)
|
||||||
# Create a join node to merge outputs from the model and cache nodes, just keep the fist not empty output
|
# Create a join node to merge outputs from the model and cache nodes, just keep the first not empty output
|
||||||
join_node = JoinOperator(
|
join_node = JoinOperator(
|
||||||
combine_function=lambda model_out, cache_out: cache_out or model_out
|
combine_function=lambda model_out, cache_out: cache_out or model_out
|
||||||
)
|
)
|
||||||
|
@ -3,28 +3,10 @@ Adapted from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/ser
|
|||||||
For benchmarks.
|
For benchmarks.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import abc
|
|
||||||
import gc
|
import gc
|
||||||
import json
|
from typing import Iterable, Dict
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from typing import Iterable, Optional, Dict, TYPE_CHECKING
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import psutil
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import (
|
|
||||||
AutoTokenizer,
|
|
||||||
AutoModelForCausalLM,
|
|
||||||
LlamaTokenizer,
|
|
||||||
LlamaForCausalLM,
|
|
||||||
AutoModel,
|
|
||||||
AutoModelForSeq2SeqLM,
|
|
||||||
T5Tokenizer,
|
|
||||||
AutoConfig,
|
|
||||||
)
|
|
||||||
from transformers.generation.logits_process import (
|
from transformers.generation.logits_process import (
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
RepetitionPenaltyLogitsProcessor,
|
RepetitionPenaltyLogitsProcessor,
|
||||||
@ -33,18 +15,6 @@ from transformers.generation.logits_process import (
|
|||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
)
|
)
|
||||||
|
|
||||||
from fastchat.conversation import get_conv_template, SeparatorStyle
|
|
||||||
from fastchat.model.model_adapter import (
|
|
||||||
load_model,
|
|
||||||
get_conversation_template,
|
|
||||||
get_generate_stream_function,
|
|
||||||
)
|
|
||||||
from fastchat.modules.awq import AWQConfig
|
|
||||||
from fastchat.modules.gptq import GptqConfig
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from fastchat.modules.exllama import ExllamaConfig
|
|
||||||
from fastchat.modules.xfastertransformer import XftConfig
|
|
||||||
|
|
||||||
from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length
|
from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length
|
||||||
|
|
||||||
@ -324,242 +294,3 @@ def generate_stream(
|
|||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
if device == "npu":
|
if device == "npu":
|
||||||
torch.npu.empty_cache()
|
torch.npu.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
class ChatIO(abc.ABC):
|
|
||||||
@abc.abstractmethod
|
|
||||||
def prompt_for_input(self, role: str) -> str:
|
|
||||||
"""Prompt for input from a role."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def prompt_for_output(self, role: str):
|
|
||||||
"""Prompt for output from a role."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def stream_output(self, output_stream):
|
|
||||||
"""Stream output."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def print_output(self, text: str):
|
|
||||||
"""Print output."""
|
|
||||||
|
|
||||||
|
|
||||||
def chat_loop(
|
|
||||||
model_path: str,
|
|
||||||
device: str,
|
|
||||||
num_gpus: int,
|
|
||||||
max_gpu_memory: str,
|
|
||||||
dtype: Optional[torch.dtype],
|
|
||||||
load_8bit: bool,
|
|
||||||
cpu_offloading: bool,
|
|
||||||
conv_template: Optional[str],
|
|
||||||
conv_system_msg: Optional[str],
|
|
||||||
temperature: float,
|
|
||||||
repetition_penalty: float,
|
|
||||||
max_new_tokens: int,
|
|
||||||
chatio: ChatIO,
|
|
||||||
gptq_config: Optional[GptqConfig] = None,
|
|
||||||
awq_config: Optional[AWQConfig] = None,
|
|
||||||
exllama_config: Optional["ExllamaConfig"] = None,
|
|
||||||
xft_config: Optional["XftConfig"] = None,
|
|
||||||
revision: str = "main",
|
|
||||||
judge_sent_end: bool = True,
|
|
||||||
debug: bool = True,
|
|
||||||
history: bool = True,
|
|
||||||
):
|
|
||||||
# Model
|
|
||||||
model, tokenizer = load_model(
|
|
||||||
model_path,
|
|
||||||
device=device,
|
|
||||||
num_gpus=num_gpus,
|
|
||||||
max_gpu_memory=max_gpu_memory,
|
|
||||||
dtype=dtype,
|
|
||||||
load_8bit=load_8bit,
|
|
||||||
cpu_offloading=cpu_offloading,
|
|
||||||
gptq_config=gptq_config,
|
|
||||||
awq_config=awq_config,
|
|
||||||
exllama_config=exllama_config,
|
|
||||||
xft_config=xft_config,
|
|
||||||
revision=revision,
|
|
||||||
debug=debug,
|
|
||||||
)
|
|
||||||
generate_stream_func = get_generate_stream_function(model, model_path)
|
|
||||||
|
|
||||||
model_type = str(type(model)).lower()
|
|
||||||
is_t5 = "t5" in model_type
|
|
||||||
is_codet5p = "codet5p" in model_type
|
|
||||||
is_xft = "xft" in model_type
|
|
||||||
|
|
||||||
# Hardcode T5's default repetition penalty to be 1.2
|
|
||||||
if is_t5 and repetition_penalty == 1.0:
|
|
||||||
repetition_penalty = 1.2
|
|
||||||
|
|
||||||
# Set context length
|
|
||||||
context_len = get_context_length(model.config)
|
|
||||||
|
|
||||||
# Chat
|
|
||||||
def new_chat():
|
|
||||||
if conv_template:
|
|
||||||
conv = get_conv_template(conv_template)
|
|
||||||
else:
|
|
||||||
conv = get_conversation_template(model_path)
|
|
||||||
if conv_system_msg is not None:
|
|
||||||
conv.set_system_message(conv_system_msg)
|
|
||||||
return conv
|
|
||||||
|
|
||||||
def reload_conv(conv):
|
|
||||||
"""
|
|
||||||
Reprints the conversation from the start.
|
|
||||||
"""
|
|
||||||
for message in conv.messages[conv.offset :]:
|
|
||||||
chatio.prompt_for_output(message[0])
|
|
||||||
chatio.print_output(message[1])
|
|
||||||
|
|
||||||
conv = None
|
|
||||||
|
|
||||||
while True:
|
|
||||||
if not history or not conv:
|
|
||||||
conv = new_chat()
|
|
||||||
|
|
||||||
try:
|
|
||||||
inp = chatio.prompt_for_input(conv.roles[0])
|
|
||||||
except EOFError:
|
|
||||||
inp = ""
|
|
||||||
|
|
||||||
if inp == "!!exit" or not inp:
|
|
||||||
print("exit...")
|
|
||||||
break
|
|
||||||
elif inp == "!!reset":
|
|
||||||
print("resetting...")
|
|
||||||
conv = new_chat()
|
|
||||||
continue
|
|
||||||
elif inp == "!!remove":
|
|
||||||
print("removing last message...")
|
|
||||||
if len(conv.messages) > conv.offset:
|
|
||||||
# Assistant
|
|
||||||
if conv.messages[-1][0] == conv.roles[1]:
|
|
||||||
conv.messages.pop()
|
|
||||||
# User
|
|
||||||
if conv.messages[-1][0] == conv.roles[0]:
|
|
||||||
conv.messages.pop()
|
|
||||||
reload_conv(conv)
|
|
||||||
else:
|
|
||||||
print("No messages to remove.")
|
|
||||||
continue
|
|
||||||
elif inp == "!!regen":
|
|
||||||
print("regenerating last message...")
|
|
||||||
if len(conv.messages) > conv.offset:
|
|
||||||
# Assistant
|
|
||||||
if conv.messages[-1][0] == conv.roles[1]:
|
|
||||||
conv.messages.pop()
|
|
||||||
# User
|
|
||||||
if conv.messages[-1][0] == conv.roles[0]:
|
|
||||||
reload_conv(conv)
|
|
||||||
# Set inp to previous message
|
|
||||||
inp = conv.messages.pop()[1]
|
|
||||||
else:
|
|
||||||
# Shouldn't happen in normal circumstances
|
|
||||||
print("No user message to regenerate from.")
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
print("No messages to regenerate.")
|
|
||||||
continue
|
|
||||||
elif inp.startswith("!!save"):
|
|
||||||
args = inp.split(" ", 1)
|
|
||||||
|
|
||||||
if len(args) != 2:
|
|
||||||
print("usage: !!save <filename>")
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
filename = args[1]
|
|
||||||
|
|
||||||
# Add .json if extension not present
|
|
||||||
if not "." in filename:
|
|
||||||
filename += ".json"
|
|
||||||
|
|
||||||
print("saving...", filename)
|
|
||||||
with open(filename, "w") as outfile:
|
|
||||||
json.dump(conv.dict(), outfile)
|
|
||||||
continue
|
|
||||||
elif inp.startswith("!!load"):
|
|
||||||
args = inp.split(" ", 1)
|
|
||||||
|
|
||||||
if len(args) != 2:
|
|
||||||
print("usage: !!load <filename>")
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
filename = args[1]
|
|
||||||
|
|
||||||
# Check if file exists and add .json if needed
|
|
||||||
if not os.path.exists(filename):
|
|
||||||
if (not filename.endswith(".json")) and os.path.exists(
|
|
||||||
filename + ".json"
|
|
||||||
):
|
|
||||||
filename += ".json"
|
|
||||||
else:
|
|
||||||
print("file not found:", filename)
|
|
||||||
continue
|
|
||||||
|
|
||||||
print("loading...", filename)
|
|
||||||
with open(filename, "r") as infile:
|
|
||||||
new_conv = json.load(infile)
|
|
||||||
|
|
||||||
conv = get_conv_template(new_conv["template_name"])
|
|
||||||
conv.set_system_message(new_conv["system_message"])
|
|
||||||
conv.messages = new_conv["messages"]
|
|
||||||
reload_conv(conv)
|
|
||||||
continue
|
|
||||||
|
|
||||||
conv.append_message(conv.roles[0], inp)
|
|
||||||
conv.append_message(conv.roles[1], None)
|
|
||||||
prompt = conv.get_prompt()
|
|
||||||
|
|
||||||
if is_codet5p: # codet5p is a code completion model.
|
|
||||||
prompt = inp
|
|
||||||
|
|
||||||
gen_params = {
|
|
||||||
"model": model_path,
|
|
||||||
"prompt": prompt,
|
|
||||||
"temperature": temperature,
|
|
||||||
"repetition_penalty": repetition_penalty,
|
|
||||||
"max_new_tokens": max_new_tokens,
|
|
||||||
"stop": conv.stop_str,
|
|
||||||
"stop_token_ids": conv.stop_token_ids,
|
|
||||||
"echo": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
chatio.prompt_for_output(conv.roles[1])
|
|
||||||
output_stream = generate_stream_func(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
gen_params,
|
|
||||||
device,
|
|
||||||
context_len=context_len,
|
|
||||||
judge_sent_end=judge_sent_end,
|
|
||||||
)
|
|
||||||
t = time.time()
|
|
||||||
outputs = chatio.stream_output(output_stream)
|
|
||||||
duration = time.time() - t
|
|
||||||
conv.update_last_message(outputs.strip())
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
num_tokens = len(tokenizer.encode(outputs))
|
|
||||||
msg = {
|
|
||||||
"conv_template": conv.name,
|
|
||||||
"prompt": prompt,
|
|
||||||
"outputs": outputs,
|
|
||||||
"speed (token/s)": round(num_tokens / duration, 2),
|
|
||||||
}
|
|
||||||
print(f"\n{msg}\n")
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("stopped generation.")
|
|
||||||
# If generation didn't finish
|
|
||||||
if conv.messages[-1][1] is None:
|
|
||||||
conv.messages.pop()
|
|
||||||
# Remove last user message, so there isn't a double up
|
|
||||||
if conv.messages[-1][0] == conv.roles[0]:
|
|
||||||
conv.messages.pop()
|
|
||||||
|
|
||||||
reload_conv(conv)
|
|
||||||
|
@ -5,6 +5,8 @@ import sys
|
|||||||
import time
|
import time
|
||||||
import csv
|
import csv
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
from pilot.configs.model_config import ROOT_PATH, LLM_MODEL_CONFIG
|
from pilot.configs.model_config import ROOT_PATH, LLM_MODEL_CONFIG
|
||||||
|
|
||||||
from pilot.model.cluster.worker.manager import (
|
from pilot.model.cluster.worker.manager import (
|
||||||
@ -19,14 +21,12 @@ from pilot.model.cluster import PromptRequest
|
|||||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||||
|
|
||||||
|
|
||||||
# model_name = "chatglm2-6b"
|
model_name = "vicuna-7b-v1.5"
|
||||||
# model_name = "vicuna-7b-v1.5"
|
|
||||||
model_name = "baichuan2-7b"
|
|
||||||
model_path = LLM_MODEL_CONFIG[model_name]
|
model_path = LLM_MODEL_CONFIG[model_name]
|
||||||
# or vllm
|
# or vllm
|
||||||
model_type = "huggingface"
|
model_type = "huggingface"
|
||||||
|
|
||||||
controller_addr = "http://127.0.0.1:5005"
|
controller_addr = "http://127.0.0.1:5000"
|
||||||
|
|
||||||
result_csv_file = None
|
result_csv_file = None
|
||||||
|
|
||||||
@ -59,7 +59,7 @@ METRICS_HEADERS = [
|
|||||||
# Merge parallel result
|
# Merge parallel result
|
||||||
"test_time_cost_ms",
|
"test_time_cost_ms",
|
||||||
"test_total_tokens",
|
"test_total_tokens",
|
||||||
"test_speed_per_second",
|
"test_speed_per_second", # (tokens / s)
|
||||||
# Detail for each task
|
# Detail for each task
|
||||||
"start_time_ms",
|
"start_time_ms",
|
||||||
"end_time_ms",
|
"end_time_ms",
|
||||||
@ -93,7 +93,7 @@ def build_param(
|
|||||||
)
|
)
|
||||||
hist.append(ModelMessage(role=ModelMessageRoleType.HUMAN, content=user_input))
|
hist.append(ModelMessage(role=ModelMessageRoleType.HUMAN, content=user_input))
|
||||||
hist = list(h.dict() for h in hist)
|
hist = list(h.dict() for h in hist)
|
||||||
context_len = input_len + output_len
|
context_len = input_len + output_len + 2
|
||||||
params = {
|
params = {
|
||||||
"prompt": user_input,
|
"prompt": user_input,
|
||||||
"messages": hist,
|
"messages": hist,
|
||||||
@ -167,7 +167,15 @@ async def run_model(wh: WorkerManager) -> None:
|
|||||||
os.rename(result_csv_file, f"{result_csv_file}.bak.csv")
|
os.rename(result_csv_file, f"{result_csv_file}.bak.csv")
|
||||||
for parallel_num in parallel_nums:
|
for parallel_num in parallel_nums:
|
||||||
for input_len, output_len in zip(input_lens, output_lens):
|
for input_len, output_len in zip(input_lens, output_lens):
|
||||||
await run_batch(wh, input_len, output_len, parallel_num, result_csv_file)
|
try:
|
||||||
|
await run_batch(
|
||||||
|
wh, input_len, output_len, parallel_num, result_csv_file
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
msg = traceback.format_exc()
|
||||||
|
logging.error(
|
||||||
|
f"Run benchmarks error, input_len: {input_len}, output_len: {output_len}, parallel_num: {parallel_num}, error message: {msg}"
|
||||||
|
)
|
||||||
|
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
@ -184,7 +192,6 @@ def startup_llm_env():
|
|||||||
controller_addr=controller_addr,
|
controller_addr=controller_addr,
|
||||||
local_port=6000,
|
local_port=6000,
|
||||||
start_listener=run_model,
|
start_listener=run_model,
|
||||||
# system_app=system_app,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -198,9 +205,9 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--model_path", type=str, default=None)
|
parser.add_argument("--model_path", type=str, default=None)
|
||||||
parser.add_argument("--model_type", type=str, default="huggingface")
|
parser.add_argument("--model_type", type=str, default="huggingface")
|
||||||
parser.add_argument("--result_csv_file", type=str, default=None)
|
parser.add_argument("--result_csv_file", type=str, default=None)
|
||||||
parser.add_argument("--input_lens", type=str, default="64,64,64,512,1024,1024,2048")
|
parser.add_argument("--input_lens", type=str, default="8,8,256,1024")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output_lens", type=str, default="256,512,1024,1024,1024,2048,2048"
|
"--output_lens", type=str, default="256,512,1024,1024"
|
||||||
)
|
)
|
||||||
parser.add_argument("--parallel_nums", type=str, default="1,2,4,16,32")
|
parser.add_argument("--parallel_nums", type=str, default="1,2,4,16,32")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -225,8 +232,10 @@ if __name__ == "__main__":
|
|||||||
raise ValueError("input_lens size must equal output_lens size")
|
raise ValueError("input_lens size must equal output_lens size")
|
||||||
|
|
||||||
if remote_model:
|
if remote_model:
|
||||||
|
# Connect to remote model and run benchmarks
|
||||||
connect_to_remote_model()
|
connect_to_remote_model()
|
||||||
else:
|
else:
|
||||||
|
# Start worker manager and run benchmarks
|
||||||
run_worker_manager(
|
run_worker_manager(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
|
Loading…
Reference in New Issue
Block a user