chore: Modify benchmarks for less times

This commit is contained in:
FangYin Cheng 2023-11-18 23:20:39 +08:00
parent 4ffd054a2a
commit d6318c21ee
3 changed files with 21 additions and 281 deletions

View File

@ -635,7 +635,7 @@ def _build_model_operator(
model_task_name="llm_model_node",
cache_task_name="llm_model_cache_node",
)
# Create a join node to merge outputs from the model and cache nodes, just keep the fist not empty output
# Create a join node to merge outputs from the model and cache nodes, just keep the first not empty output
join_node = JoinOperator(
combine_function=lambda model_out, cache_out: cache_out or model_out
)

View File

@ -3,28 +3,10 @@ Adapted from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/ser
For benchmarks.
"""
import abc
import gc
import json
import math
import os
import sys
import time
from typing import Iterable, Optional, Dict, TYPE_CHECKING
import warnings
from typing import Iterable, Dict
import psutil
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
LlamaTokenizer,
LlamaForCausalLM,
AutoModel,
AutoModelForSeq2SeqLM,
T5Tokenizer,
AutoConfig,
)
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
@ -33,18 +15,6 @@ from transformers.generation.logits_process import (
TopPLogitsWarper,
)
from fastchat.conversation import get_conv_template, SeparatorStyle
from fastchat.model.model_adapter import (
load_model,
get_conversation_template,
get_generate_stream_function,
)
from fastchat.modules.awq import AWQConfig
from fastchat.modules.gptq import GptqConfig
if TYPE_CHECKING:
from fastchat.modules.exllama import ExllamaConfig
from fastchat.modules.xfastertransformer import XftConfig
from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length
@ -324,242 +294,3 @@ def generate_stream(
torch.xpu.empty_cache()
if device == "npu":
torch.npu.empty_cache()
class ChatIO(abc.ABC):
@abc.abstractmethod
def prompt_for_input(self, role: str) -> str:
"""Prompt for input from a role."""
@abc.abstractmethod
def prompt_for_output(self, role: str):
"""Prompt for output from a role."""
@abc.abstractmethod
def stream_output(self, output_stream):
"""Stream output."""
@abc.abstractmethod
def print_output(self, text: str):
"""Print output."""
def chat_loop(
model_path: str,
device: str,
num_gpus: int,
max_gpu_memory: str,
dtype: Optional[torch.dtype],
load_8bit: bool,
cpu_offloading: bool,
conv_template: Optional[str],
conv_system_msg: Optional[str],
temperature: float,
repetition_penalty: float,
max_new_tokens: int,
chatio: ChatIO,
gptq_config: Optional[GptqConfig] = None,
awq_config: Optional[AWQConfig] = None,
exllama_config: Optional["ExllamaConfig"] = None,
xft_config: Optional["XftConfig"] = None,
revision: str = "main",
judge_sent_end: bool = True,
debug: bool = True,
history: bool = True,
):
# Model
model, tokenizer = load_model(
model_path,
device=device,
num_gpus=num_gpus,
max_gpu_memory=max_gpu_memory,
dtype=dtype,
load_8bit=load_8bit,
cpu_offloading=cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
exllama_config=exllama_config,
xft_config=xft_config,
revision=revision,
debug=debug,
)
generate_stream_func = get_generate_stream_function(model, model_path)
model_type = str(type(model)).lower()
is_t5 = "t5" in model_type
is_codet5p = "codet5p" in model_type
is_xft = "xft" in model_type
# Hardcode T5's default repetition penalty to be 1.2
if is_t5 and repetition_penalty == 1.0:
repetition_penalty = 1.2
# Set context length
context_len = get_context_length(model.config)
# Chat
def new_chat():
if conv_template:
conv = get_conv_template(conv_template)
else:
conv = get_conversation_template(model_path)
if conv_system_msg is not None:
conv.set_system_message(conv_system_msg)
return conv
def reload_conv(conv):
"""
Reprints the conversation from the start.
"""
for message in conv.messages[conv.offset :]:
chatio.prompt_for_output(message[0])
chatio.print_output(message[1])
conv = None
while True:
if not history or not conv:
conv = new_chat()
try:
inp = chatio.prompt_for_input(conv.roles[0])
except EOFError:
inp = ""
if inp == "!!exit" or not inp:
print("exit...")
break
elif inp == "!!reset":
print("resetting...")
conv = new_chat()
continue
elif inp == "!!remove":
print("removing last message...")
if len(conv.messages) > conv.offset:
# Assistant
if conv.messages[-1][0] == conv.roles[1]:
conv.messages.pop()
# User
if conv.messages[-1][0] == conv.roles[0]:
conv.messages.pop()
reload_conv(conv)
else:
print("No messages to remove.")
continue
elif inp == "!!regen":
print("regenerating last message...")
if len(conv.messages) > conv.offset:
# Assistant
if conv.messages[-1][0] == conv.roles[1]:
conv.messages.pop()
# User
if conv.messages[-1][0] == conv.roles[0]:
reload_conv(conv)
# Set inp to previous message
inp = conv.messages.pop()[1]
else:
# Shouldn't happen in normal circumstances
print("No user message to regenerate from.")
continue
else:
print("No messages to regenerate.")
continue
elif inp.startswith("!!save"):
args = inp.split(" ", 1)
if len(args) != 2:
print("usage: !!save <filename>")
continue
else:
filename = args[1]
# Add .json if extension not present
if not "." in filename:
filename += ".json"
print("saving...", filename)
with open(filename, "w") as outfile:
json.dump(conv.dict(), outfile)
continue
elif inp.startswith("!!load"):
args = inp.split(" ", 1)
if len(args) != 2:
print("usage: !!load <filename>")
continue
else:
filename = args[1]
# Check if file exists and add .json if needed
if not os.path.exists(filename):
if (not filename.endswith(".json")) and os.path.exists(
filename + ".json"
):
filename += ".json"
else:
print("file not found:", filename)
continue
print("loading...", filename)
with open(filename, "r") as infile:
new_conv = json.load(infile)
conv = get_conv_template(new_conv["template_name"])
conv.set_system_message(new_conv["system_message"])
conv.messages = new_conv["messages"]
reload_conv(conv)
continue
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
if is_codet5p: # codet5p is a code completion model.
prompt = inp
gen_params = {
"model": model_path,
"prompt": prompt,
"temperature": temperature,
"repetition_penalty": repetition_penalty,
"max_new_tokens": max_new_tokens,
"stop": conv.stop_str,
"stop_token_ids": conv.stop_token_ids,
"echo": False,
}
try:
chatio.prompt_for_output(conv.roles[1])
output_stream = generate_stream_func(
model,
tokenizer,
gen_params,
device,
context_len=context_len,
judge_sent_end=judge_sent_end,
)
t = time.time()
outputs = chatio.stream_output(output_stream)
duration = time.time() - t
conv.update_last_message(outputs.strip())
if debug:
num_tokens = len(tokenizer.encode(outputs))
msg = {
"conv_template": conv.name,
"prompt": prompt,
"outputs": outputs,
"speed (token/s)": round(num_tokens / duration, 2),
}
print(f"\n{msg}\n")
except KeyboardInterrupt:
print("stopped generation.")
# If generation didn't finish
if conv.messages[-1][1] is None:
conv.messages.pop()
# Remove last user message, so there isn't a double up
if conv.messages[-1][0] == conv.roles[0]:
conv.messages.pop()
reload_conv(conv)

View File

@ -5,6 +5,8 @@ import sys
import time
import csv
import argparse
import logging
import traceback
from pilot.configs.model_config import ROOT_PATH, LLM_MODEL_CONFIG
from pilot.model.cluster.worker.manager import (
@ -19,14 +21,12 @@ from pilot.model.cluster import PromptRequest
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
# model_name = "chatglm2-6b"
# model_name = "vicuna-7b-v1.5"
model_name = "baichuan2-7b"
model_name = "vicuna-7b-v1.5"
model_path = LLM_MODEL_CONFIG[model_name]
# or vllm
model_type = "huggingface"
controller_addr = "http://127.0.0.1:5005"
controller_addr = "http://127.0.0.1:5000"
result_csv_file = None
@ -59,7 +59,7 @@ METRICS_HEADERS = [
# Merge parallel result
"test_time_cost_ms",
"test_total_tokens",
"test_speed_per_second",
"test_speed_per_second", # (tokens / s)
# Detail for each task
"start_time_ms",
"end_time_ms",
@ -93,7 +93,7 @@ def build_param(
)
hist.append(ModelMessage(role=ModelMessageRoleType.HUMAN, content=user_input))
hist = list(h.dict() for h in hist)
context_len = input_len + output_len
context_len = input_len + output_len + 2
params = {
"prompt": user_input,
"messages": hist,
@ -167,7 +167,15 @@ async def run_model(wh: WorkerManager) -> None:
os.rename(result_csv_file, f"{result_csv_file}.bak.csv")
for parallel_num in parallel_nums:
for input_len, output_len in zip(input_lens, output_lens):
await run_batch(wh, input_len, output_len, parallel_num, result_csv_file)
try:
await run_batch(
wh, input_len, output_len, parallel_num, result_csv_file
)
except Exception:
msg = traceback.format_exc()
logging.error(
f"Run benchmarks error, input_len: {input_len}, output_len: {output_len}, parallel_num: {parallel_num}, error message: {msg}"
)
sys.exit(0)
@ -184,7 +192,6 @@ def startup_llm_env():
controller_addr=controller_addr,
local_port=6000,
start_listener=run_model,
# system_app=system_app,
)
@ -198,9 +205,9 @@ if __name__ == "__main__":
parser.add_argument("--model_path", type=str, default=None)
parser.add_argument("--model_type", type=str, default="huggingface")
parser.add_argument("--result_csv_file", type=str, default=None)
parser.add_argument("--input_lens", type=str, default="64,64,64,512,1024,1024,2048")
parser.add_argument("--input_lens", type=str, default="8,8,256,1024")
parser.add_argument(
"--output_lens", type=str, default="256,512,1024,1024,1024,2048,2048"
"--output_lens", type=str, default="256,512,1024,1024"
)
parser.add_argument("--parallel_nums", type=str, default="1,2,4,16,32")
parser.add_argument(
@ -225,8 +232,10 @@ if __name__ == "__main__":
raise ValueError("input_lens size must equal output_lens size")
if remote_model:
# Connect to remote model and run benchmarks
connect_to_remote_model()
else:
# Start worker manager and run benchmarks
run_worker_manager(
model_name=model_name,
model_path=model_path,