diff --git a/pilot/model/guanaco_llm.py b/pilot/model/guanaco_llm.py index c6e91ee6f..75ae5b795 100644 --- a/pilot/model/guanaco_llm.py +++ b/pilot/model/guanaco_llm.py @@ -4,7 +4,7 @@ from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCri from pilot.conversation import ROLE_ASSISTANT, ROLE_USER def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): - """Fork from fastchat: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py""" + """Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py""" stop = params.get("stop", "###") messages = params["prompt"].split(stop) @@ -17,11 +17,6 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): messages[i + 1].split(ROLE_ASSISTANT + ":")[1], ) ) - - - text = + "".join(["".join([f"### USER: {item[0]}\n",f"### Assistant: {item[1]}\n",])for item in hist[:-1]]) - text += "".join(["".join([f"### USER: {hist[-1][0]}\n",f"### Assistant: {hist[-1][1]}\n",])]) - query = messages[-2].split(ROLE_USER + ":")[1] print("Query Message: ", query)