mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-08 11:47:44 +00:00
model server fix message model
This commit is contained in:
parent
98a7bf93f6
commit
ab40a68c8f
@ -10,33 +10,6 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
||||
print(params)
|
||||
stop = params.get("stop", "###")
|
||||
prompt = params["prompt"]
|
||||
messages = prompt.split(stop)
|
||||
#
|
||||
# # Add history conversation
|
||||
# hist = []
|
||||
# once_conversation = []
|
||||
# for message in messages[:-1]:
|
||||
# if len(message) <= 0:
|
||||
# continue
|
||||
#
|
||||
# if "human:" in message:
|
||||
# once_conversation.append(f"""###system:{message.split("human:")[1]} """ )
|
||||
# elif "system:" in message:
|
||||
# once_conversation.append(f"""###system:{message.split("system:")[1]} """)
|
||||
# elif "ai:" in message:
|
||||
# once_conversation.append(f"""###system:{message.split("ai:")[1]} """)
|
||||
# last_conversation = copy.deepcopy(once_conversation)
|
||||
# hist.append("".join(last_conversation))
|
||||
# once_conversation = []
|
||||
# else:
|
||||
# once_conversation.append(f"""###system:{message} """)
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
# query = "".join(hist)
|
||||
|
||||
query = prompt
|
||||
print("Query Message: ", query)
|
||||
|
||||
@ -66,8 +39,8 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
||||
)
|
||||
|
||||
|
||||
t1 = Thread(target=model.generate, kwargs=generate_kwargs)
|
||||
t1.start()
|
||||
# t1 = Thread(target=model.generate, kwargs=generate_kwargs)
|
||||
# t1.start()
|
||||
|
||||
generator = model.generate(**generate_kwargs)
|
||||
for output in generator:
|
||||
|
@ -52,9 +52,10 @@ class BaseOutputParser(ABC):
|
||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||
"""
|
||||
if data["error_code"] == 0:
|
||||
if CFG.LLM_MODEL in ["vicuna-13b", "guanaco"]:
|
||||
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
if "vicuna" in CFG.LLM_MODEL:
|
||||
output = data["text"][skip_echo_len + 11:].strip()
|
||||
elif "guanaco" in CFG.LLM_MODEL:
|
||||
output = data["text"][skip_echo_len + 14:].replace("<s>", "").strip()
|
||||
else:
|
||||
output = data["text"].strip()
|
||||
|
||||
|
@ -129,7 +129,7 @@ class BaseChat(ABC):
|
||||
def stream_call(self):
|
||||
payload = self.__call_base()
|
||||
|
||||
self.skip_echo_len = len(payload.get('prompt').replace("</s>", " ")) + 11
|
||||
self.skip_echo_len = len(payload.get('prompt').replace("</s>", " "))
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
ai_response_text = ""
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user