feature: stream output for guanaco (#154)

This commit is contained in:
csunny 2023-06-04 21:47:21 +08:00
parent ff6cc05e11
commit e8a193ef46
3 changed files with 3 additions and 5 deletions

View File

@ -76,7 +76,7 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048):
streamer = TextIteratorStreamer( streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
) )
tokenizer.bos_token_id = 1 tokenizer.bos_token_id = 1
stop_token_ids = [0] stop_token_ids = [0]
@ -102,10 +102,9 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048):
stopping_criteria=StoppingCriteriaList([stop]), stopping_criteria=StoppingCriteriaList([stop]),
) )
model.generate(**generate_kwargs) model.generate(**generate_kwargs)
out = "" out = ""
for new_text in streamer: for new_text in streamer:
out += new_text out += new_text
yield out yield out

View File

@ -56,7 +56,6 @@ class BaseOutputParser(ABC):
# output = data["text"][skip_echo_len + 11:].strip() # output = data["text"][skip_echo_len + 11:].strip()
output = data["text"][skip_echo_len:].strip() output = data["text"][skip_echo_len:].strip()
elif "guanaco" in CFG.LLM_MODEL: elif "guanaco" in CFG.LLM_MODEL:
# NO stream output # NO stream output
# output = data["text"][skip_echo_len + 2:].replace("<s>", "").strip() # output = data["text"][skip_echo_len + 2:].replace("<s>", "").strip()

View File

@ -91,7 +91,7 @@ class GuanacoChatAdapter(BaseChatAdpter):
return "guanaco" in model_path return "guanaco" in model_path
def get_generate_stream_func(self): def get_generate_stream_func(self):
from pilot.model.llm_out.guanaco_llm import guanaco_generate_stream from pilot.model.llm_out.guanaco_llm import guanaco_generate_stream
return guanaco_generate_stream return guanaco_generate_stream