diff --git a/pilot/model/llm_out/guanaco_llm.py b/pilot/model/llm_out/guanaco_llm.py index 5b24e69ec..9b8008702 100644 --- a/pilot/model/llm_out/guanaco_llm.py +++ b/pilot/model/llm_out/guanaco_llm.py @@ -64,6 +64,9 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048): print(params) stop = params.get("stop", "###") prompt = params["prompt"] + max_new_tokens = params.get("max_new_tokens", 512) + temerature = params.get("temperature", 1.0) + query = prompt print("Query Message: ", query) @@ -82,7 +85,7 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048): self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs ) -> bool: for stop_id in stop_token_ids: - if input_ids[0][-1] == stop_id: + if input_ids[-1][-1] == stop_id: return True return False @@ -90,8 +93,8 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048): generate_kwargs = dict( input_ids=input_ids, - max_new_tokens=512, - temperature=1.0, + max_new_tokens=max_new_tokens, + temperature=temerature, do_sample=True, top_k=1, streamer=streamer, @@ -100,9 +103,9 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048): ) - generator = model.generate(**generate_kwargs) + model.generate(**generate_kwargs) + out = "" for new_text in streamer: out += new_text - yield new_text - return out \ No newline at end of file + yield out \ No newline at end of file diff --git a/pilot/model/llm_out/proxy_llm.py b/pilot/model/llm_out/proxy_llm.py index 92887cfc6..68512ec3c 100644 --- a/pilot/model/llm_out/proxy_llm.py +++ b/pilot/model/llm_out/proxy_llm.py @@ -68,15 +68,11 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048) "max_tokens": params.get("max_new_tokens"), } - print(payloads) - print(headers) res = requests.post( CFG.proxy_server_url, headers=headers, json=payloads, stream=True ) text = "" - print("====================================res================") - print(res) for line in res.iter_lines(): if line: decoded_line = line.decode("utf-8") diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index bb2d0b2b2..d1dee2e37 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -56,8 +56,12 @@ class BaseOutputParser(ABC): # output = data["text"][skip_echo_len + 11:].strip() output = data["text"][skip_echo_len:].strip() elif "guanaco" in CFG.LLM_MODEL: - # output = data["text"][skip_echo_len + 14:].replace("", "").strip() - output = data["text"][skip_echo_len + 2:].replace("", "").strip() + + # NO stream output + # output = data["text"][skip_echo_len + 2:].replace("", "").strip() + + # stream out output + output = data["text"][11:].replace("", "").strip() else: output = data["text"].strip()