mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-09 04:08:10 +00:00
295 lines
10 KiB
Python
295 lines
10 KiB
Python
"""
|
|
Adapted from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py.
|
|
For benchmarks.
|
|
|
|
"""
|
|
import gc
|
|
from typing import Dict, Iterable
|
|
|
|
import torch
|
|
from fastchat.utils import get_context_length, is_partial_stop, is_sentence_complete
|
|
from transformers.generation.logits_process import (
|
|
LogitsProcessorList,
|
|
RepetitionPenaltyLogitsProcessor,
|
|
TemperatureLogitsWarper,
|
|
TopKLogitsWarper,
|
|
TopPLogitsWarper,
|
|
)
|
|
|
|
|
|
def prepare_logits_processor(
|
|
temperature: float, repetition_penalty: float, top_p: float, top_k: int
|
|
) -> LogitsProcessorList:
|
|
processor_list = LogitsProcessorList()
|
|
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
|
|
if temperature >= 1e-5 and temperature != 1.0:
|
|
processor_list.append(TemperatureLogitsWarper(temperature))
|
|
if repetition_penalty > 1.0:
|
|
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
|
|
if 1e-8 <= top_p < 1.0:
|
|
processor_list.append(TopPLogitsWarper(top_p))
|
|
if top_k > 0:
|
|
processor_list.append(TopKLogitsWarper(top_k))
|
|
return processor_list
|
|
|
|
|
|
@torch.inference_mode()
|
|
def generate_stream(
|
|
model,
|
|
tokenizer,
|
|
params: Dict,
|
|
device: str,
|
|
context_len: int,
|
|
stream_interval: int = 1,
|
|
judge_sent_end: bool = False,
|
|
):
|
|
if hasattr(model, "device"):
|
|
device = model.device
|
|
|
|
# Read parameters
|
|
prompt = params["prompt"]
|
|
len_prompt = len(prompt)
|
|
temperature = float(params.get("temperature", 1.0))
|
|
repetition_penalty = float(params.get("repetition_penalty", 1.0))
|
|
top_p = float(params.get("top_p", 1.0))
|
|
top_k = int(params.get("top_k", -1)) # -1 means disable
|
|
max_new_tokens = int(params.get("max_new_tokens", 256))
|
|
logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1.
|
|
echo = bool(params.get("echo", True))
|
|
stop_str = params.get("stop", None)
|
|
stop_token_ids = params.get("stop_token_ids", None) or []
|
|
if tokenizer.eos_token_id not in stop_token_ids:
|
|
stop_token_ids.append(tokenizer.eos_token_id)
|
|
|
|
logits_processor = prepare_logits_processor(
|
|
temperature, repetition_penalty, top_p, top_k
|
|
)
|
|
input_ids = tokenizer(prompt).input_ids
|
|
|
|
if model.config.is_encoder_decoder:
|
|
max_src_len = context_len
|
|
else: # truncate
|
|
max_src_len = context_len - max_new_tokens - 1
|
|
|
|
input_ids = input_ids[-max_src_len:]
|
|
output_ids = list(input_ids)
|
|
input_echo_len = len(input_ids)
|
|
|
|
# Don't stop generate until max_new_tokens is reached.
|
|
stop_token_ids = []
|
|
stop_str = None
|
|
|
|
if model.config.is_encoder_decoder:
|
|
if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models.
|
|
raise NotImplementedError
|
|
encoder_output = model.encoder(
|
|
input_ids=torch.as_tensor([input_ids], device=device)
|
|
)[0]
|
|
start_ids = torch.as_tensor(
|
|
[[model.generation_config.decoder_start_token_id]],
|
|
dtype=torch.int64,
|
|
device=device,
|
|
)
|
|
else:
|
|
start_ids = torch.as_tensor([input_ids], device=device)
|
|
|
|
past_key_values = out = None
|
|
token_logprobs = [None] # The first token has no logprobs.
|
|
sent_interrupt = False
|
|
finish_reason = None
|
|
for i in range(max_new_tokens):
|
|
if i == 0: # prefill
|
|
if model.config.is_encoder_decoder:
|
|
out = model.decoder(
|
|
input_ids=start_ids,
|
|
encoder_hidden_states=encoder_output,
|
|
use_cache=True,
|
|
)
|
|
logits = model.lm_head(out[0])
|
|
else:
|
|
out = model(input_ids=start_ids, use_cache=True)
|
|
logits = out.logits
|
|
past_key_values = out.past_key_values
|
|
|
|
if logprobs is not None:
|
|
# Prefull logprobs for the prompt.
|
|
shift_input_ids = start_ids[..., 1:].contiguous()
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist()
|
|
for label_id, logit in zip(
|
|
shift_input_ids[0].tolist(), shift_logits[0]
|
|
):
|
|
token_logprobs.append(logit[label_id])
|
|
else: # decoding
|
|
if model.config.is_encoder_decoder:
|
|
out = model.decoder(
|
|
input_ids=torch.as_tensor(
|
|
[[token] if not sent_interrupt else output_ids],
|
|
device=device,
|
|
),
|
|
encoder_hidden_states=encoder_output,
|
|
use_cache=True,
|
|
past_key_values=past_key_values if not sent_interrupt else None,
|
|
)
|
|
sent_interrupt = False
|
|
|
|
logits = model.lm_head(out[0])
|
|
else:
|
|
out = model(
|
|
input_ids=torch.as_tensor(
|
|
[[token] if not sent_interrupt else output_ids],
|
|
device=device,
|
|
),
|
|
use_cache=True,
|
|
past_key_values=past_key_values if not sent_interrupt else None,
|
|
)
|
|
sent_interrupt = False
|
|
logits = out.logits
|
|
past_key_values = out.past_key_values
|
|
|
|
if logits_processor:
|
|
if repetition_penalty > 1.0:
|
|
tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
|
|
else:
|
|
tmp_output_ids = None
|
|
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
|
|
else:
|
|
last_token_logits = logits[0, -1, :]
|
|
|
|
if device == "mps":
|
|
# Switch to CPU by avoiding some bugs in mps backend.
|
|
last_token_logits = last_token_logits.float().to("cpu")
|
|
|
|
if temperature < 1e-5 or top_p < 1e-8: # greedy
|
|
_, indices = torch.topk(last_token_logits, 2)
|
|
tokens = [int(index) for index in indices.tolist()]
|
|
else:
|
|
probs = torch.softmax(last_token_logits, dim=-1)
|
|
indices = torch.multinomial(probs, num_samples=2)
|
|
tokens = [int(token) for token in indices.tolist()]
|
|
token = tokens[0]
|
|
output_ids.append(token)
|
|
if logprobs is not None:
|
|
# Cannot use last_token_logits because logprobs is based on raw logits.
|
|
token_logprobs.append(
|
|
torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist()
|
|
)
|
|
|
|
if token in stop_token_ids:
|
|
stopped = True
|
|
else:
|
|
stopped = False
|
|
|
|
# Yield the output tokens
|
|
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
|
|
if echo:
|
|
tmp_output_ids = output_ids
|
|
rfind_start = len_prompt
|
|
else:
|
|
tmp_output_ids = output_ids[input_echo_len:]
|
|
rfind_start = 0
|
|
|
|
output = tokenizer.decode(
|
|
tmp_output_ids,
|
|
skip_special_tokens=True,
|
|
spaces_between_special_tokens=False,
|
|
clean_up_tokenization_spaces=True,
|
|
)
|
|
ret_logprobs = None
|
|
if logprobs is not None:
|
|
ret_logprobs = {
|
|
"text_offset": [],
|
|
"tokens": [
|
|
tokenizer.decode(token)
|
|
for token in (
|
|
output_ids if echo else output_ids[input_echo_len:]
|
|
)
|
|
],
|
|
"token_logprobs": token_logprobs
|
|
if echo
|
|
else token_logprobs[input_echo_len:],
|
|
"top_logprobs": [{}]
|
|
* len(token_logprobs if echo else token_logprobs[input_echo_len:]),
|
|
}
|
|
# Compute text_offset
|
|
curr_pos = 0
|
|
for text in ret_logprobs["tokens"]:
|
|
ret_logprobs["text_offset"].append(curr_pos)
|
|
curr_pos += len(text)
|
|
|
|
# TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way
|
|
if judge_sent_end and stopped and not is_sentence_complete(output):
|
|
if len(tokens) > 1:
|
|
token = tokens[1]
|
|
output_ids[-1] = token
|
|
else:
|
|
output_ids.pop()
|
|
stopped = False
|
|
sent_interrupt = True
|
|
|
|
partially_stopped = False
|
|
if stop_str:
|
|
if isinstance(stop_str, str):
|
|
pos = output.rfind(stop_str, rfind_start)
|
|
if pos != -1:
|
|
output = output[:pos]
|
|
stopped = True
|
|
else:
|
|
partially_stopped = is_partial_stop(output, stop_str)
|
|
elif isinstance(stop_str, Iterable):
|
|
for each_stop in stop_str:
|
|
pos = output.rfind(each_stop, rfind_start)
|
|
if pos != -1:
|
|
output = output[:pos]
|
|
stopped = True
|
|
break
|
|
else:
|
|
partially_stopped = is_partial_stop(output, each_stop)
|
|
if partially_stopped:
|
|
break
|
|
else:
|
|
raise ValueError("Invalid stop field type.")
|
|
|
|
# Prevent yielding partial stop sequence
|
|
if not partially_stopped:
|
|
yield {
|
|
"text": output,
|
|
"logprobs": ret_logprobs,
|
|
"usage": {
|
|
"prompt_tokens": input_echo_len,
|
|
"completion_tokens": i,
|
|
"total_tokens": input_echo_len + i,
|
|
},
|
|
"finish_reason": None,
|
|
}
|
|
|
|
if stopped:
|
|
break
|
|
|
|
# Finish stream event, which contains finish reason
|
|
else:
|
|
finish_reason = "length"
|
|
|
|
if stopped:
|
|
finish_reason = "stop"
|
|
|
|
yield {
|
|
"text": output,
|
|
"logprobs": ret_logprobs,
|
|
"usage": {
|
|
"prompt_tokens": input_echo_len,
|
|
"completion_tokens": i,
|
|
"total_tokens": input_echo_len + i,
|
|
},
|
|
"finish_reason": finish_reason,
|
|
}
|
|
|
|
# Clean
|
|
del past_key_values, out
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
if device == "xpu":
|
|
torch.xpu.empty_cache()
|
|
if device == "npu":
|
|
torch.npu.empty_cache()
|