This commit is contained in:
csunny 2023-05-13 22:50:37 +08:00
parent 8afee1070e
commit 2a2382a902
2 changed files with 17 additions and 34 deletions

View File

@ -71,23 +71,12 @@ def generate_stream(model, tokenizer, params, device,
@torch.inference_mode()
def generate_output(model, tokenizer, params, device, context_len=4096, stream_interval=2):
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """
prompt = params["prompt"]
l_prompt = len(prompt)
temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 2048))
stop_parameter = params.get("stop", None)
if stop_parameter == tokenizer.eos_token:
stop_parameter = None
stop_strings = []
if isinstance(stop_parameter, str):
stop_strings.append(stop_parameter)
elif isinstance(stop_parameter, list):
stop_strings = stop_parameter
elif stop_parameter is None:
pass
else:
raise TypeError("Stop parameter must be string or list of strings.")
stop_str = params.get("stop", None)
input_ids = tokenizer(prompt).input_ids
@ -132,28 +121,17 @@ def generate_output(model, tokenizer, params, device, context_len=4096, stream_i
stopped = False
output = tokenizer.decode(output_ids, skip_special_tokens=True)
print("Partial output:", output)
for stop_str in stop_strings:
# print(f"Looking for '{stop_str}' in '{output[:l_prompt]}'#END")
pos = output.rfind(stop_str)
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
output = tokenizer.decode(output_ids, skip_special_tokens=True)
pos = output.rfind(stop_str, l_prompt)
if pos != -1:
# print("Found stop str: ", output)
output = output[:pos]
# print("Trimmed output: ", output)
stopped = True
stop_word = stop_str
break
else:
pass
# print("Not found")
return output
if stopped:
break
del past_key_values
if pos != -1:
return output[:pos]
return output
@torch.inference_mode()
def generate_output_ex(model, tokenizer, params, device, context_len=2048, stream_interval=2):

View File

@ -94,18 +94,23 @@ async def api_generate_stream(request: Request):
@app.post("/generate")
def generate(prompt_request: PromptRequest):
print(prompt_request)
params = {
"prompt": prompt_request.prompt,
"temperature": prompt_request.temperature,
"max_new_tokens": prompt_request.max_new_tokens,
"stop": prompt_request.stop
}
print("Receive prompt: ", params["prompt"])
output = generate_output(model, tokenizer, params, DEVICE)
print("Output: ", output)
return {"response": output}
response = []
output = generate_stream_gate(params)
for o in output:
print(o)
response.append(o)
rsp = "".join(response)
print("rsp:",rsp)
return {"response": rsp}
@app.post("/embedding")
def embeddings(prompt_request: EmbeddingRequest):