From 2a2382a9028c74f36330fb58fb82be281f771d46 Mon Sep 17 00:00:00 2001 From: csunny Date: Sat, 13 May 2023 22:50:37 +0800 Subject: [PATCH] update --- pilot/model/inference.py | 36 +++++++----------------------------- pilot/server/llmserver.py | 15 ++++++++++----- 2 files changed, 17 insertions(+), 34 deletions(-) diff --git a/pilot/model/inference.py b/pilot/model/inference.py index 109fe7e98..a677c0339 100644 --- a/pilot/model/inference.py +++ b/pilot/model/inference.py @@ -71,23 +71,12 @@ def generate_stream(model, tokenizer, params, device, @torch.inference_mode() def generate_output(model, tokenizer, params, device, context_len=4096, stream_interval=2): """Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """ + prompt = params["prompt"] l_prompt = len(prompt) temperature = float(params.get("temperature", 1.0)) max_new_tokens = int(params.get("max_new_tokens", 2048)) - stop_parameter = params.get("stop", None) - - if stop_parameter == tokenizer.eos_token: - stop_parameter = None - stop_strings = [] - if isinstance(stop_parameter, str): - stop_strings.append(stop_parameter) - elif isinstance(stop_parameter, list): - stop_strings = stop_parameter - elif stop_parameter is None: - pass - else: - raise TypeError("Stop parameter must be string or list of strings.") + stop_str = params.get("stop", None) input_ids = tokenizer(prompt).input_ids @@ -132,28 +121,17 @@ def generate_output(model, tokenizer, params, device, context_len=4096, stream_i stopped = False - output = tokenizer.decode(output_ids, skip_special_tokens=True) - print("Partial output:", output) - for stop_str in stop_strings: - # print(f"Looking for '{stop_str}' in '{output[:l_prompt]}'#END") - pos = output.rfind(stop_str) + if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: + output = tokenizer.decode(output_ids, skip_special_tokens=True) + pos = output.rfind(stop_str, l_prompt) if pos != -1: - # print("Found stop str: ", output) output = output[:pos] - # print("Trimmed output: ", output) stopped = True - stop_word = stop_str - break - else: - pass - # print("Not found") + return output + if stopped: break - del past_key_values - if pos != -1: - return output[:pos] - return output @torch.inference_mode() def generate_output_ex(model, tokenizer, params, device, context_len=2048, stream_interval=2): diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 2920495c5..ba3df420d 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -94,18 +94,23 @@ async def api_generate_stream(request: Request): @app.post("/generate") def generate(prompt_request: PromptRequest): - print(prompt_request) params = { "prompt": prompt_request.prompt, "temperature": prompt_request.temperature, "max_new_tokens": prompt_request.max_new_tokens, "stop": prompt_request.stop } - print("Receive prompt: ", params["prompt"]) - output = generate_output(model, tokenizer, params, DEVICE) - print("Output: ", output) - return {"response": output} + response = [] + output = generate_stream_gate(params) + for o in output: + print(o) + response.append(o) + + rsp = "".join(response) + print("rsp:",rsp) + return {"response": rsp} + @app.post("/embedding") def embeddings(prompt_request: EmbeddingRequest):