mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 08:11:45 +00:00
update
This commit is contained in:
parent
8afee1070e
commit
2a2382a902
@ -71,23 +71,12 @@ def generate_stream(model, tokenizer, params, device,
|
||||
@torch.inference_mode()
|
||||
def generate_output(model, tokenizer, params, device, context_len=4096, stream_interval=2):
|
||||
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """
|
||||
|
||||
prompt = params["prompt"]
|
||||
l_prompt = len(prompt)
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||
stop_parameter = params.get("stop", None)
|
||||
|
||||
if stop_parameter == tokenizer.eos_token:
|
||||
stop_parameter = None
|
||||
stop_strings = []
|
||||
if isinstance(stop_parameter, str):
|
||||
stop_strings.append(stop_parameter)
|
||||
elif isinstance(stop_parameter, list):
|
||||
stop_strings = stop_parameter
|
||||
elif stop_parameter is None:
|
||||
pass
|
||||
else:
|
||||
raise TypeError("Stop parameter must be string or list of strings.")
|
||||
stop_str = params.get("stop", None)
|
||||
|
||||
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
@ -132,28 +121,17 @@ def generate_output(model, tokenizer, params, device, context_len=4096, stream_i
|
||||
stopped = False
|
||||
|
||||
|
||||
output = tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||
print("Partial output:", output)
|
||||
for stop_str in stop_strings:
|
||||
# print(f"Looking for '{stop_str}' in '{output[:l_prompt]}'#END")
|
||||
pos = output.rfind(stop_str)
|
||||
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
|
||||
output = tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||
pos = output.rfind(stop_str, l_prompt)
|
||||
if pos != -1:
|
||||
# print("Found stop str: ", output)
|
||||
output = output[:pos]
|
||||
# print("Trimmed output: ", output)
|
||||
stopped = True
|
||||
stop_word = stop_str
|
||||
break
|
||||
else:
|
||||
pass
|
||||
# print("Not found")
|
||||
return output
|
||||
|
||||
if stopped:
|
||||
break
|
||||
|
||||
del past_key_values
|
||||
if pos != -1:
|
||||
return output[:pos]
|
||||
return output
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_output_ex(model, tokenizer, params, device, context_len=2048, stream_interval=2):
|
||||
|
@ -94,18 +94,23 @@ async def api_generate_stream(request: Request):
|
||||
|
||||
@app.post("/generate")
|
||||
def generate(prompt_request: PromptRequest):
|
||||
print(prompt_request)
|
||||
params = {
|
||||
"prompt": prompt_request.prompt,
|
||||
"temperature": prompt_request.temperature,
|
||||
"max_new_tokens": prompt_request.max_new_tokens,
|
||||
"stop": prompt_request.stop
|
||||
}
|
||||
print("Receive prompt: ", params["prompt"])
|
||||
output = generate_output(model, tokenizer, params, DEVICE)
|
||||
print("Output: ", output)
|
||||
return {"response": output}
|
||||
|
||||
response = []
|
||||
output = generate_stream_gate(params)
|
||||
for o in output:
|
||||
print(o)
|
||||
response.append(o)
|
||||
|
||||
rsp = "".join(response)
|
||||
print("rsp:",rsp)
|
||||
return {"response": rsp}
|
||||
|
||||
|
||||
@app.post("/embedding")
|
||||
def embeddings(prompt_request: EmbeddingRequest):
|
||||
|
Loading…
Reference in New Issue
Block a user