mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-05 10:29:36 +00:00
update
This commit is contained in:
parent
8afee1070e
commit
2a2382a902
@ -71,23 +71,12 @@ def generate_stream(model, tokenizer, params, device,
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate_output(model, tokenizer, params, device, context_len=4096, stream_interval=2):
|
def generate_output(model, tokenizer, params, device, context_len=4096, stream_interval=2):
|
||||||
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """
|
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """
|
||||||
|
|
||||||
prompt = params["prompt"]
|
prompt = params["prompt"]
|
||||||
l_prompt = len(prompt)
|
l_prompt = len(prompt)
|
||||||
temperature = float(params.get("temperature", 1.0))
|
temperature = float(params.get("temperature", 1.0))
|
||||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||||
stop_parameter = params.get("stop", None)
|
stop_str = params.get("stop", None)
|
||||||
|
|
||||||
if stop_parameter == tokenizer.eos_token:
|
|
||||||
stop_parameter = None
|
|
||||||
stop_strings = []
|
|
||||||
if isinstance(stop_parameter, str):
|
|
||||||
stop_strings.append(stop_parameter)
|
|
||||||
elif isinstance(stop_parameter, list):
|
|
||||||
stop_strings = stop_parameter
|
|
||||||
elif stop_parameter is None:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise TypeError("Stop parameter must be string or list of strings.")
|
|
||||||
|
|
||||||
|
|
||||||
input_ids = tokenizer(prompt).input_ids
|
input_ids = tokenizer(prompt).input_ids
|
||||||
@ -132,28 +121,17 @@ def generate_output(model, tokenizer, params, device, context_len=4096, stream_i
|
|||||||
stopped = False
|
stopped = False
|
||||||
|
|
||||||
|
|
||||||
|
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
|
||||||
output = tokenizer.decode(output_ids, skip_special_tokens=True)
|
output = tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||||
print("Partial output:", output)
|
pos = output.rfind(stop_str, l_prompt)
|
||||||
for stop_str in stop_strings:
|
|
||||||
# print(f"Looking for '{stop_str}' in '{output[:l_prompt]}'#END")
|
|
||||||
pos = output.rfind(stop_str)
|
|
||||||
if pos != -1:
|
if pos != -1:
|
||||||
# print("Found stop str: ", output)
|
|
||||||
output = output[:pos]
|
output = output[:pos]
|
||||||
# print("Trimmed output: ", output)
|
|
||||||
stopped = True
|
stopped = True
|
||||||
stop_word = stop_str
|
return output
|
||||||
break
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
# print("Not found")
|
|
||||||
if stopped:
|
if stopped:
|
||||||
break
|
break
|
||||||
|
|
||||||
del past_key_values
|
del past_key_values
|
||||||
if pos != -1:
|
|
||||||
return output[:pos]
|
|
||||||
return output
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate_output_ex(model, tokenizer, params, device, context_len=2048, stream_interval=2):
|
def generate_output_ex(model, tokenizer, params, device, context_len=2048, stream_interval=2):
|
||||||
|
@ -94,17 +94,22 @@ async def api_generate_stream(request: Request):
|
|||||||
|
|
||||||
@app.post("/generate")
|
@app.post("/generate")
|
||||||
def generate(prompt_request: PromptRequest):
|
def generate(prompt_request: PromptRequest):
|
||||||
print(prompt_request)
|
|
||||||
params = {
|
params = {
|
||||||
"prompt": prompt_request.prompt,
|
"prompt": prompt_request.prompt,
|
||||||
"temperature": prompt_request.temperature,
|
"temperature": prompt_request.temperature,
|
||||||
"max_new_tokens": prompt_request.max_new_tokens,
|
"max_new_tokens": prompt_request.max_new_tokens,
|
||||||
"stop": prompt_request.stop
|
"stop": prompt_request.stop
|
||||||
}
|
}
|
||||||
print("Receive prompt: ", params["prompt"])
|
|
||||||
output = generate_output(model, tokenizer, params, DEVICE)
|
response = []
|
||||||
print("Output: ", output)
|
output = generate_stream_gate(params)
|
||||||
return {"response": output}
|
for o in output:
|
||||||
|
print(o)
|
||||||
|
response.append(o)
|
||||||
|
|
||||||
|
rsp = "".join(response)
|
||||||
|
print("rsp:",rsp)
|
||||||
|
return {"response": rsp}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/embedding")
|
@app.post("/embedding")
|
||||||
|
Loading…
Reference in New Issue
Block a user