mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-06 03:20:41 +00:00
feat(awel): New MessageConverter and more AWEL operators (#1039)
This commit is contained in:
@@ -152,7 +152,7 @@ class LLMModelAdapter(ABC):
|
||||
return "\n"
|
||||
|
||||
def transform_model_messages(
|
||||
self, messages: List[ModelMessage]
|
||||
self, messages: List[ModelMessage], convert_to_compatible_format: bool = False
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Transform the model messages
|
||||
|
||||
@@ -174,15 +174,19 @@ class LLMModelAdapter(ABC):
|
||||
]
|
||||
Args:
|
||||
messages (List[ModelMessage]): The model messages
|
||||
convert_to_compatible_format (bool, optional): Whether to convert to compatible format. Defaults to False.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: The transformed model messages
|
||||
"""
|
||||
logger.info(f"support_system_message: {self.support_system_message}")
|
||||
if not self.support_system_message:
|
||||
if not self.support_system_message and convert_to_compatible_format:
|
||||
# We will not do any transform in the future
|
||||
return self._transform_to_no_system_messages(messages)
|
||||
else:
|
||||
return ModelMessage.to_openai_messages(messages)
|
||||
return ModelMessage.to_openai_messages(
|
||||
messages, convert_to_compatible_format=convert_to_compatible_format
|
||||
)
|
||||
|
||||
def _transform_to_no_system_messages(
|
||||
self, messages: List[ModelMessage]
|
||||
@@ -237,6 +241,7 @@ class LLMModelAdapter(ABC):
|
||||
messages: List[ModelMessage],
|
||||
tokenizer: Any,
|
||||
prompt_template: str = None,
|
||||
convert_to_compatible_format: bool = False,
|
||||
) -> Optional[str]:
|
||||
"""Get the string prompt from the given parameters and messages
|
||||
|
||||
@@ -247,6 +252,7 @@ class LLMModelAdapter(ABC):
|
||||
messages (List[ModelMessage]): The model messages
|
||||
tokenizer (Any): The tokenizer of model, in huggingface chat model, we can create the prompt by tokenizer
|
||||
prompt_template (str, optional): The prompt template. Defaults to None.
|
||||
convert_to_compatible_format (bool, optional): Whether to convert to compatible format. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The string prompt
|
||||
@@ -262,6 +268,7 @@ class LLMModelAdapter(ABC):
|
||||
model_context: Dict,
|
||||
prompt_template: str = None,
|
||||
):
|
||||
convert_to_compatible_format = params.get("convert_to_compatible_format")
|
||||
conv: ConversationAdapter = self.get_default_conv_template(
|
||||
model_name, model_path
|
||||
)
|
||||
@@ -277,6 +284,72 @@ class LLMModelAdapter(ABC):
|
||||
return None, None, None
|
||||
|
||||
conv = conv.copy()
|
||||
if convert_to_compatible_format:
|
||||
# In old version, we will convert the messages to compatible format
|
||||
conv = self._set_conv_converted_messages(conv, messages)
|
||||
else:
|
||||
# In new version, we will use the messages directly
|
||||
conv = self._set_conv_messages(conv, messages)
|
||||
|
||||
# Add a blank message for the assistant.
|
||||
conv.append_message(conv.roles[1], None)
|
||||
new_prompt = conv.get_prompt()
|
||||
return new_prompt, conv.stop_str, conv.stop_token_ids
|
||||
|
||||
def _set_conv_messages(
|
||||
self, conv: ConversationAdapter, messages: List[ModelMessage]
|
||||
) -> ConversationAdapter:
|
||||
"""Set the messages to the conversation template
|
||||
|
||||
Args:
|
||||
conv (ConversationAdapter): The conversation template
|
||||
messages (List[ModelMessage]): The model messages
|
||||
|
||||
Returns:
|
||||
ConversationAdapter: The conversation template with messages
|
||||
"""
|
||||
system_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, ModelMessage):
|
||||
role = message.role
|
||||
content = message.content
|
||||
elif isinstance(message, dict):
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
else:
|
||||
raise ValueError(f"Invalid message type: {message}")
|
||||
|
||||
if role == ModelMessageRoleType.SYSTEM:
|
||||
system_messages.append(content)
|
||||
elif role == ModelMessageRoleType.HUMAN:
|
||||
conv.append_message(conv.roles[0], content)
|
||||
elif role == ModelMessageRoleType.AI:
|
||||
conv.append_message(conv.roles[1], content)
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {role}")
|
||||
if len(system_messages) > 1:
|
||||
raise ValueError(
|
||||
f"Your system messages have more than one message: {system_messages}"
|
||||
)
|
||||
if system_messages:
|
||||
conv.set_system_message(system_messages[0])
|
||||
return conv
|
||||
|
||||
def _set_conv_converted_messages(
|
||||
self, conv: ConversationAdapter, messages: List[ModelMessage]
|
||||
) -> ConversationAdapter:
|
||||
"""Set the messages to the conversation template
|
||||
|
||||
In the old version, we will convert the messages to compatible format.
|
||||
This method will be deprecated in the future.
|
||||
|
||||
Args:
|
||||
conv (ConversationAdapter): The conversation template
|
||||
messages (List[ModelMessage]): The model messages
|
||||
|
||||
Returns:
|
||||
ConversationAdapter: The conversation template with messages
|
||||
"""
|
||||
system_messages = []
|
||||
user_messages = []
|
||||
ai_messages = []
|
||||
@@ -295,10 +368,8 @@ class LLMModelAdapter(ABC):
|
||||
# Support for multiple system messages
|
||||
system_messages.append(content)
|
||||
elif role == ModelMessageRoleType.HUMAN:
|
||||
# conv.append_message(conv.roles[0], content)
|
||||
user_messages.append(content)
|
||||
elif role == ModelMessageRoleType.AI:
|
||||
# conv.append_message(conv.roles[1], content)
|
||||
ai_messages.append(content)
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {role}")
|
||||
@@ -320,10 +391,7 @@ class LLMModelAdapter(ABC):
|
||||
|
||||
# TODO join all system messages may not be a good idea
|
||||
conv.set_system_message("".join(can_use_systems))
|
||||
# Add a blank message for the assistant.
|
||||
conv.append_message(conv.roles[1], None)
|
||||
new_prompt = conv.get_prompt()
|
||||
return new_prompt, conv.stop_str, conv.stop_token_ids
|
||||
return conv
|
||||
|
||||
def model_adaptation(
|
||||
self,
|
||||
@@ -335,6 +403,15 @@ class LLMModelAdapter(ABC):
|
||||
) -> Tuple[Dict, Dict]:
|
||||
"""Params adaptation"""
|
||||
messages = params.get("messages")
|
||||
convert_to_compatible_format = params.get("convert_to_compatible_format")
|
||||
message_version = params.get("version", "v2").lower()
|
||||
logger.info(f"Message version is {message_version}")
|
||||
if convert_to_compatible_format is None:
|
||||
# Support convert messages to compatible format when message version is v1
|
||||
convert_to_compatible_format = message_version == "v1"
|
||||
# Save to params
|
||||
params["convert_to_compatible_format"] = convert_to_compatible_format
|
||||
|
||||
# Some model context to dbgpt server
|
||||
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
|
||||
if messages:
|
||||
@@ -345,7 +422,9 @@ class LLMModelAdapter(ABC):
|
||||
]
|
||||
params["messages"] = messages
|
||||
|
||||
new_prompt = self.get_str_prompt(params, messages, tokenizer, prompt_template)
|
||||
new_prompt = self.get_str_prompt(
|
||||
params, messages, tokenizer, prompt_template, convert_to_compatible_format
|
||||
)
|
||||
conv_stop_str, conv_stop_token_ids = None, None
|
||||
if not new_prompt:
|
||||
(
|
||||
|
Reference in New Issue
Block a user