mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-20 09:14:44 +00:00
fix: guanaco stream out (#154)
This commit is contained in:
parent
d1a6222cee
commit
5252d8ed50
@ -34,7 +34,7 @@ LLM_MODEL_CONFIG = {
|
|||||||
"chatglm-6b-int4": os.path.join(MODEL_PATH, "chatglm-6b-int4"),
|
"chatglm-6b-int4": os.path.join(MODEL_PATH, "chatglm-6b-int4"),
|
||||||
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
|
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
|
||||||
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
|
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
|
||||||
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"
|
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
|
||||||
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
|
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
|
||||||
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
|
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
|
||||||
"proxyllm": "proxyllm",
|
"proxyllm": "proxyllm",
|
||||||
|
@ -1,55 +0,0 @@
|
|||||||
import torch
|
|
||||||
from threading import Thread
|
|
||||||
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def guanaco_stream_generate_output(model, tokenizer, params, device, context_len=2048):
|
|
||||||
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
|
|
||||||
tokenizer.bos_token_id = 1
|
|
||||||
print(params)
|
|
||||||
stop = params.get("stop", "###")
|
|
||||||
prompt = params["prompt"]
|
|
||||||
query = prompt
|
|
||||||
print("Query Message: ", query)
|
|
||||||
|
|
||||||
input_ids = tokenizer(query, return_tensors="pt").input_ids
|
|
||||||
input_ids = input_ids.to(model.device)
|
|
||||||
|
|
||||||
streamer = TextIteratorStreamer(
|
|
||||||
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer.bos_token_id = 1
|
|
||||||
stop_token_ids = [0]
|
|
||||||
|
|
||||||
class StopOnTokens(StoppingCriteria):
|
|
||||||
def __call__(
|
|
||||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
|
||||||
) -> bool:
|
|
||||||
for stop_id in stop_token_ids:
|
|
||||||
if input_ids[0][-1] == stop_id:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
stop = StopOnTokens()
|
|
||||||
|
|
||||||
generate_kwargs = dict(
|
|
||||||
input_ids=input_ids,
|
|
||||||
max_new_tokens=512,
|
|
||||||
temperature=1.0,
|
|
||||||
do_sample=True,
|
|
||||||
top_k=1,
|
|
||||||
streamer=streamer,
|
|
||||||
repetition_penalty=1.7,
|
|
||||||
stopping_criteria=StoppingCriteriaList([stop]),
|
|
||||||
)
|
|
||||||
|
|
||||||
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
|
||||||
t.start()
|
|
||||||
|
|
||||||
out = ""
|
|
||||||
for new_text in streamer:
|
|
||||||
out += new_text
|
|
||||||
yield new_text
|
|
||||||
return out
|
|
@ -59,17 +59,6 @@ class ChatGLMChatAdapter(BaseChatAdpter):
|
|||||||
|
|
||||||
return chatglm_generate_stream
|
return chatglm_generate_stream
|
||||||
|
|
||||||
class GuanacoChatAdapter(BaseChatAdpter):
|
|
||||||
"""Model chat adapter for Guanaco"""
|
|
||||||
|
|
||||||
def match(self, model_path: str):
|
|
||||||
return "guanaco" in model_path
|
|
||||||
|
|
||||||
def get_generate_stream_func(self):
|
|
||||||
from pilot.model.llm_out.guanaco_stream_llm import guanaco_stream_generate_output
|
|
||||||
|
|
||||||
return guanaco_generate_output
|
|
||||||
|
|
||||||
class CodeT5ChatAdapter(BaseChatAdpter):
|
class CodeT5ChatAdapter(BaseChatAdpter):
|
||||||
|
|
||||||
"""Model chat adapter for CodeT5"""
|
"""Model chat adapter for CodeT5"""
|
||||||
|
Loading…
Reference in New Issue
Block a user