refactor: Refactor proxy LLM (#1064)

This commit is contained in:
Fangyin Cheng
2024-01-14 21:01:37 +08:00
committed by GitHub
parent a035433170
commit 22bfd01c4b
95 changed files with 2049 additions and 1294 deletions

View File

@@ -127,6 +127,11 @@ class ModelRequestContext:
extra: Optional[Dict[str, Any]] = field(default_factory=dict)
"""The extra information of the model inference."""
request_id: Optional[str] = None
"""The request id of the model inference."""
cache_enable: Optional[bool] = False
"""Whether to enable the cache for the model inference"""
@dataclass
@PublicAPI(stability="beta")
@@ -171,7 +176,7 @@ class ModelRequest:
"""The stop token ids of the model inference."""
context_len: Optional[int] = None
"""The context length of the model inference."""
echo: Optional[bool] = True
echo: Optional[bool] = False
"""Whether to echo the input messages."""
span_id: Optional[str] = None
"""The span id of the model inference."""
@@ -203,7 +208,12 @@ class ModelRequest:
map(lambda m: m if isinstance(m, dict) else m.dict(), new_reqeust.messages)
)
# Skip None fields
return {k: v for k, v in asdict(new_reqeust).items() if v}
return {k: v for k, v in asdict(new_reqeust).items() if v is not None}
def to_trace_metadata(self):
metadata = self.to_dict()
metadata["prompt"] = self.messages_to_string()
return metadata
def get_messages(self) -> List[ModelMessage]:
"""Get the messages.
@@ -234,10 +244,13 @@ class ModelRequest:
def build_request(
model: str,
messages: List[ModelMessage],
context: Union[ModelRequestContext, Dict[str, Any], BaseModel],
context: Optional[Union[ModelRequestContext, Dict[str, Any], BaseModel]] = None,
stream: Optional[bool] = False,
echo: Optional[bool] = False,
**kwargs,
):
if not context:
context = ModelRequestContext(stream=stream)
context_dict = None
if isinstance(context, dict):
context_dict = context
@@ -250,6 +263,7 @@ class ModelRequest:
model=model,
messages=messages,
context=context,
echo=echo,
**kwargs,
)
@@ -261,14 +275,22 @@ class ModelRequest:
**kwargs,
)
def to_openai_messages(self) -> List[Dict[str, Any]]:
"""Convert the messages to the format of OpenAI API.
def to_common_messages(
self, support_system_role: bool = True
) -> List[Dict[str, Any]]:
"""Convert the messages to the common format(like OpenAI API).
This function will move last user message to the end of the list.
Args:
support_system_role (bool): Whether to support system role
Returns:
List[Dict[str, Any]]: The messages in the format of OpenAI API.
Raises:
ValueError: If the message role is not supported
Examples:
.. code-block:: python
@@ -298,7 +320,17 @@ class ModelRequest:
m if isinstance(m, ModelMessage) else ModelMessage(**m)
for m in self.messages
]
return ModelMessage.to_openai_messages(messages)
return ModelMessage.to_common_messages(
messages, support_system_role=support_system_role
)
def messages_to_string(self) -> str:
"""Convert the messages to string.
Returns:
str: The messages in string format.
"""
return ModelMessage.messages_to_string(self.get_messages())
@dataclass
@@ -478,7 +510,7 @@ class DefaultMessageConverter(MessageConverter):
if not model_metadata or not model_metadata.ext_metadata:
logger.warning("No model metadata, skip message system message conversion")
return messages
if model_metadata.ext_metadata.support_system_message:
if not model_metadata.ext_metadata.support_system_message:
# 3. Convert the messages to no system message
return self.convert_to_no_system_message(messages, model_metadata)
return messages

View File

@@ -197,15 +197,24 @@ class ModelMessage(BaseModel):
return result
@staticmethod
def to_openai_messages(
messages: List["ModelMessage"], convert_to_compatible_format: bool = False
def to_common_messages(
messages: List["ModelMessage"],
convert_to_compatible_format: bool = False,
support_system_role: bool = True,
) -> List[Dict[str, str]]:
"""Convert to OpenAI message format and
hugggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
"""Convert to common message format(e.g. OpenAI message format) and
huggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
Args:
messages (List["ModelMessage"]): The model messages
convert_to_compatible_format (bool): Whether to convert to compatible format
support_system_role (bool): Whether to support system role
Returns:
List[Dict[str, str]]: The common messages
Raises:
ValueError: If the message role is not supported
"""
history = []
# Add history conversation
@@ -213,6 +222,8 @@ class ModelMessage(BaseModel):
if message.role == ModelMessageRoleType.HUMAN:
history.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.SYSTEM:
if not support_system_role:
raise ValueError("Current model not support system role")
history.append({"role": "system", "content": message.content})
elif message.role == ModelMessageRoleType.AI:
history.append({"role": "assistant", "content": message.content})
@@ -250,6 +261,18 @@ class ModelMessage(BaseModel):
return str_msg
@staticmethod
def messages_to_string(messages: List["ModelMessage"]) -> str:
"""Convert messages to str
Args:
messages (List[ModelMessage]): The messages
Returns:
str: The str messages
"""
return _messages_to_str(messages)
_SingleRoundMessage = List[BaseMessage]
_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[BaseMessage]]
@@ -264,7 +287,7 @@ def _messages_to_dict(messages: List[BaseMessage]) -> List[Dict]:
def _messages_to_str(
messages: List[BaseMessage],
messages: List[Union[BaseMessage, ModelMessage]],
human_prefix: str = "Human",
ai_prefix: str = "AI",
system_prefix: str = "System",
@@ -272,7 +295,7 @@ def _messages_to_str(
"""Convert messages to str
Args:
messages (List[BaseMessage]): The messages
messages (List[Union[BaseMessage, ModelMessage]]): The messages
human_prefix (str): The human prefix
ai_prefix (str): The ai prefix
system_prefix (str): The system prefix
@@ -291,6 +314,8 @@ def _messages_to_str(
role = system_prefix
elif isinstance(message, ViewMessage):
pass
elif isinstance(message, ModelMessage):
role = message.role
else:
raise ValueError(f"Got unsupported message type: {message}")
if role:

View File

@@ -44,7 +44,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
model_context = data.get("model_context")
has_echo = True
has_echo = False
if model_context and "prompt_echo_len_char" in model_context:
prompt_echo_len_char = int(model_context.get("prompt_echo_len_char", -1))
has_echo = bool(model_context.get("echo", False))

View File

@@ -421,13 +421,13 @@ def test_parse_model_messages_multiple_system_messages():
def test_to_openai_messages(
human_model_message, ai_model_message, system_model_message
):
none_messages = ModelMessage.to_openai_messages([])
none_messages = ModelMessage.to_common_messages([])
assert none_messages == []
single_messages = ModelMessage.to_openai_messages([human_model_message])
single_messages = ModelMessage.to_common_messages([human_model_message])
assert single_messages == [{"role": "user", "content": human_model_message.content}]
normal_messages = ModelMessage.to_openai_messages(
normal_messages = ModelMessage.to_common_messages(
[
system_model_message,
human_model_message,
@@ -446,7 +446,7 @@ def test_to_openai_messages(
def test_to_openai_messages_convert_to_compatible_format(
human_model_message, ai_model_message, system_model_message
):
shuffle_messages = ModelMessage.to_openai_messages(
shuffle_messages = ModelMessage.to_common_messages(
[
system_model_message,
human_model_message,