diff --git a/pilot/conversation.py b/pilot/conversation.py index 7f526fb89..0470bc720 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -15,6 +15,9 @@ DB_SETTINGS = { "port": CFG.LOCAL_DB_PORT } +ROLE_USER = "USER" +ROLE_ASSISTANT = "Assistant" + class SeparatorStyle(Enum): SINGLE = auto() TWO = auto() diff --git a/pilot/model/chatglm_llm.py b/pilot/model/chatglm_llm.py index b0a3c8296..0f8b74efa 100644 --- a/pilot/model/chatglm_llm.py +++ b/pilot/model/chatglm_llm.py @@ -3,6 +3,8 @@ import torch +from pilot.conversation import ROLE_USER, ROLE_ASSISTANT + @torch.inference_mode() def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048, stream_interval=2): @@ -30,9 +32,9 @@ def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048, # Add history chat to hist for model. for i in range(1, len(messages) - 2, 2): - hist.append((messages[i].split(":")[1], messages[i+1].split(":")[1])) + hist.append((messages[i].split(ROLE_USER + ":")[1], messages[i+1].split(ROLE_ASSISTANT + ":")[1])) - query = messages[-2].split(":")[1] + query = messages[-2].split(ROLE_USER + ":")[1] print("Query Message: ", query) output = "" i = 0