model server fix message model

This commit is contained in:
yhjun1026 2023-06-01 19:10:12 +08:00
parent 98a7bf93f6
commit ab40a68c8f
3 changed files with 7 additions and 33 deletions

View File

@ -10,33 +10,6 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
print(params) print(params)
stop = params.get("stop", "###") stop = params.get("stop", "###")
prompt = params["prompt"] prompt = params["prompt"]
messages = prompt.split(stop)
#
# # Add history conversation
# hist = []
# once_conversation = []
# for message in messages[:-1]:
# if len(message) <= 0:
# continue
#
# if "human:" in message:
# once_conversation.append(f"""###system:{message.split("human:")[1]} """ )
# elif "system:" in message:
# once_conversation.append(f"""###system:{message.split("system:")[1]} """)
# elif "ai:" in message:
# once_conversation.append(f"""###system:{message.split("ai:")[1]} """)
# last_conversation = copy.deepcopy(once_conversation)
# hist.append("".join(last_conversation))
# once_conversation = []
# else:
# once_conversation.append(f"""###system:{message} """)
#
#
#
#
#
# query = "".join(hist)
query = prompt query = prompt
print("Query Message: ", query) print("Query Message: ", query)
@ -66,8 +39,8 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
) )
t1 = Thread(target=model.generate, kwargs=generate_kwargs) # t1 = Thread(target=model.generate, kwargs=generate_kwargs)
t1.start() # t1.start()
generator = model.generate(**generate_kwargs) generator = model.generate(**generate_kwargs)
for output in generator: for output in generator:

View File

@ -52,9 +52,10 @@ class BaseOutputParser(ABC):
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
""" """
if data["error_code"] == 0: if data["error_code"] == 0:
if CFG.LLM_MODEL in ["vicuna-13b", "guanaco"]: if "vicuna" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len + 11:].strip()
output = data["text"][skip_echo_len:].strip() elif "guanaco" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len + 14:].replace("<s>", "").strip()
else: else:
output = data["text"].strip() output = data["text"].strip()

View File

@ -129,7 +129,7 @@ class BaseChat(ABC):
def stream_call(self): def stream_call(self):
payload = self.__call_base() payload = self.__call_base()
self.skip_echo_len = len(payload.get('prompt').replace("</s>", " ")) + 11 self.skip_echo_len = len(payload.get('prompt').replace("</s>", " "))
logger.info(f"Requert: \n{payload}") logger.info(f"Requert: \n{payload}")
ai_response_text = "" ai_response_text = ""
try: try: