chore: Modify benchmarks for less times

This commit is contained in:
FangYin Cheng 2023-11-18 23:20:39 +08:00
parent 4ffd054a2a
commit d6318c21ee
3 changed files with 21 additions and 281 deletions

View File

@ -635,7 +635,7 @@ def _build_model_operator(
model_task_name="llm_model_node", model_task_name="llm_model_node",
cache_task_name="llm_model_cache_node", cache_task_name="llm_model_cache_node",
) )
# Create a join node to merge outputs from the model and cache nodes, just keep the fist not empty output # Create a join node to merge outputs from the model and cache nodes, just keep the first not empty output
join_node = JoinOperator( join_node = JoinOperator(
combine_function=lambda model_out, cache_out: cache_out or model_out combine_function=lambda model_out, cache_out: cache_out or model_out
) )

View File

@ -3,28 +3,10 @@ Adapted from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/ser
For benchmarks. For benchmarks.
""" """
import abc
import gc import gc
import json from typing import Iterable, Dict
import math
import os
import sys
import time
from typing import Iterable, Optional, Dict, TYPE_CHECKING
import warnings
import psutil
import torch import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
LlamaTokenizer,
LlamaForCausalLM,
AutoModel,
AutoModelForSeq2SeqLM,
T5Tokenizer,
AutoConfig,
)
from transformers.generation.logits_process import ( from transformers.generation.logits_process import (
LogitsProcessorList, LogitsProcessorList,
RepetitionPenaltyLogitsProcessor, RepetitionPenaltyLogitsProcessor,
@ -33,18 +15,6 @@ from transformers.generation.logits_process import (
TopPLogitsWarper, TopPLogitsWarper,
) )
from fastchat.conversation import get_conv_template, SeparatorStyle
from fastchat.model.model_adapter import (
load_model,
get_conversation_template,
get_generate_stream_function,
)
from fastchat.modules.awq import AWQConfig
from fastchat.modules.gptq import GptqConfig
if TYPE_CHECKING:
from fastchat.modules.exllama import ExllamaConfig
from fastchat.modules.xfastertransformer import XftConfig
from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length
@ -324,242 +294,3 @@ def generate_stream(
torch.xpu.empty_cache() torch.xpu.empty_cache()
if device == "npu": if device == "npu":
torch.npu.empty_cache() torch.npu.empty_cache()
class ChatIO(abc.ABC):
@abc.abstractmethod
def prompt_for_input(self, role: str) -> str:
"""Prompt for input from a role."""
@abc.abstractmethod
def prompt_for_output(self, role: str):
"""Prompt for output from a role."""
@abc.abstractmethod
def stream_output(self, output_stream):
"""Stream output."""
@abc.abstractmethod
def print_output(self, text: str):
"""Print output."""
def chat_loop(
model_path: str,
device: str,
num_gpus: int,
max_gpu_memory: str,
dtype: Optional[torch.dtype],
load_8bit: bool,
cpu_offloading: bool,
conv_template: Optional[str],
conv_system_msg: Optional[str],
temperature: float,
repetition_penalty: float,
max_new_tokens: int,
chatio: ChatIO,
gptq_config: Optional[GptqConfig] = None,
awq_config: Optional[AWQConfig] = None,
exllama_config: Optional["ExllamaConfig"] = None,
xft_config: Optional["XftConfig"] = None,
revision: str = "main",
judge_sent_end: bool = True,
debug: bool = True,
history: bool = True,
):
# Model
model, tokenizer = load_model(
model_path,
device=device,
num_gpus=num_gpus,
max_gpu_memory=max_gpu_memory,
dtype=dtype,
load_8bit=load_8bit,
cpu_offloading=cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
exllama_config=exllama_config,
xft_config=xft_config,
revision=revision,
debug=debug,
)
generate_stream_func = get_generate_stream_function(model, model_path)
model_type = str(type(model)).lower()
is_t5 = "t5" in model_type
is_codet5p = "codet5p" in model_type
is_xft = "xft" in model_type
# Hardcode T5's default repetition penalty to be 1.2
if is_t5 and repetition_penalty == 1.0:
repetition_penalty = 1.2
# Set context length
context_len = get_context_length(model.config)
# Chat
def new_chat():
if conv_template:
conv = get_conv_template(conv_template)
else:
conv = get_conversation_template(model_path)
if conv_system_msg is not None:
conv.set_system_message(conv_system_msg)
return conv
def reload_conv(conv):
"""
Reprints the conversation from the start.
"""
for message in conv.messages[conv.offset :]:
chatio.prompt_for_output(message[0])
chatio.print_output(message[1])
conv = None
while True:
if not history or not conv:
conv = new_chat()
try:
inp = chatio.prompt_for_input(conv.roles[0])
except EOFError:
inp = ""
if inp == "!!exit" or not inp:
print("exit...")
break
elif inp == "!!reset":
print("resetting...")
conv = new_chat()
continue
elif inp == "!!remove":
print("removing last message...")
if len(conv.messages) > conv.offset:
# Assistant
if conv.messages[-1][0] == conv.roles[1]:
conv.messages.pop()
# User
if conv.messages[-1][0] == conv.roles[0]:
conv.messages.pop()
reload_conv(conv)
else:
print("No messages to remove.")
continue
elif inp == "!!regen":
print("regenerating last message...")
if len(conv.messages) > conv.offset:
# Assistant
if conv.messages[-1][0] == conv.roles[1]:
conv.messages.pop()
# User
if conv.messages[-1][0] == conv.roles[0]:
reload_conv(conv)
# Set inp to previous message
inp = conv.messages.pop()[1]
else:
# Shouldn't happen in normal circumstances
print("No user message to regenerate from.")
continue
else:
print("No messages to regenerate.")
continue
elif inp.startswith("!!save"):
args = inp.split(" ", 1)
if len(args) != 2:
print("usage: !!save <filename>")
continue
else:
filename = args[1]
# Add .json if extension not present
if not "." in filename:
filename += ".json"
print("saving...", filename)
with open(filename, "w") as outfile:
json.dump(conv.dict(), outfile)
continue
elif inp.startswith("!!load"):
args = inp.split(" ", 1)
if len(args) != 2:
print("usage: !!load <filename>")
continue
else:
filename = args[1]
# Check if file exists and add .json if needed
if not os.path.exists(filename):
if (not filename.endswith(".json")) and os.path.exists(
filename + ".json"
):
filename += ".json"
else:
print("file not found:", filename)
continue
print("loading...", filename)
with open(filename, "r") as infile:
new_conv = json.load(infile)
conv = get_conv_template(new_conv["template_name"])
conv.set_system_message(new_conv["system_message"])
conv.messages = new_conv["messages"]
reload_conv(conv)
continue
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
if is_codet5p: # codet5p is a code completion model.
prompt = inp
gen_params = {
"model": model_path,
"prompt": prompt,
"temperature": temperature,
"repetition_penalty": repetition_penalty,
"max_new_tokens": max_new_tokens,
"stop": conv.stop_str,
"stop_token_ids": conv.stop_token_ids,
"echo": False,
}
try:
chatio.prompt_for_output(conv.roles[1])
output_stream = generate_stream_func(
model,
tokenizer,
gen_params,
device,
context_len=context_len,
judge_sent_end=judge_sent_end,
)
t = time.time()
outputs = chatio.stream_output(output_stream)
duration = time.time() - t
conv.update_last_message(outputs.strip())
if debug:
num_tokens = len(tokenizer.encode(outputs))
msg = {
"conv_template": conv.name,
"prompt": prompt,
"outputs": outputs,
"speed (token/s)": round(num_tokens / duration, 2),
}
print(f"\n{msg}\n")
except KeyboardInterrupt:
print("stopped generation.")
# If generation didn't finish
if conv.messages[-1][1] is None:
conv.messages.pop()
# Remove last user message, so there isn't a double up
if conv.messages[-1][0] == conv.roles[0]:
conv.messages.pop()
reload_conv(conv)

View File

@ -5,6 +5,8 @@ import sys
import time import time
import csv import csv
import argparse import argparse
import logging
import traceback
from pilot.configs.model_config import ROOT_PATH, LLM_MODEL_CONFIG from pilot.configs.model_config import ROOT_PATH, LLM_MODEL_CONFIG
from pilot.model.cluster.worker.manager import ( from pilot.model.cluster.worker.manager import (
@ -19,14 +21,12 @@ from pilot.model.cluster import PromptRequest
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
# model_name = "chatglm2-6b" model_name = "vicuna-7b-v1.5"
# model_name = "vicuna-7b-v1.5"
model_name = "baichuan2-7b"
model_path = LLM_MODEL_CONFIG[model_name] model_path = LLM_MODEL_CONFIG[model_name]
# or vllm # or vllm
model_type = "huggingface" model_type = "huggingface"
controller_addr = "http://127.0.0.1:5005" controller_addr = "http://127.0.0.1:5000"
result_csv_file = None result_csv_file = None
@ -59,7 +59,7 @@ METRICS_HEADERS = [
# Merge parallel result # Merge parallel result
"test_time_cost_ms", "test_time_cost_ms",
"test_total_tokens", "test_total_tokens",
"test_speed_per_second", "test_speed_per_second", # (tokens / s)
# Detail for each task # Detail for each task
"start_time_ms", "start_time_ms",
"end_time_ms", "end_time_ms",
@ -93,7 +93,7 @@ def build_param(
) )
hist.append(ModelMessage(role=ModelMessageRoleType.HUMAN, content=user_input)) hist.append(ModelMessage(role=ModelMessageRoleType.HUMAN, content=user_input))
hist = list(h.dict() for h in hist) hist = list(h.dict() for h in hist)
context_len = input_len + output_len context_len = input_len + output_len + 2
params = { params = {
"prompt": user_input, "prompt": user_input,
"messages": hist, "messages": hist,
@ -167,7 +167,15 @@ async def run_model(wh: WorkerManager) -> None:
os.rename(result_csv_file, f"{result_csv_file}.bak.csv") os.rename(result_csv_file, f"{result_csv_file}.bak.csv")
for parallel_num in parallel_nums: for parallel_num in parallel_nums:
for input_len, output_len in zip(input_lens, output_lens): for input_len, output_len in zip(input_lens, output_lens):
await run_batch(wh, input_len, output_len, parallel_num, result_csv_file) try:
await run_batch(
wh, input_len, output_len, parallel_num, result_csv_file
)
except Exception:
msg = traceback.format_exc()
logging.error(
f"Run benchmarks error, input_len: {input_len}, output_len: {output_len}, parallel_num: {parallel_num}, error message: {msg}"
)
sys.exit(0) sys.exit(0)
@ -184,7 +192,6 @@ def startup_llm_env():
controller_addr=controller_addr, controller_addr=controller_addr,
local_port=6000, local_port=6000,
start_listener=run_model, start_listener=run_model,
# system_app=system_app,
) )
@ -198,9 +205,9 @@ if __name__ == "__main__":
parser.add_argument("--model_path", type=str, default=None) parser.add_argument("--model_path", type=str, default=None)
parser.add_argument("--model_type", type=str, default="huggingface") parser.add_argument("--model_type", type=str, default="huggingface")
parser.add_argument("--result_csv_file", type=str, default=None) parser.add_argument("--result_csv_file", type=str, default=None)
parser.add_argument("--input_lens", type=str, default="64,64,64,512,1024,1024,2048") parser.add_argument("--input_lens", type=str, default="8,8,256,1024")
parser.add_argument( parser.add_argument(
"--output_lens", type=str, default="256,512,1024,1024,1024,2048,2048" "--output_lens", type=str, default="256,512,1024,1024"
) )
parser.add_argument("--parallel_nums", type=str, default="1,2,4,16,32") parser.add_argument("--parallel_nums", type=str, default="1,2,4,16,32")
parser.add_argument( parser.add_argument(
@ -225,8 +232,10 @@ if __name__ == "__main__":
raise ValueError("input_lens size must equal output_lens size") raise ValueError("input_lens size must equal output_lens size")
if remote_model: if remote_model:
# Connect to remote model and run benchmarks
connect_to_remote_model() connect_to_remote_model()
else: else:
# Start worker manager and run benchmarks
run_worker_manager( run_worker_manager(
model_name=model_name, model_name=model_name,
model_path=model_path, model_path=model_path,