mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-12 13:42:23 +00:00
feature: stream output for guanaco (#154)
This commit is contained in:
parent
ff6cc05e11
commit
e8a193ef46
@ -76,7 +76,7 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048):
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
||||
)
|
||||
|
||||
|
||||
tokenizer.bos_token_id = 1
|
||||
stop_token_ids = [0]
|
||||
|
||||
@ -102,10 +102,9 @@ def guanaco_generate_stream(model, tokenizer, params, device, context_len=2048):
|
||||
stopping_criteria=StoppingCriteriaList([stop]),
|
||||
)
|
||||
|
||||
|
||||
model.generate(**generate_kwargs)
|
||||
|
||||
out = ""
|
||||
for new_text in streamer:
|
||||
out += new_text
|
||||
yield out
|
||||
yield out
|
||||
|
@ -56,7 +56,6 @@ class BaseOutputParser(ABC):
|
||||
# output = data["text"][skip_echo_len + 11:].strip()
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
elif "guanaco" in CFG.LLM_MODEL:
|
||||
|
||||
# NO stream output
|
||||
# output = data["text"][skip_echo_len + 2:].replace("<s>", "").strip()
|
||||
|
||||
|
@ -91,7 +91,7 @@ class GuanacoChatAdapter(BaseChatAdpter):
|
||||
return "guanaco" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
from pilot.model.llm_out.guanaco_llm import guanaco_generate_stream
|
||||
from pilot.model.llm_out.guanaco_llm import guanaco_generate_stream
|
||||
|
||||
return guanaco_generate_stream
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user