mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-23 01:49:58 +00:00
chore: Add pylint for DB-GPT core lib (#1076)
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
"""The interface for LLM."""
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import logging
|
||||
@@ -31,7 +33,8 @@ class ModelInferenceMetrics:
|
||||
"""The timestamp (in milliseconds) when the model inference ends."""
|
||||
|
||||
current_time_ms: Optional[int] = None
|
||||
"""The current timestamp (in milliseconds) when the model inference return partially output(stream)."""
|
||||
"""The current timestamp (in milliseconds) when the model inference return
|
||||
partially output(stream)."""
|
||||
|
||||
first_token_time_ms: Optional[int] = None
|
||||
"""The timestamp (in milliseconds) when the first token is generated."""
|
||||
@@ -64,6 +67,14 @@ class ModelInferenceMetrics:
|
||||
def create_metrics(
|
||||
last_metrics: Optional["ModelInferenceMetrics"] = None,
|
||||
) -> "ModelInferenceMetrics":
|
||||
"""Create metrics for model inference.
|
||||
|
||||
Args:
|
||||
last_metrics(ModelInferenceMetrics): The last metrics.
|
||||
|
||||
Returns:
|
||||
ModelInferenceMetrics: The metrics for model inference.
|
||||
"""
|
||||
start_time_ms = last_metrics.start_time_ms if last_metrics else None
|
||||
first_token_time_ms = last_metrics.first_token_time_ms if last_metrics else None
|
||||
first_completion_time_ms = (
|
||||
@@ -100,15 +111,21 @@ class ModelInferenceMetrics:
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert the model inference metrics to dict."""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
class ModelRequestContext:
|
||||
stream: Optional[bool] = False
|
||||
"""A class to represent the context of a LLM model request."""
|
||||
|
||||
stream: bool = False
|
||||
"""Whether to return a stream of responses."""
|
||||
|
||||
cache_enable: bool = False
|
||||
"""Whether to enable the cache for the model inference"""
|
||||
|
||||
user_name: Optional[str] = None
|
||||
"""The user name of the model request."""
|
||||
|
||||
@@ -129,8 +146,6 @@ class ModelRequestContext:
|
||||
|
||||
request_id: Optional[str] = None
|
||||
"""The request id of the model inference."""
|
||||
cache_enable: Optional[bool] = False
|
||||
"""Whether to enable the cache for the model inference"""
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -141,27 +156,31 @@ class ModelOutput:
|
||||
text: str
|
||||
"""The generated text."""
|
||||
error_code: int
|
||||
"""The error code of the model inference. If the model inference is successful, the error code is 0."""
|
||||
model_context: Dict = None
|
||||
finish_reason: str = None
|
||||
usage: Dict[str, Any] = None
|
||||
"""The error code of the model inference. If the model inference is successful,
|
||||
the error code is 0."""
|
||||
model_context: Optional[Dict] = None
|
||||
finish_reason: Optional[str] = None
|
||||
usage: Optional[Dict[str, Any]] = None
|
||||
metrics: Optional[ModelInferenceMetrics] = None
|
||||
"""Some metrics for model inference"""
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert the model output to dict."""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
_ModelMessageType = Union[ModelMessage, Dict[str, Any]]
|
||||
_ModelMessageType = Union[List[ModelMessage], List[Dict[str, Any]]]
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
class ModelRequest:
|
||||
"""The model request."""
|
||||
|
||||
model: str
|
||||
"""The name of the model."""
|
||||
|
||||
messages: List[_ModelMessageType]
|
||||
messages: _ModelMessageType
|
||||
"""The input messages."""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
@@ -189,28 +208,42 @@ class ModelRequest:
|
||||
@property
|
||||
def stream(self) -> bool:
|
||||
"""Whether to return a stream of responses."""
|
||||
return self.context and self.context.stream
|
||||
return bool(self.context and self.context.stream)
|
||||
|
||||
def copy(self):
|
||||
def copy(self) -> "ModelRequest":
|
||||
"""Copy the model request.
|
||||
|
||||
Returns:
|
||||
ModelRequest: The copied model request.
|
||||
"""
|
||||
new_request = copy.deepcopy(self)
|
||||
# Transform messages to List[ModelMessage]
|
||||
new_request.messages = list(
|
||||
map(
|
||||
lambda m: m if isinstance(m, ModelMessage) else ModelMessage(**m),
|
||||
new_request.messages,
|
||||
)
|
||||
)
|
||||
new_request.messages = new_request.get_messages()
|
||||
return new_request
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert the model request to dict.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The model request in dict.
|
||||
"""
|
||||
new_reqeust = copy.deepcopy(self)
|
||||
new_reqeust.messages = list(
|
||||
map(lambda m: m if isinstance(m, dict) else m.dict(), new_reqeust.messages)
|
||||
)
|
||||
new_messages = []
|
||||
for message in new_reqeust.messages:
|
||||
if isinstance(message, dict):
|
||||
new_messages.append(message)
|
||||
else:
|
||||
new_messages.append(message.dict())
|
||||
new_reqeust.messages = new_messages
|
||||
# Skip None fields
|
||||
return {k: v for k, v in asdict(new_reqeust).items() if v is not None}
|
||||
|
||||
def to_trace_metadata(self):
|
||||
def to_trace_metadata(self) -> Dict[str, Any]:
|
||||
"""Convert the model request to trace metadata.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The trace metadata.
|
||||
"""
|
||||
metadata = self.to_dict()
|
||||
metadata["prompt"] = self.messages_to_string()
|
||||
return metadata
|
||||
@@ -218,16 +251,19 @@ class ModelRequest:
|
||||
def get_messages(self) -> List[ModelMessage]:
|
||||
"""Get the messages.
|
||||
|
||||
If the messages is not a list of ModelMessage, it will be converted to a list of ModelMessage.
|
||||
If the messages is not a list of ModelMessage, it will be converted to a list
|
||||
of ModelMessage.
|
||||
|
||||
Returns:
|
||||
List[ModelMessage]: The messages.
|
||||
"""
|
||||
return list(
|
||||
map(
|
||||
lambda m: m if isinstance(m, ModelMessage) else ModelMessage(**m),
|
||||
self.messages,
|
||||
)
|
||||
)
|
||||
messages = []
|
||||
for message in self.messages:
|
||||
if isinstance(message, dict):
|
||||
messages.append(ModelMessage(**message))
|
||||
else:
|
||||
messages.append(message)
|
||||
return messages
|
||||
|
||||
def get_single_user_message(self) -> Optional[ModelMessage]:
|
||||
"""Get the single user message.
|
||||
@@ -245,20 +281,35 @@ class ModelRequest:
|
||||
model: str,
|
||||
messages: List[ModelMessage],
|
||||
context: Optional[Union[ModelRequestContext, Dict[str, Any], BaseModel]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
echo: Optional[bool] = False,
|
||||
stream: bool = False,
|
||||
echo: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Build a model request.
|
||||
|
||||
Args:
|
||||
model(str): The model name.
|
||||
messages(List[ModelMessage]): The messages.
|
||||
context(Optional[Union[ModelRequestContext, Dict[str, Any], BaseModel]]):
|
||||
The context.
|
||||
stream(bool): Whether to return a stream of responses. Defaults to False.
|
||||
echo(bool): Whether to echo the input messages. Defaults to False.
|
||||
**kwargs: Other arguments.
|
||||
"""
|
||||
if not context:
|
||||
context = ModelRequestContext(stream=stream)
|
||||
context_dict = None
|
||||
if isinstance(context, dict):
|
||||
context_dict = context
|
||||
elif isinstance(context, BaseModel):
|
||||
context_dict = context.dict()
|
||||
if context_dict and "stream" not in context_dict:
|
||||
context_dict["stream"] = stream
|
||||
context = ModelRequestContext(**context_dict)
|
||||
elif not isinstance(context, ModelRequestContext):
|
||||
context_dict = None
|
||||
if isinstance(context, dict):
|
||||
context_dict = context
|
||||
elif isinstance(context, BaseModel):
|
||||
context_dict = context.dict()
|
||||
if context_dict and "stream" not in context_dict:
|
||||
context_dict["stream"] = stream
|
||||
if context_dict:
|
||||
context = ModelRequestContext(**context_dict)
|
||||
else:
|
||||
context = ModelRequestContext(stream=stream)
|
||||
return ModelRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
@@ -292,7 +343,6 @@ class ModelRequest:
|
||||
ValueError: If the message role is not supported
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.core.interface.message import (
|
||||
@@ -337,7 +387,7 @@ class ModelRequest:
|
||||
class ModelExtraMedata(BaseParameters):
|
||||
"""A class to represent the extra metadata of a LLM."""
|
||||
|
||||
prompt_roles: Optional[List[str]] = field(
|
||||
prompt_roles: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
ModelMessageRoleType.SYSTEM,
|
||||
ModelMessageRoleType.HUMAN,
|
||||
@@ -356,7 +406,8 @@ class ModelExtraMedata(BaseParameters):
|
||||
prompt_chat_template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The chat template, see: https://huggingface.co/docs/transformers/main/en/chat_templating"
|
||||
"help": "The chat template, see: "
|
||||
"https://huggingface.co/docs/transformers/main/en/chat_templating"
|
||||
},
|
||||
)
|
||||
|
||||
@@ -403,19 +454,19 @@ class ModelMetadata(BaseParameters):
|
||||
def from_dict(
|
||||
cls, data: dict, ignore_extra_fields: bool = False
|
||||
) -> "ModelMetadata":
|
||||
"""Create a new model metadata from a dict."""
|
||||
if "ext_metadata" in data:
|
||||
data["ext_metadata"] = ModelExtraMedata(**data["ext_metadata"])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class MessageConverter(ABC):
|
||||
"""An abstract class for message converter.
|
||||
r"""An abstract class for message converter.
|
||||
|
||||
Different LLMs may have different message formats, this class is used to convert the messages
|
||||
to the format of the LLM.
|
||||
Different LLMs may have different message formats, this class is used to convert
|
||||
the messages to the format of the LLM.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from typing import List
|
||||
>>> from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
>>> from dbgpt.core.interface.llm import MessageConverter, ModelMetadata
|
||||
@@ -425,7 +476,8 @@ class MessageConverter(ABC):
|
||||
... messages: List[ModelMessage],
|
||||
... model_metadata: Optional[ModelMetadata] = None,
|
||||
... ) -> List[ModelMessage]:
|
||||
... # Convert the messages, merge system messages to the last user message.
|
||||
... # Convert the messages, merge system messages to the last user
|
||||
... # message.
|
||||
... system_message = None
|
||||
... other_messages = []
|
||||
... sep = "\\n"
|
||||
@@ -478,6 +530,7 @@ class DefaultMessageConverter(MessageConverter):
|
||||
"""The default message converter."""
|
||||
|
||||
def __init__(self, prompt_sep: Optional[str] = None):
|
||||
"""Create a new default message converter."""
|
||||
self._prompt_sep = prompt_sep
|
||||
|
||||
def convert(
|
||||
@@ -493,7 +546,8 @@ class DefaultMessageConverter(MessageConverter):
|
||||
|
||||
2. Move the last user's message to the end of the list
|
||||
|
||||
3. Convert the messages to no system message if the model does not support system message
|
||||
3. Convert the messages to no system message if the model does not support
|
||||
system message
|
||||
|
||||
Args:
|
||||
messages(List[ModelMessage]): The messages.
|
||||
@@ -520,10 +574,11 @@ class DefaultMessageConverter(MessageConverter):
|
||||
messages: List[ModelMessage],
|
||||
model_metadata: Optional[ModelMetadata] = None,
|
||||
) -> List[ModelMessage]:
|
||||
"""Convert the messages to no system message.
|
||||
r"""Convert the messages to no system message.
|
||||
|
||||
Examples:
|
||||
>>> # Convert the messages to no system message, just merge system messages to the last user message
|
||||
>>> # Convert the messages to no system message, just merge system messages
|
||||
>>> # to the last user message
|
||||
>>> from typing import List
|
||||
>>> from dbgpt.core.interface.message import (
|
||||
... ModelMessage,
|
||||
@@ -550,7 +605,7 @@ class DefaultMessageConverter(MessageConverter):
|
||||
>>> assert converted_messages == [
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.HUMAN,
|
||||
... content="You are a helpful assistant\\nWho are you",
|
||||
... content="You are a helpful assistant\nWho are you",
|
||||
... ),
|
||||
... ]
|
||||
"""
|
||||
@@ -562,7 +617,8 @@ class DefaultMessageConverter(MessageConverter):
|
||||
result_messages = []
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.SYSTEM:
|
||||
# Not support system message, append system message to the last user message
|
||||
# Not support system message, append system message to the last user
|
||||
# message
|
||||
system_messages.append(message)
|
||||
elif message.role in [
|
||||
ModelMessageRoleType.HUMAN,
|
||||
@@ -578,7 +634,8 @@ class DefaultMessageConverter(MessageConverter):
|
||||
system_message_str = system_messages[0].content
|
||||
|
||||
if system_message_str and result_messages:
|
||||
# Not support system messages, merge system messages to the last user message
|
||||
# Not support system messages, merge system messages to the last user
|
||||
# message
|
||||
result_messages[-1].content = (
|
||||
system_message_str + prompt_sep + result_messages[-1].content
|
||||
)
|
||||
@@ -587,10 +644,9 @@ class DefaultMessageConverter(MessageConverter):
|
||||
def move_last_user_message_to_end(
|
||||
self, messages: List[ModelMessage]
|
||||
) -> List[ModelMessage]:
|
||||
"""Move the last user message to the end of the list.
|
||||
"""Try to move the last user message to the end of the list.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from typing import List
|
||||
>>> from dbgpt.core.interface.message import (
|
||||
... ModelMessage,
|
||||
@@ -660,7 +716,7 @@ class LLMClient(ABC):
|
||||
|
||||
@property
|
||||
def cache(self) -> collections.abc.MutableMapping:
|
||||
"""The cache object to cache the model metadata.
|
||||
"""Return the cache object to cache the model metadata.
|
||||
|
||||
You can override this property to use your own cache object.
|
||||
Returns:
|
||||
@@ -677,7 +733,8 @@ class LLMClient(ABC):
|
||||
"""Generate a response for a given model request.
|
||||
|
||||
Sometimes, different LLMs may have different message formats,
|
||||
you can use the message converter to convert the messages to the format of the LLM.
|
||||
you can use the message converter to convert the messages to the format of the
|
||||
LLM.
|
||||
|
||||
Args:
|
||||
request(ModelRequest): The model request.
|
||||
@@ -697,7 +754,8 @@ class LLMClient(ABC):
|
||||
"""Generate a stream of responses for a given model request.
|
||||
|
||||
Sometimes, different LLMs may have different message formats,
|
||||
you can use the message converter to convert the messages to the format of the LLM.
|
||||
you can use the message converter to convert the messages to the format of the
|
||||
LLM.
|
||||
|
||||
Args:
|
||||
request(ModelRequest): The model request.
|
||||
@@ -733,6 +791,7 @@ class LLMClient(ABC):
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> ModelRequest:
|
||||
"""Covert the message.
|
||||
|
||||
If no message converter is provided, the original request will be returned.
|
||||
|
||||
Args:
|
||||
@@ -746,14 +805,15 @@ class LLMClient(ABC):
|
||||
return request
|
||||
new_request = request.copy()
|
||||
model_metadata = await self.get_model_metadata(request.model)
|
||||
new_messages = message_converter.convert(request.messages, model_metadata)
|
||||
new_messages = message_converter.convert(request.get_messages(), model_metadata)
|
||||
new_request.messages = new_messages
|
||||
return new_request
|
||||
|
||||
async def cached_models(self) -> List[ModelMetadata]:
|
||||
"""Get all the models from the cache or the llm server.
|
||||
|
||||
If the model metadata is not in the cache, it will be fetched from the llm server.
|
||||
If the model metadata is not in the cache, it will be fetched from the
|
||||
llm server.
|
||||
|
||||
Returns:
|
||||
List[ModelMetadata]: A list of model metadata.
|
||||
|
Reference in New Issue
Block a user