mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-08 11:47:44 +00:00
model server fix message model
This commit is contained in:
parent
98a7bf93f6
commit
ab40a68c8f
@ -10,33 +10,6 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
|||||||
print(params)
|
print(params)
|
||||||
stop = params.get("stop", "###")
|
stop = params.get("stop", "###")
|
||||||
prompt = params["prompt"]
|
prompt = params["prompt"]
|
||||||
messages = prompt.split(stop)
|
|
||||||
#
|
|
||||||
# # Add history conversation
|
|
||||||
# hist = []
|
|
||||||
# once_conversation = []
|
|
||||||
# for message in messages[:-1]:
|
|
||||||
# if len(message) <= 0:
|
|
||||||
# continue
|
|
||||||
#
|
|
||||||
# if "human:" in message:
|
|
||||||
# once_conversation.append(f"""###system:{message.split("human:")[1]} """ )
|
|
||||||
# elif "system:" in message:
|
|
||||||
# once_conversation.append(f"""###system:{message.split("system:")[1]} """)
|
|
||||||
# elif "ai:" in message:
|
|
||||||
# once_conversation.append(f"""###system:{message.split("ai:")[1]} """)
|
|
||||||
# last_conversation = copy.deepcopy(once_conversation)
|
|
||||||
# hist.append("".join(last_conversation))
|
|
||||||
# once_conversation = []
|
|
||||||
# else:
|
|
||||||
# once_conversation.append(f"""###system:{message} """)
|
|
||||||
#
|
|
||||||
#
|
|
||||||
#
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# query = "".join(hist)
|
|
||||||
|
|
||||||
query = prompt
|
query = prompt
|
||||||
print("Query Message: ", query)
|
print("Query Message: ", query)
|
||||||
|
|
||||||
@ -66,8 +39,8 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
t1 = Thread(target=model.generate, kwargs=generate_kwargs)
|
# t1 = Thread(target=model.generate, kwargs=generate_kwargs)
|
||||||
t1.start()
|
# t1.start()
|
||||||
|
|
||||||
generator = model.generate(**generate_kwargs)
|
generator = model.generate(**generate_kwargs)
|
||||||
for output in generator:
|
for output in generator:
|
||||||
|
@ -52,9 +52,10 @@ class BaseOutputParser(ABC):
|
|||||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||||
"""
|
"""
|
||||||
if data["error_code"] == 0:
|
if data["error_code"] == 0:
|
||||||
if CFG.LLM_MODEL in ["vicuna-13b", "guanaco"]:
|
if "vicuna" in CFG.LLM_MODEL:
|
||||||
|
output = data["text"][skip_echo_len + 11:].strip()
|
||||||
output = data["text"][skip_echo_len:].strip()
|
elif "guanaco" in CFG.LLM_MODEL:
|
||||||
|
output = data["text"][skip_echo_len + 14:].replace("<s>", "").strip()
|
||||||
else:
|
else:
|
||||||
output = data["text"].strip()
|
output = data["text"].strip()
|
||||||
|
|
||||||
|
@ -129,7 +129,7 @@ class BaseChat(ABC):
|
|||||||
def stream_call(self):
|
def stream_call(self):
|
||||||
payload = self.__call_base()
|
payload = self.__call_base()
|
||||||
|
|
||||||
self.skip_echo_len = len(payload.get('prompt').replace("</s>", " ")) + 11
|
self.skip_echo_len = len(payload.get('prompt').replace("</s>", " "))
|
||||||
logger.info(f"Requert: \n{payload}")
|
logger.info(f"Requert: \n{payload}")
|
||||||
ai_response_text = ""
|
ai_response_text = ""
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user