diff --git a/pilot/model/llm_out/guanaco_llm.py b/pilot/model/llm_out/guanaco_llm.py index 9b8008702..1a2d1ae8b 100644 --- a/pilot/model/llm_out/guanaco_llm.py +++ b/pilot/model/llm_out/guanaco_llm.py @@ -76,7 +76,7 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048): streamer = TextIteratorStreamer( tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True ) - + tokenizer.bos_token_id = 1 stop_token_ids = [0] @@ -102,10 +102,9 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048): stopping_criteria=StoppingCriteriaList([stop]), ) - model.generate(**generate_kwargs) out = "" for new_text in streamer: out += new_text - yield out \ No newline at end of file + yield out diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index d1dee2e37..909023f07 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -56,7 +56,6 @@ class BaseOutputParser(ABC): # output = data["text"][skip_echo_len + 11:].strip() output = data["text"][skip_echo_len:].strip() elif "guanaco" in CFG.LLM_MODEL: - # NO stream output # output = data["text"][skip_echo_len + 2:].replace("", "").strip() diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 8db61d09f..4dec22655 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -91,7 +91,7 @@ class GuanacoChatAdapter(BaseChatAdpter): return "guanaco" in model_path def get_generate_stream_func(self): - from pilot.model.llm_out.guanaco_llm import guanaco_generate_stream + from pilot.model.llm_out.guanaco_llm import guanaco_generate_stream return guanaco_generate_stream