diff --git a/pilot/model/llm_out/guanaco_llm.py b/pilot/model/llm_out/guanaco_llm.py index 6c209b565..ad322c556 100644 --- a/pilot/model/llm_out/guanaco_llm.py +++ b/pilot/model/llm_out/guanaco_llm.py @@ -10,33 +10,6 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): print(params) stop = params.get("stop", "###") prompt = params["prompt"] - messages = prompt.split(stop) - # - # # Add history conversation - # hist = [] - # once_conversation = [] - # for message in messages[:-1]: - # if len(message) <= 0: - # continue - # - # if "human:" in message: - # once_conversation.append(f"""###system:{message.split("human:")[1]} """ ) - # elif "system:" in message: - # once_conversation.append(f"""###system:{message.split("system:")[1]} """) - # elif "ai:" in message: - # once_conversation.append(f"""###system:{message.split("ai:")[1]} """) - # last_conversation = copy.deepcopy(once_conversation) - # hist.append("".join(last_conversation)) - # once_conversation = [] - # else: - # once_conversation.append(f"""###system:{message} """) - # - # - # - # - # - # query = "".join(hist) - query = prompt print("Query Message: ", query) @@ -66,8 +39,8 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): ) - t1 = Thread(target=model.generate, kwargs=generate_kwargs) - t1.start() + # t1 = Thread(target=model.generate, kwargs=generate_kwargs) + # t1.start() generator = model.generate(**generate_kwargs) for output in generator: diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 663b87a7d..605dc742d 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -52,9 +52,10 @@ class BaseOutputParser(ABC): """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ if data["error_code"] == 0: - if CFG.LLM_MODEL in ["vicuna-13b", "guanaco"]: - - output = data["text"][skip_echo_len:].strip() + if "vicuna" in CFG.LLM_MODEL: + output = data["text"][skip_echo_len + 11:].strip() + elif "guanaco" in CFG.LLM_MODEL: + output = data["text"][skip_echo_len + 14:].replace("", "").strip() else: output = data["text"].strip() diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 8c8dba501..7690efd43 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -129,7 +129,7 @@ class BaseChat(ABC): def stream_call(self): payload = self.__call_base() - self.skip_echo_len = len(payload.get('prompt').replace("", " ")) + 11 + self.skip_echo_len = len(payload.get('prompt').replace("", " ")) logger.info(f"Requert: \n{payload}") ai_response_text = "" try: