mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-23 10:20:01 +00:00
refactor: Refactor proxy LLM (#1064)
This commit is contained in:
@@ -127,6 +127,11 @@ class ModelRequestContext:
|
||||
extra: Optional[Dict[str, Any]] = field(default_factory=dict)
|
||||
"""The extra information of the model inference."""
|
||||
|
||||
request_id: Optional[str] = None
|
||||
"""The request id of the model inference."""
|
||||
cache_enable: Optional[bool] = False
|
||||
"""Whether to enable the cache for the model inference"""
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
@@ -171,7 +176,7 @@ class ModelRequest:
|
||||
"""The stop token ids of the model inference."""
|
||||
context_len: Optional[int] = None
|
||||
"""The context length of the model inference."""
|
||||
echo: Optional[bool] = True
|
||||
echo: Optional[bool] = False
|
||||
"""Whether to echo the input messages."""
|
||||
span_id: Optional[str] = None
|
||||
"""The span id of the model inference."""
|
||||
@@ -203,7 +208,12 @@ class ModelRequest:
|
||||
map(lambda m: m if isinstance(m, dict) else m.dict(), new_reqeust.messages)
|
||||
)
|
||||
# Skip None fields
|
||||
return {k: v for k, v in asdict(new_reqeust).items() if v}
|
||||
return {k: v for k, v in asdict(new_reqeust).items() if v is not None}
|
||||
|
||||
def to_trace_metadata(self):
|
||||
metadata = self.to_dict()
|
||||
metadata["prompt"] = self.messages_to_string()
|
||||
return metadata
|
||||
|
||||
def get_messages(self) -> List[ModelMessage]:
|
||||
"""Get the messages.
|
||||
@@ -234,10 +244,13 @@ class ModelRequest:
|
||||
def build_request(
|
||||
model: str,
|
||||
messages: List[ModelMessage],
|
||||
context: Union[ModelRequestContext, Dict[str, Any], BaseModel],
|
||||
context: Optional[Union[ModelRequestContext, Dict[str, Any], BaseModel]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
echo: Optional[bool] = False,
|
||||
**kwargs,
|
||||
):
|
||||
if not context:
|
||||
context = ModelRequestContext(stream=stream)
|
||||
context_dict = None
|
||||
if isinstance(context, dict):
|
||||
context_dict = context
|
||||
@@ -250,6 +263,7 @@ class ModelRequest:
|
||||
model=model,
|
||||
messages=messages,
|
||||
context=context,
|
||||
echo=echo,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -261,14 +275,22 @@ class ModelRequest:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def to_openai_messages(self) -> List[Dict[str, Any]]:
|
||||
"""Convert the messages to the format of OpenAI API.
|
||||
def to_common_messages(
|
||||
self, support_system_role: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert the messages to the common format(like OpenAI API).
|
||||
|
||||
This function will move last user message to the end of the list.
|
||||
|
||||
Args:
|
||||
support_system_role (bool): Whether to support system role
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The messages in the format of OpenAI API.
|
||||
|
||||
Raises:
|
||||
ValueError: If the message role is not supported
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
@@ -298,7 +320,17 @@ class ModelRequest:
|
||||
m if isinstance(m, ModelMessage) else ModelMessage(**m)
|
||||
for m in self.messages
|
||||
]
|
||||
return ModelMessage.to_openai_messages(messages)
|
||||
return ModelMessage.to_common_messages(
|
||||
messages, support_system_role=support_system_role
|
||||
)
|
||||
|
||||
def messages_to_string(self) -> str:
|
||||
"""Convert the messages to string.
|
||||
|
||||
Returns:
|
||||
str: The messages in string format.
|
||||
"""
|
||||
return ModelMessage.messages_to_string(self.get_messages())
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -478,7 +510,7 @@ class DefaultMessageConverter(MessageConverter):
|
||||
if not model_metadata or not model_metadata.ext_metadata:
|
||||
logger.warning("No model metadata, skip message system message conversion")
|
||||
return messages
|
||||
if model_metadata.ext_metadata.support_system_message:
|
||||
if not model_metadata.ext_metadata.support_system_message:
|
||||
# 3. Convert the messages to no system message
|
||||
return self.convert_to_no_system_message(messages, model_metadata)
|
||||
return messages
|
||||
|
Reference in New Issue
Block a user