model server fix message model

This commit is contained in:
yhjun1026 2023-06-01 19:10:12 +08:00
parent 98a7bf93f6
commit ab40a68c8f
3 changed files with 7 additions and 33 deletions

View File

@ -10,33 +10,6 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
print(params)
stop = params.get("stop", "###")
prompt = params["prompt"]
messages = prompt.split(stop)
#
# # Add history conversation
# hist = []
# once_conversation = []
# for message in messages[:-1]:
# if len(message) <= 0:
# continue
#
# if "human:" in message:
# once_conversation.append(f"""###system:{message.split("human:")[1]} """ )
# elif "system:" in message:
# once_conversation.append(f"""###system:{message.split("system:")[1]} """)
# elif "ai:" in message:
# once_conversation.append(f"""###system:{message.split("ai:")[1]} """)
# last_conversation = copy.deepcopy(once_conversation)
# hist.append("".join(last_conversation))
# once_conversation = []
# else:
# once_conversation.append(f"""###system:{message} """)
#
#
#
#
#
# query = "".join(hist)
query = prompt
print("Query Message: ", query)
@ -66,8 +39,8 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
)
t1 = Thread(target=model.generate, kwargs=generate_kwargs)
t1.start()
# t1 = Thread(target=model.generate, kwargs=generate_kwargs)
# t1.start()
generator = model.generate(**generate_kwargs)
for output in generator:

View File

@ -52,9 +52,10 @@ class BaseOutputParser(ABC):
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
if data["error_code"] == 0:
if CFG.LLM_MODEL in ["vicuna-13b", "guanaco"]:
output = data["text"][skip_echo_len:].strip()
if "vicuna" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len + 11:].strip()
elif "guanaco" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len + 14:].replace("<s>", "").strip()
else:
output = data["text"].strip()

View File

@ -129,7 +129,7 @@ class BaseChat(ABC):
def stream_call(self):
payload = self.__call_base()
self.skip_echo_len = len(payload.get('prompt').replace("</s>", " ")) + 11
self.skip_echo_len = len(payload.get('prompt').replace("</s>", " "))
logger.info(f"Requert: \n{payload}")
ai_response_text = ""
try: