refactor: Refactor proxy LLM (#1064)

This commit is contained in:
Fangyin Cheng
2024-01-14 21:01:37 +08:00
committed by GitHub
parent a035433170
commit 22bfd01c4b
95 changed files with 2049 additions and 1294 deletions

View File

@@ -197,15 +197,24 @@ class ModelMessage(BaseModel):
return result
@staticmethod
def to_openai_messages(
messages: List["ModelMessage"], convert_to_compatible_format: bool = False
def to_common_messages(
messages: List["ModelMessage"],
convert_to_compatible_format: bool = False,
support_system_role: bool = True,
) -> List[Dict[str, str]]:
"""Convert to OpenAI message format and
hugggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
"""Convert to common message format(e.g. OpenAI message format) and
huggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
Args:
messages (List["ModelMessage"]): The model messages
convert_to_compatible_format (bool): Whether to convert to compatible format
support_system_role (bool): Whether to support system role
Returns:
List[Dict[str, str]]: The common messages
Raises:
ValueError: If the message role is not supported
"""
history = []
# Add history conversation
@@ -213,6 +222,8 @@ class ModelMessage(BaseModel):
if message.role == ModelMessageRoleType.HUMAN:
history.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.SYSTEM:
if not support_system_role:
raise ValueError("Current model not support system role")
history.append({"role": "system", "content": message.content})
elif message.role == ModelMessageRoleType.AI:
history.append({"role": "assistant", "content": message.content})
@@ -250,6 +261,18 @@ class ModelMessage(BaseModel):
return str_msg
@staticmethod
def messages_to_string(messages: List["ModelMessage"]) -> str:
"""Convert messages to str
Args:
messages (List[ModelMessage]): The messages
Returns:
str: The str messages
"""
return _messages_to_str(messages)
_SingleRoundMessage = List[BaseMessage]
_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[BaseMessage]]
@@ -264,7 +287,7 @@ def _messages_to_dict(messages: List[BaseMessage]) -> List[Dict]:
def _messages_to_str(
messages: List[BaseMessage],
messages: List[Union[BaseMessage, ModelMessage]],
human_prefix: str = "Human",
ai_prefix: str = "AI",
system_prefix: str = "System",
@@ -272,7 +295,7 @@ def _messages_to_str(
"""Convert messages to str
Args:
messages (List[BaseMessage]): The messages
messages (List[Union[BaseMessage, ModelMessage]]): The messages
human_prefix (str): The human prefix
ai_prefix (str): The ai prefix
system_prefix (str): The system prefix
@@ -291,6 +314,8 @@ def _messages_to_str(
role = system_prefix
elif isinstance(message, ViewMessage):
pass
elif isinstance(message, ModelMessage):
role = message.role
else:
raise ValueError(f"Got unsupported message type: {message}")
if role: