diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 0c8abac4f..13fbf2708 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -125,10 +125,16 @@ class BaseChat(ABC): current_prompt = self.prompt_template.format(**input_values) self.current_message.add_system_message(current_prompt) + llm_messages = self.generate_llm_messages() + if not CFG.NEW_SERVER_MODE: + # Not new server mode, we convert the message format(List[ModelMessage]) to list of dict + # fix the error of "Object of type ModelMessage is not JSON serializable" when passing the payload to request.post + llm_messages = list(map(lambda m: m.dict(), llm_messages)) + payload = { "model": self.llm_model, "prompt": self.generate_llm_text(), - "messages": self.generate_llm_messages(), + "messages": llm_messages, "temperature": float(self.prompt_template.temperature), "max_new_tokens": int(self.prompt_template.max_new_tokens), "stop": self.prompt_template.sep, diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 0ca8f97da..d47bc6cc8 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -28,6 +28,15 @@ class BaseChatAdpter: messages = params.get("messages") # Some model scontext to dbgpt server model_context = {"prompt_echo_len_char": -1} + + if messages: + # Dict message to ModelMessage + messages = [ + m if isinstance(m, ModelMessage) else ModelMessage(**m) + for m in messages + ] + params["messages"] = messages + if not conv or not messages: # Nothing to do print(