feat(awel): New MessageConverter and more AWEL operators (#1039)

This commit is contained in:
Fangyin Cheng
2024-01-08 09:40:05 +08:00
committed by GitHub
parent 765fb181f6
commit e8861bd8fa
48 changed files with 2333 additions and 719 deletions

View File

@@ -152,7 +152,7 @@ class LLMModelAdapter(ABC):
return "\n"
def transform_model_messages(
self, messages: List[ModelMessage]
self, messages: List[ModelMessage], convert_to_compatible_format: bool = False
) -> List[Dict[str, str]]:
"""Transform the model messages
@@ -174,15 +174,19 @@ class LLMModelAdapter(ABC):
]
Args:
messages (List[ModelMessage]): The model messages
convert_to_compatible_format (bool, optional): Whether to convert to compatible format. Defaults to False.
Returns:
List[Dict[str, str]]: The transformed model messages
"""
logger.info(f"support_system_message: {self.support_system_message}")
if not self.support_system_message:
if not self.support_system_message and convert_to_compatible_format:
# We will not do any transform in the future
return self._transform_to_no_system_messages(messages)
else:
return ModelMessage.to_openai_messages(messages)
return ModelMessage.to_openai_messages(
messages, convert_to_compatible_format=convert_to_compatible_format
)
def _transform_to_no_system_messages(
self, messages: List[ModelMessage]
@@ -237,6 +241,7 @@ class LLMModelAdapter(ABC):
messages: List[ModelMessage],
tokenizer: Any,
prompt_template: str = None,
convert_to_compatible_format: bool = False,
) -> Optional[str]:
"""Get the string prompt from the given parameters and messages
@@ -247,6 +252,7 @@ class LLMModelAdapter(ABC):
messages (List[ModelMessage]): The model messages
tokenizer (Any): The tokenizer of model, in huggingface chat model, we can create the prompt by tokenizer
prompt_template (str, optional): The prompt template. Defaults to None.
convert_to_compatible_format (bool, optional): Whether to convert to compatible format. Defaults to False.
Returns:
Optional[str]: The string prompt
@@ -262,6 +268,7 @@ class LLMModelAdapter(ABC):
model_context: Dict,
prompt_template: str = None,
):
convert_to_compatible_format = params.get("convert_to_compatible_format")
conv: ConversationAdapter = self.get_default_conv_template(
model_name, model_path
)
@@ -277,6 +284,72 @@ class LLMModelAdapter(ABC):
return None, None, None
conv = conv.copy()
if convert_to_compatible_format:
# In old version, we will convert the messages to compatible format
conv = self._set_conv_converted_messages(conv, messages)
else:
# In new version, we will use the messages directly
conv = self._set_conv_messages(conv, messages)
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
new_prompt = conv.get_prompt()
return new_prompt, conv.stop_str, conv.stop_token_ids
def _set_conv_messages(
self, conv: ConversationAdapter, messages: List[ModelMessage]
) -> ConversationAdapter:
"""Set the messages to the conversation template
Args:
conv (ConversationAdapter): The conversation template
messages (List[ModelMessage]): The model messages
Returns:
ConversationAdapter: The conversation template with messages
"""
system_messages = []
for message in messages:
if isinstance(message, ModelMessage):
role = message.role
content = message.content
elif isinstance(message, dict):
role = message["role"]
content = message["content"]
else:
raise ValueError(f"Invalid message type: {message}")
if role == ModelMessageRoleType.SYSTEM:
system_messages.append(content)
elif role == ModelMessageRoleType.HUMAN:
conv.append_message(conv.roles[0], content)
elif role == ModelMessageRoleType.AI:
conv.append_message(conv.roles[1], content)
else:
raise ValueError(f"Unknown role: {role}")
if len(system_messages) > 1:
raise ValueError(
f"Your system messages have more than one message: {system_messages}"
)
if system_messages:
conv.set_system_message(system_messages[0])
return conv
def _set_conv_converted_messages(
self, conv: ConversationAdapter, messages: List[ModelMessage]
) -> ConversationAdapter:
"""Set the messages to the conversation template
In the old version, we will convert the messages to compatible format.
This method will be deprecated in the future.
Args:
conv (ConversationAdapter): The conversation template
messages (List[ModelMessage]): The model messages
Returns:
ConversationAdapter: The conversation template with messages
"""
system_messages = []
user_messages = []
ai_messages = []
@@ -295,10 +368,8 @@ class LLMModelAdapter(ABC):
# Support for multiple system messages
system_messages.append(content)
elif role == ModelMessageRoleType.HUMAN:
# conv.append_message(conv.roles[0], content)
user_messages.append(content)
elif role == ModelMessageRoleType.AI:
# conv.append_message(conv.roles[1], content)
ai_messages.append(content)
else:
raise ValueError(f"Unknown role: {role}")
@@ -320,10 +391,7 @@ class LLMModelAdapter(ABC):
# TODO join all system messages may not be a good idea
conv.set_system_message("".join(can_use_systems))
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
new_prompt = conv.get_prompt()
return new_prompt, conv.stop_str, conv.stop_token_ids
return conv
def model_adaptation(
self,
@@ -335,6 +403,15 @@ class LLMModelAdapter(ABC):
) -> Tuple[Dict, Dict]:
"""Params adaptation"""
messages = params.get("messages")
convert_to_compatible_format = params.get("convert_to_compatible_format")
message_version = params.get("version", "v2").lower()
logger.info(f"Message version is {message_version}")
if convert_to_compatible_format is None:
# Support convert messages to compatible format when message version is v1
convert_to_compatible_format = message_version == "v1"
# Save to params
params["convert_to_compatible_format"] = convert_to_compatible_format
# Some model context to dbgpt server
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
if messages:
@@ -345,7 +422,9 @@ class LLMModelAdapter(ABC):
]
params["messages"] = messages
new_prompt = self.get_str_prompt(params, messages, tokenizer, prompt_template)
new_prompt = self.get_str_prompt(
params, messages, tokenizer, prompt_template, convert_to_compatible_format
)
conv_stop_str, conv_stop_token_ids = None, None
if not new_prompt:
(