mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 08:11:45 +00:00
fix(model): Fix benchmarks bugs
This commit is contained in:
parent
1f22459fbe
commit
4ffd054a2a
@ -1,9 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Iterator, List, Optional
|
|
||||||
|
|
||||||
|
from typing import Dict, Iterator, List, Optional
|
||||||
import time
|
import time
|
||||||
import copy
|
import traceback
|
||||||
|
|
||||||
from pilot.configs.model_config import get_device
|
from pilot.configs.model_config import get_device
|
||||||
from pilot.model.model_adapter import get_llm_model_adapter, LLMModelAdaper
|
from pilot.model.model_adapter import get_llm_model_adapter, LLMModelAdaper
|
||||||
@ -332,6 +332,8 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
text="**GPU OutOfMemory, Please Refresh.**", error_code=1
|
text="**GPU OutOfMemory, Please Refresh.**", error_code=1
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
msg = traceback.format_exc()
|
||||||
|
logger.error(f"Model inference error, detail: {msg}")
|
||||||
model_output = ModelOutput(
|
model_output = ModelOutput(
|
||||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||||
error_code=1,
|
error_code=1,
|
||||||
|
@ -1,9 +1,13 @@
|
|||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
import os
|
||||||
from vllm import AsyncLLMEngine
|
from vllm import AsyncLLMEngine
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
|
||||||
|
_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
|
||||||
|
|
||||||
|
|
||||||
async def generate_stream(
|
async def generate_stream(
|
||||||
model: AsyncLLMEngine, tokenizer, params: Dict, device: str, context_len: int
|
model: AsyncLLMEngine, tokenizer, params: Dict, device: str, context_len: int
|
||||||
):
|
):
|
||||||
@ -37,15 +41,29 @@ async def generate_stream(
|
|||||||
top_p = max(top_p, 1e-5)
|
top_p = max(top_p, 1e-5)
|
||||||
if temperature <= 1e-5:
|
if temperature <= 1e-5:
|
||||||
top_p = 1.0
|
top_p = 1.0
|
||||||
|
gen_params = {
|
||||||
|
"stop": list(stop),
|
||||||
|
"ignore_eos": False,
|
||||||
|
}
|
||||||
|
prompt_token_ids = None
|
||||||
|
if _IS_BENCHMARK:
|
||||||
|
gen_params["stop"] = []
|
||||||
|
gen_params["ignore_eos"] = True
|
||||||
|
prompt_len = context_len - max_new_tokens - 2
|
||||||
|
prompt_token_ids = tokenizer([prompt]).input_ids[0]
|
||||||
|
prompt_token_ids = prompt_token_ids[-prompt_len:]
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
n=1,
|
n=1,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
use_beam_search=False,
|
use_beam_search=False,
|
||||||
stop=list(stop),
|
|
||||||
max_tokens=max_new_tokens,
|
max_tokens=max_new_tokens,
|
||||||
|
**gen_params
|
||||||
|
)
|
||||||
|
|
||||||
|
results_generator = model.generate(
|
||||||
|
prompt, sampling_params, request_id, prompt_token_ids=prompt_token_ids
|
||||||
)
|
)
|
||||||
results_generator = model.generate(prompt, sampling_params, request_id)
|
|
||||||
async for request_output in results_generator:
|
async for request_output in results_generator:
|
||||||
prompt = request_output.prompt
|
prompt = request_output.prompt
|
||||||
if echo:
|
if echo:
|
||||||
|
@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
thread_local = threading.local()
|
thread_local = threading.local()
|
||||||
|
_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
|
||||||
|
|
||||||
_OLD_MODELS = [
|
_OLD_MODELS = [
|
||||||
"llama-cpp",
|
"llama-cpp",
|
||||||
@ -228,6 +228,13 @@ class FastChatLLMModelAdaperWrapper(LLMModelAdaper):
|
|||||||
return self._adapter.load_model(model_path, from_pretrained_kwargs)
|
return self._adapter.load_model(model_path, from_pretrained_kwargs)
|
||||||
|
|
||||||
def get_generate_stream_function(self, model: "TorchNNModule", model_path: str):
|
def get_generate_stream_function(self, model: "TorchNNModule", model_path: str):
|
||||||
|
if _IS_BENCHMARK:
|
||||||
|
from pilot.utils.benchmarks.llm.fastchat_benchmarks_inference import (
|
||||||
|
generate_stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
return generate_stream
|
||||||
|
else:
|
||||||
from fastchat.model.model_adapter import get_generate_stream_function
|
from fastchat.model.model_adapter import get_generate_stream_function
|
||||||
|
|
||||||
return get_generate_stream_function(model, model_path)
|
return get_generate_stream_function(model, model_path)
|
||||||
|
565
pilot/utils/benchmarks/llm/fastchat_benchmarks_inference.py
Normal file
565
pilot/utils/benchmarks/llm/fastchat_benchmarks_inference.py
Normal file
@ -0,0 +1,565 @@
|
|||||||
|
"""
|
||||||
|
Adapted from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py.
|
||||||
|
For benchmarks.
|
||||||
|
|
||||||
|
"""
|
||||||
|
import abc
|
||||||
|
import gc
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import Iterable, Optional, Dict, TYPE_CHECKING
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
import torch
|
||||||
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
LlamaTokenizer,
|
||||||
|
LlamaForCausalLM,
|
||||||
|
AutoModel,
|
||||||
|
AutoModelForSeq2SeqLM,
|
||||||
|
T5Tokenizer,
|
||||||
|
AutoConfig,
|
||||||
|
)
|
||||||
|
from transformers.generation.logits_process import (
|
||||||
|
LogitsProcessorList,
|
||||||
|
RepetitionPenaltyLogitsProcessor,
|
||||||
|
TemperatureLogitsWarper,
|
||||||
|
TopKLogitsWarper,
|
||||||
|
TopPLogitsWarper,
|
||||||
|
)
|
||||||
|
|
||||||
|
from fastchat.conversation import get_conv_template, SeparatorStyle
|
||||||
|
from fastchat.model.model_adapter import (
|
||||||
|
load_model,
|
||||||
|
get_conversation_template,
|
||||||
|
get_generate_stream_function,
|
||||||
|
)
|
||||||
|
from fastchat.modules.awq import AWQConfig
|
||||||
|
from fastchat.modules.gptq import GptqConfig
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from fastchat.modules.exllama import ExllamaConfig
|
||||||
|
from fastchat.modules.xfastertransformer import XftConfig
|
||||||
|
|
||||||
|
from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_logits_processor(
|
||||||
|
temperature: float, repetition_penalty: float, top_p: float, top_k: int
|
||||||
|
) -> LogitsProcessorList:
|
||||||
|
processor_list = LogitsProcessorList()
|
||||||
|
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
|
||||||
|
if temperature >= 1e-5 and temperature != 1.0:
|
||||||
|
processor_list.append(TemperatureLogitsWarper(temperature))
|
||||||
|
if repetition_penalty > 1.0:
|
||||||
|
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
|
||||||
|
if 1e-8 <= top_p < 1.0:
|
||||||
|
processor_list.append(TopPLogitsWarper(top_p))
|
||||||
|
if top_k > 0:
|
||||||
|
processor_list.append(TopKLogitsWarper(top_k))
|
||||||
|
return processor_list
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def generate_stream(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
params: Dict,
|
||||||
|
device: str,
|
||||||
|
context_len: int,
|
||||||
|
stream_interval: int = 2,
|
||||||
|
judge_sent_end: bool = False,
|
||||||
|
):
|
||||||
|
if hasattr(model, "device"):
|
||||||
|
device = model.device
|
||||||
|
|
||||||
|
# Read parameters
|
||||||
|
prompt = params["prompt"]
|
||||||
|
len_prompt = len(prompt)
|
||||||
|
temperature = float(params.get("temperature", 1.0))
|
||||||
|
repetition_penalty = float(params.get("repetition_penalty", 1.0))
|
||||||
|
top_p = float(params.get("top_p", 1.0))
|
||||||
|
top_k = int(params.get("top_k", -1)) # -1 means disable
|
||||||
|
max_new_tokens = int(params.get("max_new_tokens", 256))
|
||||||
|
logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1.
|
||||||
|
echo = bool(params.get("echo", True))
|
||||||
|
stop_str = params.get("stop", None)
|
||||||
|
stop_token_ids = params.get("stop_token_ids", None) or []
|
||||||
|
if tokenizer.eos_token_id not in stop_token_ids:
|
||||||
|
stop_token_ids.append(tokenizer.eos_token_id)
|
||||||
|
|
||||||
|
logits_processor = prepare_logits_processor(
|
||||||
|
temperature, repetition_penalty, top_p, top_k
|
||||||
|
)
|
||||||
|
input_ids = tokenizer(prompt).input_ids
|
||||||
|
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
max_src_len = context_len
|
||||||
|
else: # truncate
|
||||||
|
max_src_len = context_len - max_new_tokens - 1
|
||||||
|
|
||||||
|
input_ids = input_ids[-max_src_len:]
|
||||||
|
output_ids = list(input_ids)
|
||||||
|
input_echo_len = len(input_ids)
|
||||||
|
|
||||||
|
# Don't stop generate until max_new_tokens is reached.
|
||||||
|
stop_token_ids = []
|
||||||
|
stop_str = None
|
||||||
|
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models.
|
||||||
|
raise NotImplementedError
|
||||||
|
encoder_output = model.encoder(
|
||||||
|
input_ids=torch.as_tensor([input_ids], device=device)
|
||||||
|
)[0]
|
||||||
|
start_ids = torch.as_tensor(
|
||||||
|
[[model.generation_config.decoder_start_token_id]],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
start_ids = torch.as_tensor([input_ids], device=device)
|
||||||
|
|
||||||
|
past_key_values = out = None
|
||||||
|
token_logprobs = [None] # The first token has no logprobs.
|
||||||
|
sent_interrupt = False
|
||||||
|
finish_reason = None
|
||||||
|
for i in range(max_new_tokens):
|
||||||
|
if i == 0: # prefill
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
out = model.decoder(
|
||||||
|
input_ids=start_ids,
|
||||||
|
encoder_hidden_states=encoder_output,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
logits = model.lm_head(out[0])
|
||||||
|
else:
|
||||||
|
out = model(input_ids=start_ids, use_cache=True)
|
||||||
|
logits = out.logits
|
||||||
|
past_key_values = out.past_key_values
|
||||||
|
|
||||||
|
if logprobs is not None:
|
||||||
|
# Prefull logprobs for the prompt.
|
||||||
|
shift_input_ids = start_ids[..., 1:].contiguous()
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist()
|
||||||
|
for label_id, logit in zip(
|
||||||
|
shift_input_ids[0].tolist(), shift_logits[0]
|
||||||
|
):
|
||||||
|
token_logprobs.append(logit[label_id])
|
||||||
|
else: # decoding
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
out = model.decoder(
|
||||||
|
input_ids=torch.as_tensor(
|
||||||
|
[[token] if not sent_interrupt else output_ids],
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
encoder_hidden_states=encoder_output,
|
||||||
|
use_cache=True,
|
||||||
|
past_key_values=past_key_values if not sent_interrupt else None,
|
||||||
|
)
|
||||||
|
sent_interrupt = False
|
||||||
|
|
||||||
|
logits = model.lm_head(out[0])
|
||||||
|
else:
|
||||||
|
out = model(
|
||||||
|
input_ids=torch.as_tensor(
|
||||||
|
[[token] if not sent_interrupt else output_ids],
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
use_cache=True,
|
||||||
|
past_key_values=past_key_values if not sent_interrupt else None,
|
||||||
|
)
|
||||||
|
sent_interrupt = False
|
||||||
|
logits = out.logits
|
||||||
|
past_key_values = out.past_key_values
|
||||||
|
|
||||||
|
if logits_processor:
|
||||||
|
if repetition_penalty > 1.0:
|
||||||
|
tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
|
||||||
|
else:
|
||||||
|
tmp_output_ids = None
|
||||||
|
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
|
||||||
|
else:
|
||||||
|
last_token_logits = logits[0, -1, :]
|
||||||
|
|
||||||
|
if device == "mps":
|
||||||
|
# Switch to CPU by avoiding some bugs in mps backend.
|
||||||
|
last_token_logits = last_token_logits.float().to("cpu")
|
||||||
|
|
||||||
|
if temperature < 1e-5 or top_p < 1e-8: # greedy
|
||||||
|
_, indices = torch.topk(last_token_logits, 2)
|
||||||
|
tokens = [int(index) for index in indices.tolist()]
|
||||||
|
else:
|
||||||
|
probs = torch.softmax(last_token_logits, dim=-1)
|
||||||
|
indices = torch.multinomial(probs, num_samples=2)
|
||||||
|
tokens = [int(token) for token in indices.tolist()]
|
||||||
|
token = tokens[0]
|
||||||
|
output_ids.append(token)
|
||||||
|
if logprobs is not None:
|
||||||
|
# Cannot use last_token_logits because logprobs is based on raw logits.
|
||||||
|
token_logprobs.append(
|
||||||
|
torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist()
|
||||||
|
)
|
||||||
|
|
||||||
|
if token in stop_token_ids:
|
||||||
|
stopped = True
|
||||||
|
else:
|
||||||
|
stopped = False
|
||||||
|
|
||||||
|
# Yield the output tokens
|
||||||
|
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
|
||||||
|
if echo:
|
||||||
|
tmp_output_ids = output_ids
|
||||||
|
rfind_start = len_prompt
|
||||||
|
else:
|
||||||
|
tmp_output_ids = output_ids[input_echo_len:]
|
||||||
|
rfind_start = 0
|
||||||
|
|
||||||
|
output = tokenizer.decode(
|
||||||
|
tmp_output_ids,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
spaces_between_special_tokens=False,
|
||||||
|
clean_up_tokenization_spaces=True,
|
||||||
|
)
|
||||||
|
ret_logprobs = None
|
||||||
|
if logprobs is not None:
|
||||||
|
ret_logprobs = {
|
||||||
|
"text_offset": [],
|
||||||
|
"tokens": [
|
||||||
|
tokenizer.decode(token)
|
||||||
|
for token in (
|
||||||
|
output_ids if echo else output_ids[input_echo_len:]
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"token_logprobs": token_logprobs
|
||||||
|
if echo
|
||||||
|
else token_logprobs[input_echo_len:],
|
||||||
|
"top_logprobs": [{}]
|
||||||
|
* len(token_logprobs if echo else token_logprobs[input_echo_len:]),
|
||||||
|
}
|
||||||
|
# Compute text_offset
|
||||||
|
curr_pos = 0
|
||||||
|
for text in ret_logprobs["tokens"]:
|
||||||
|
ret_logprobs["text_offset"].append(curr_pos)
|
||||||
|
curr_pos += len(text)
|
||||||
|
|
||||||
|
# TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way
|
||||||
|
if judge_sent_end and stopped and not is_sentence_complete(output):
|
||||||
|
if len(tokens) > 1:
|
||||||
|
token = tokens[1]
|
||||||
|
output_ids[-1] = token
|
||||||
|
else:
|
||||||
|
output_ids.pop()
|
||||||
|
stopped = False
|
||||||
|
sent_interrupt = True
|
||||||
|
|
||||||
|
partially_stopped = False
|
||||||
|
if stop_str:
|
||||||
|
if isinstance(stop_str, str):
|
||||||
|
pos = output.rfind(stop_str, rfind_start)
|
||||||
|
if pos != -1:
|
||||||
|
output = output[:pos]
|
||||||
|
stopped = True
|
||||||
|
else:
|
||||||
|
partially_stopped = is_partial_stop(output, stop_str)
|
||||||
|
elif isinstance(stop_str, Iterable):
|
||||||
|
for each_stop in stop_str:
|
||||||
|
pos = output.rfind(each_stop, rfind_start)
|
||||||
|
if pos != -1:
|
||||||
|
output = output[:pos]
|
||||||
|
stopped = True
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
partially_stopped = is_partial_stop(output, each_stop)
|
||||||
|
if partially_stopped:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid stop field type.")
|
||||||
|
|
||||||
|
# Prevent yielding partial stop sequence
|
||||||
|
if not partially_stopped:
|
||||||
|
yield {
|
||||||
|
"text": output,
|
||||||
|
"logprobs": ret_logprobs,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": input_echo_len,
|
||||||
|
"completion_tokens": i,
|
||||||
|
"total_tokens": input_echo_len + i,
|
||||||
|
},
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if stopped:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Finish stream event, which contains finish reason
|
||||||
|
else:
|
||||||
|
finish_reason = "length"
|
||||||
|
|
||||||
|
if stopped:
|
||||||
|
finish_reason = "stop"
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"text": output,
|
||||||
|
"logprobs": ret_logprobs,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": input_echo_len,
|
||||||
|
"completion_tokens": i,
|
||||||
|
"total_tokens": input_echo_len + i,
|
||||||
|
},
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Clean
|
||||||
|
del past_key_values, out
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
if device == "xpu":
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
if device == "npu":
|
||||||
|
torch.npu.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
class ChatIO(abc.ABC):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def prompt_for_input(self, role: str) -> str:
|
||||||
|
"""Prompt for input from a role."""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def prompt_for_output(self, role: str):
|
||||||
|
"""Prompt for output from a role."""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def stream_output(self, output_stream):
|
||||||
|
"""Stream output."""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def print_output(self, text: str):
|
||||||
|
"""Print output."""
|
||||||
|
|
||||||
|
|
||||||
|
def chat_loop(
|
||||||
|
model_path: str,
|
||||||
|
device: str,
|
||||||
|
num_gpus: int,
|
||||||
|
max_gpu_memory: str,
|
||||||
|
dtype: Optional[torch.dtype],
|
||||||
|
load_8bit: bool,
|
||||||
|
cpu_offloading: bool,
|
||||||
|
conv_template: Optional[str],
|
||||||
|
conv_system_msg: Optional[str],
|
||||||
|
temperature: float,
|
||||||
|
repetition_penalty: float,
|
||||||
|
max_new_tokens: int,
|
||||||
|
chatio: ChatIO,
|
||||||
|
gptq_config: Optional[GptqConfig] = None,
|
||||||
|
awq_config: Optional[AWQConfig] = None,
|
||||||
|
exllama_config: Optional["ExllamaConfig"] = None,
|
||||||
|
xft_config: Optional["XftConfig"] = None,
|
||||||
|
revision: str = "main",
|
||||||
|
judge_sent_end: bool = True,
|
||||||
|
debug: bool = True,
|
||||||
|
history: bool = True,
|
||||||
|
):
|
||||||
|
# Model
|
||||||
|
model, tokenizer = load_model(
|
||||||
|
model_path,
|
||||||
|
device=device,
|
||||||
|
num_gpus=num_gpus,
|
||||||
|
max_gpu_memory=max_gpu_memory,
|
||||||
|
dtype=dtype,
|
||||||
|
load_8bit=load_8bit,
|
||||||
|
cpu_offloading=cpu_offloading,
|
||||||
|
gptq_config=gptq_config,
|
||||||
|
awq_config=awq_config,
|
||||||
|
exllama_config=exllama_config,
|
||||||
|
xft_config=xft_config,
|
||||||
|
revision=revision,
|
||||||
|
debug=debug,
|
||||||
|
)
|
||||||
|
generate_stream_func = get_generate_stream_function(model, model_path)
|
||||||
|
|
||||||
|
model_type = str(type(model)).lower()
|
||||||
|
is_t5 = "t5" in model_type
|
||||||
|
is_codet5p = "codet5p" in model_type
|
||||||
|
is_xft = "xft" in model_type
|
||||||
|
|
||||||
|
# Hardcode T5's default repetition penalty to be 1.2
|
||||||
|
if is_t5 and repetition_penalty == 1.0:
|
||||||
|
repetition_penalty = 1.2
|
||||||
|
|
||||||
|
# Set context length
|
||||||
|
context_len = get_context_length(model.config)
|
||||||
|
|
||||||
|
# Chat
|
||||||
|
def new_chat():
|
||||||
|
if conv_template:
|
||||||
|
conv = get_conv_template(conv_template)
|
||||||
|
else:
|
||||||
|
conv = get_conversation_template(model_path)
|
||||||
|
if conv_system_msg is not None:
|
||||||
|
conv.set_system_message(conv_system_msg)
|
||||||
|
return conv
|
||||||
|
|
||||||
|
def reload_conv(conv):
|
||||||
|
"""
|
||||||
|
Reprints the conversation from the start.
|
||||||
|
"""
|
||||||
|
for message in conv.messages[conv.offset :]:
|
||||||
|
chatio.prompt_for_output(message[0])
|
||||||
|
chatio.print_output(message[1])
|
||||||
|
|
||||||
|
conv = None
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if not history or not conv:
|
||||||
|
conv = new_chat()
|
||||||
|
|
||||||
|
try:
|
||||||
|
inp = chatio.prompt_for_input(conv.roles[0])
|
||||||
|
except EOFError:
|
||||||
|
inp = ""
|
||||||
|
|
||||||
|
if inp == "!!exit" or not inp:
|
||||||
|
print("exit...")
|
||||||
|
break
|
||||||
|
elif inp == "!!reset":
|
||||||
|
print("resetting...")
|
||||||
|
conv = new_chat()
|
||||||
|
continue
|
||||||
|
elif inp == "!!remove":
|
||||||
|
print("removing last message...")
|
||||||
|
if len(conv.messages) > conv.offset:
|
||||||
|
# Assistant
|
||||||
|
if conv.messages[-1][0] == conv.roles[1]:
|
||||||
|
conv.messages.pop()
|
||||||
|
# User
|
||||||
|
if conv.messages[-1][0] == conv.roles[0]:
|
||||||
|
conv.messages.pop()
|
||||||
|
reload_conv(conv)
|
||||||
|
else:
|
||||||
|
print("No messages to remove.")
|
||||||
|
continue
|
||||||
|
elif inp == "!!regen":
|
||||||
|
print("regenerating last message...")
|
||||||
|
if len(conv.messages) > conv.offset:
|
||||||
|
# Assistant
|
||||||
|
if conv.messages[-1][0] == conv.roles[1]:
|
||||||
|
conv.messages.pop()
|
||||||
|
# User
|
||||||
|
if conv.messages[-1][0] == conv.roles[0]:
|
||||||
|
reload_conv(conv)
|
||||||
|
# Set inp to previous message
|
||||||
|
inp = conv.messages.pop()[1]
|
||||||
|
else:
|
||||||
|
# Shouldn't happen in normal circumstances
|
||||||
|
print("No user message to regenerate from.")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
print("No messages to regenerate.")
|
||||||
|
continue
|
||||||
|
elif inp.startswith("!!save"):
|
||||||
|
args = inp.split(" ", 1)
|
||||||
|
|
||||||
|
if len(args) != 2:
|
||||||
|
print("usage: !!save <filename>")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
filename = args[1]
|
||||||
|
|
||||||
|
# Add .json if extension not present
|
||||||
|
if not "." in filename:
|
||||||
|
filename += ".json"
|
||||||
|
|
||||||
|
print("saving...", filename)
|
||||||
|
with open(filename, "w") as outfile:
|
||||||
|
json.dump(conv.dict(), outfile)
|
||||||
|
continue
|
||||||
|
elif inp.startswith("!!load"):
|
||||||
|
args = inp.split(" ", 1)
|
||||||
|
|
||||||
|
if len(args) != 2:
|
||||||
|
print("usage: !!load <filename>")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
filename = args[1]
|
||||||
|
|
||||||
|
# Check if file exists and add .json if needed
|
||||||
|
if not os.path.exists(filename):
|
||||||
|
if (not filename.endswith(".json")) and os.path.exists(
|
||||||
|
filename + ".json"
|
||||||
|
):
|
||||||
|
filename += ".json"
|
||||||
|
else:
|
||||||
|
print("file not found:", filename)
|
||||||
|
continue
|
||||||
|
|
||||||
|
print("loading...", filename)
|
||||||
|
with open(filename, "r") as infile:
|
||||||
|
new_conv = json.load(infile)
|
||||||
|
|
||||||
|
conv = get_conv_template(new_conv["template_name"])
|
||||||
|
conv.set_system_message(new_conv["system_message"])
|
||||||
|
conv.messages = new_conv["messages"]
|
||||||
|
reload_conv(conv)
|
||||||
|
continue
|
||||||
|
|
||||||
|
conv.append_message(conv.roles[0], inp)
|
||||||
|
conv.append_message(conv.roles[1], None)
|
||||||
|
prompt = conv.get_prompt()
|
||||||
|
|
||||||
|
if is_codet5p: # codet5p is a code completion model.
|
||||||
|
prompt = inp
|
||||||
|
|
||||||
|
gen_params = {
|
||||||
|
"model": model_path,
|
||||||
|
"prompt": prompt,
|
||||||
|
"temperature": temperature,
|
||||||
|
"repetition_penalty": repetition_penalty,
|
||||||
|
"max_new_tokens": max_new_tokens,
|
||||||
|
"stop": conv.stop_str,
|
||||||
|
"stop_token_ids": conv.stop_token_ids,
|
||||||
|
"echo": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
chatio.prompt_for_output(conv.roles[1])
|
||||||
|
output_stream = generate_stream_func(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
gen_params,
|
||||||
|
device,
|
||||||
|
context_len=context_len,
|
||||||
|
judge_sent_end=judge_sent_end,
|
||||||
|
)
|
||||||
|
t = time.time()
|
||||||
|
outputs = chatio.stream_output(output_stream)
|
||||||
|
duration = time.time() - t
|
||||||
|
conv.update_last_message(outputs.strip())
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
num_tokens = len(tokenizer.encode(outputs))
|
||||||
|
msg = {
|
||||||
|
"conv_template": conv.name,
|
||||||
|
"prompt": prompt,
|
||||||
|
"outputs": outputs,
|
||||||
|
"speed (token/s)": round(num_tokens / duration, 2),
|
||||||
|
}
|
||||||
|
print(f"\n{msg}\n")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("stopped generation.")
|
||||||
|
# If generation didn't finish
|
||||||
|
if conv.messages[-1][1] is None:
|
||||||
|
conv.messages.pop()
|
||||||
|
# Remove last user message, so there isn't a double up
|
||||||
|
if conv.messages[-1][0] == conv.roles[0]:
|
||||||
|
conv.messages.pop()
|
||||||
|
|
||||||
|
reload_conv(conv)
|
@ -40,15 +40,6 @@ def get_result_csv_file() -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
input_output_length_pair = [
|
|
||||||
[64, 256],
|
|
||||||
[64, 512],
|
|
||||||
[64, 1024],
|
|
||||||
[512, 1024],
|
|
||||||
[1024, 1024],
|
|
||||||
[1024, 2048],
|
|
||||||
[2048, 2048],
|
|
||||||
]
|
|
||||||
input_lens = [64, 64]
|
input_lens = [64, 64]
|
||||||
output_lens = [256, 512]
|
output_lens = [256, 512]
|
||||||
|
|
||||||
@ -96,8 +87,8 @@ def build_param(
|
|||||||
system_prompt: str = None,
|
system_prompt: str = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
hist = []
|
hist = []
|
||||||
if system_prompt:
|
if system_prompt is not None:
|
||||||
hist.append()(
|
hist.append(
|
||||||
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=system_prompt)
|
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=system_prompt)
|
||||||
)
|
)
|
||||||
hist.append(ModelMessage(role=ModelMessageRoleType.HUMAN, content=user_input))
|
hist.append(ModelMessage(role=ModelMessageRoleType.HUMAN, content=user_input))
|
||||||
@ -119,8 +110,15 @@ async def run_batch(
|
|||||||
):
|
):
|
||||||
tasks = []
|
tasks = []
|
||||||
prompt = read_prompt_from_file("11k")
|
prompt = read_prompt_from_file("11k")
|
||||||
|
if model_type == "vllm":
|
||||||
|
max_input_str_len = input_len
|
||||||
|
if "baichuan" in model_name:
|
||||||
|
# TODO prompt handle first
|
||||||
|
max_input_str_len *= 2
|
||||||
|
prompt = prompt[-max_input_str_len:]
|
||||||
|
|
||||||
for _ in range(parallel_num):
|
for _ in range(parallel_num):
|
||||||
params = build_param(input_len, output_len, prompt)
|
params = build_param(input_len, output_len, prompt, system_prompt="")
|
||||||
tasks.append(wh.generate(params))
|
tasks.append(wh.generate(params))
|
||||||
print(
|
print(
|
||||||
f"Begin run benchmarks, model name: {model_name}, input_len: {input_len}, output_len: {output_len}, parallel_num: {parallel_num}, save result to {output_file}"
|
f"Begin run benchmarks, model name: {model_name}, input_len: {input_len}, output_len: {output_len}, parallel_num: {parallel_num}, save result to {output_file}"
|
||||||
@ -136,6 +134,7 @@ async def run_batch(
|
|||||||
metrics = r.metrics
|
metrics = r.metrics
|
||||||
if isinstance(metrics, dict):
|
if isinstance(metrics, dict):
|
||||||
metrics = ModelInferenceMetrics(**metrics)
|
metrics = ModelInferenceMetrics(**metrics)
|
||||||
|
print(r)
|
||||||
test_total_tokens += metrics.total_tokens
|
test_total_tokens += metrics.total_tokens
|
||||||
row_data = metrics.to_dict()
|
row_data = metrics.to_dict()
|
||||||
rows.append(row_data)
|
rows.append(row_data)
|
||||||
|
@ -11,7 +11,7 @@ parallel_nums=${3:-$default_parallel_nums}
|
|||||||
run_benchmark() {
|
run_benchmark() {
|
||||||
local model_name=$1
|
local model_name=$1
|
||||||
local model_type=$2
|
local model_type=$2
|
||||||
python pilot/utils/benchmarks/llm/llm_benchmarks.py --model_name ${model_name} --model_type ${model_type} --input_lens ${input_lens} --output_lens ${output_lens} --parallel_nums ${parallel_nums}
|
DB_GPT_MODEL_BENCHMARK=true python pilot/utils/benchmarks/llm/llm_benchmarks.py --model_name ${model_name} --model_type ${model_type} --input_lens ${input_lens} --output_lens ${output_lens} --parallel_nums ${parallel_nums}
|
||||||
}
|
}
|
||||||
|
|
||||||
run_benchmark "vicuna-7b-v1.5" "huggingface"
|
run_benchmark "vicuna-7b-v1.5" "huggingface"
|
||||||
|
Loading…
Reference in New Issue
Block a user