chore: Add pylint for DB-GPT core lib (#1076)

This commit is contained in:
Fangyin Cheng
2024-01-16 17:36:26 +08:00
committed by GitHub
parent 3a54d1ef9a
commit 40c853575a
79 changed files with 2213 additions and 839 deletions

View File

@@ -1,3 +1,5 @@
"""The interface for LLM."""
import collections
import copy
import logging
@@ -31,7 +33,8 @@ class ModelInferenceMetrics:
"""The timestamp (in milliseconds) when the model inference ends."""
current_time_ms: Optional[int] = None
"""The current timestamp (in milliseconds) when the model inference return partially output(stream)."""
"""The current timestamp (in milliseconds) when the model inference return
partially output(stream)."""
first_token_time_ms: Optional[int] = None
"""The timestamp (in milliseconds) when the first token is generated."""
@@ -64,6 +67,14 @@ class ModelInferenceMetrics:
def create_metrics(
last_metrics: Optional["ModelInferenceMetrics"] = None,
) -> "ModelInferenceMetrics":
"""Create metrics for model inference.
Args:
last_metrics(ModelInferenceMetrics): The last metrics.
Returns:
ModelInferenceMetrics: The metrics for model inference.
"""
start_time_ms = last_metrics.start_time_ms if last_metrics else None
first_token_time_ms = last_metrics.first_token_time_ms if last_metrics else None
first_completion_time_ms = (
@@ -100,15 +111,21 @@ class ModelInferenceMetrics:
)
def to_dict(self) -> Dict:
"""Convert the model inference metrics to dict."""
return asdict(self)
@dataclass
@PublicAPI(stability="beta")
class ModelRequestContext:
stream: Optional[bool] = False
"""A class to represent the context of a LLM model request."""
stream: bool = False
"""Whether to return a stream of responses."""
cache_enable: bool = False
"""Whether to enable the cache for the model inference"""
user_name: Optional[str] = None
"""The user name of the model request."""
@@ -129,8 +146,6 @@ class ModelRequestContext:
request_id: Optional[str] = None
"""The request id of the model inference."""
cache_enable: Optional[bool] = False
"""Whether to enable the cache for the model inference"""
@dataclass
@@ -141,27 +156,31 @@ class ModelOutput:
text: str
"""The generated text."""
error_code: int
"""The error code of the model inference. If the model inference is successful, the error code is 0."""
model_context: Dict = None
finish_reason: str = None
usage: Dict[str, Any] = None
"""The error code of the model inference. If the model inference is successful,
the error code is 0."""
model_context: Optional[Dict] = None
finish_reason: Optional[str] = None
usage: Optional[Dict[str, Any]] = None
metrics: Optional[ModelInferenceMetrics] = None
"""Some metrics for model inference"""
def to_dict(self) -> Dict:
"""Convert the model output to dict."""
return asdict(self)
_ModelMessageType = Union[ModelMessage, Dict[str, Any]]
_ModelMessageType = Union[List[ModelMessage], List[Dict[str, Any]]]
@dataclass
@PublicAPI(stability="beta")
class ModelRequest:
"""The model request."""
model: str
"""The name of the model."""
messages: List[_ModelMessageType]
messages: _ModelMessageType
"""The input messages."""
temperature: Optional[float] = None
@@ -189,28 +208,42 @@ class ModelRequest:
@property
def stream(self) -> bool:
"""Whether to return a stream of responses."""
return self.context and self.context.stream
return bool(self.context and self.context.stream)
def copy(self):
def copy(self) -> "ModelRequest":
"""Copy the model request.
Returns:
ModelRequest: The copied model request.
"""
new_request = copy.deepcopy(self)
# Transform messages to List[ModelMessage]
new_request.messages = list(
map(
lambda m: m if isinstance(m, ModelMessage) else ModelMessage(**m),
new_request.messages,
)
)
new_request.messages = new_request.get_messages()
return new_request
def to_dict(self) -> Dict[str, Any]:
"""Convert the model request to dict.
Returns:
Dict[str, Any]: The model request in dict.
"""
new_reqeust = copy.deepcopy(self)
new_reqeust.messages = list(
map(lambda m: m if isinstance(m, dict) else m.dict(), new_reqeust.messages)
)
new_messages = []
for message in new_reqeust.messages:
if isinstance(message, dict):
new_messages.append(message)
else:
new_messages.append(message.dict())
new_reqeust.messages = new_messages
# Skip None fields
return {k: v for k, v in asdict(new_reqeust).items() if v is not None}
def to_trace_metadata(self):
def to_trace_metadata(self) -> Dict[str, Any]:
"""Convert the model request to trace metadata.
Returns:
Dict[str, Any]: The trace metadata.
"""
metadata = self.to_dict()
metadata["prompt"] = self.messages_to_string()
return metadata
@@ -218,16 +251,19 @@ class ModelRequest:
def get_messages(self) -> List[ModelMessage]:
"""Get the messages.
If the messages is not a list of ModelMessage, it will be converted to a list of ModelMessage.
If the messages is not a list of ModelMessage, it will be converted to a list
of ModelMessage.
Returns:
List[ModelMessage]: The messages.
"""
return list(
map(
lambda m: m if isinstance(m, ModelMessage) else ModelMessage(**m),
self.messages,
)
)
messages = []
for message in self.messages:
if isinstance(message, dict):
messages.append(ModelMessage(**message))
else:
messages.append(message)
return messages
def get_single_user_message(self) -> Optional[ModelMessage]:
"""Get the single user message.
@@ -245,20 +281,35 @@ class ModelRequest:
model: str,
messages: List[ModelMessage],
context: Optional[Union[ModelRequestContext, Dict[str, Any], BaseModel]] = None,
stream: Optional[bool] = False,
echo: Optional[bool] = False,
stream: bool = False,
echo: bool = False,
**kwargs,
):
"""Build a model request.
Args:
model(str): The model name.
messages(List[ModelMessage]): The messages.
context(Optional[Union[ModelRequestContext, Dict[str, Any], BaseModel]]):
The context.
stream(bool): Whether to return a stream of responses. Defaults to False.
echo(bool): Whether to echo the input messages. Defaults to False.
**kwargs: Other arguments.
"""
if not context:
context = ModelRequestContext(stream=stream)
context_dict = None
if isinstance(context, dict):
context_dict = context
elif isinstance(context, BaseModel):
context_dict = context.dict()
if context_dict and "stream" not in context_dict:
context_dict["stream"] = stream
context = ModelRequestContext(**context_dict)
elif not isinstance(context, ModelRequestContext):
context_dict = None
if isinstance(context, dict):
context_dict = context
elif isinstance(context, BaseModel):
context_dict = context.dict()
if context_dict and "stream" not in context_dict:
context_dict["stream"] = stream
if context_dict:
context = ModelRequestContext(**context_dict)
else:
context = ModelRequestContext(stream=stream)
return ModelRequest(
model=model,
messages=messages,
@@ -292,7 +343,6 @@ class ModelRequest:
ValueError: If the message role is not supported
Examples:
.. code-block:: python
from dbgpt.core.interface.message import (
@@ -337,7 +387,7 @@ class ModelRequest:
class ModelExtraMedata(BaseParameters):
"""A class to represent the extra metadata of a LLM."""
prompt_roles: Optional[List[str]] = field(
prompt_roles: List[str] = field(
default_factory=lambda: [
ModelMessageRoleType.SYSTEM,
ModelMessageRoleType.HUMAN,
@@ -356,7 +406,8 @@ class ModelExtraMedata(BaseParameters):
prompt_chat_template: Optional[str] = field(
default=None,
metadata={
"help": "The chat template, see: https://huggingface.co/docs/transformers/main/en/chat_templating"
"help": "The chat template, see: "
"https://huggingface.co/docs/transformers/main/en/chat_templating"
},
)
@@ -403,19 +454,19 @@ class ModelMetadata(BaseParameters):
def from_dict(
cls, data: dict, ignore_extra_fields: bool = False
) -> "ModelMetadata":
"""Create a new model metadata from a dict."""
if "ext_metadata" in data:
data["ext_metadata"] = ModelExtraMedata(**data["ext_metadata"])
return cls(**data)
class MessageConverter(ABC):
"""An abstract class for message converter.
r"""An abstract class for message converter.
Different LLMs may have different message formats, this class is used to convert the messages
to the format of the LLM.
Different LLMs may have different message formats, this class is used to convert
the messages to the format of the LLM.
Examples:
>>> from typing import List
>>> from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
>>> from dbgpt.core.interface.llm import MessageConverter, ModelMetadata
@@ -425,7 +476,8 @@ class MessageConverter(ABC):
... messages: List[ModelMessage],
... model_metadata: Optional[ModelMetadata] = None,
... ) -> List[ModelMessage]:
... # Convert the messages, merge system messages to the last user message.
... # Convert the messages, merge system messages to the last user
... # message.
... system_message = None
... other_messages = []
... sep = "\\n"
@@ -478,6 +530,7 @@ class DefaultMessageConverter(MessageConverter):
"""The default message converter."""
def __init__(self, prompt_sep: Optional[str] = None):
"""Create a new default message converter."""
self._prompt_sep = prompt_sep
def convert(
@@ -493,7 +546,8 @@ class DefaultMessageConverter(MessageConverter):
2. Move the last user's message to the end of the list
3. Convert the messages to no system message if the model does not support system message
3. Convert the messages to no system message if the model does not support
system message
Args:
messages(List[ModelMessage]): The messages.
@@ -520,10 +574,11 @@ class DefaultMessageConverter(MessageConverter):
messages: List[ModelMessage],
model_metadata: Optional[ModelMetadata] = None,
) -> List[ModelMessage]:
"""Convert the messages to no system message.
r"""Convert the messages to no system message.
Examples:
>>> # Convert the messages to no system message, just merge system messages to the last user message
>>> # Convert the messages to no system message, just merge system messages
>>> # to the last user message
>>> from typing import List
>>> from dbgpt.core.interface.message import (
... ModelMessage,
@@ -550,7 +605,7 @@ class DefaultMessageConverter(MessageConverter):
>>> assert converted_messages == [
... ModelMessage(
... role=ModelMessageRoleType.HUMAN,
... content="You are a helpful assistant\\nWho are you",
... content="You are a helpful assistant\nWho are you",
... ),
... ]
"""
@@ -562,7 +617,8 @@ class DefaultMessageConverter(MessageConverter):
result_messages = []
for message in messages:
if message.role == ModelMessageRoleType.SYSTEM:
# Not support system message, append system message to the last user message
# Not support system message, append system message to the last user
# message
system_messages.append(message)
elif message.role in [
ModelMessageRoleType.HUMAN,
@@ -578,7 +634,8 @@ class DefaultMessageConverter(MessageConverter):
system_message_str = system_messages[0].content
if system_message_str and result_messages:
# Not support system messages, merge system messages to the last user message
# Not support system messages, merge system messages to the last user
# message
result_messages[-1].content = (
system_message_str + prompt_sep + result_messages[-1].content
)
@@ -587,10 +644,9 @@ class DefaultMessageConverter(MessageConverter):
def move_last_user_message_to_end(
self, messages: List[ModelMessage]
) -> List[ModelMessage]:
"""Move the last user message to the end of the list.
"""Try to move the last user message to the end of the list.
Examples:
>>> from typing import List
>>> from dbgpt.core.interface.message import (
... ModelMessage,
@@ -660,7 +716,7 @@ class LLMClient(ABC):
@property
def cache(self) -> collections.abc.MutableMapping:
"""The cache object to cache the model metadata.
"""Return the cache object to cache the model metadata.
You can override this property to use your own cache object.
Returns:
@@ -677,7 +733,8 @@ class LLMClient(ABC):
"""Generate a response for a given model request.
Sometimes, different LLMs may have different message formats,
you can use the message converter to convert the messages to the format of the LLM.
you can use the message converter to convert the messages to the format of the
LLM.
Args:
request(ModelRequest): The model request.
@@ -697,7 +754,8 @@ class LLMClient(ABC):
"""Generate a stream of responses for a given model request.
Sometimes, different LLMs may have different message formats,
you can use the message converter to convert the messages to the format of the LLM.
you can use the message converter to convert the messages to the format of the
LLM.
Args:
request(ModelRequest): The model request.
@@ -733,6 +791,7 @@ class LLMClient(ABC):
message_converter: Optional[MessageConverter] = None,
) -> ModelRequest:
"""Covert the message.
If no message converter is provided, the original request will be returned.
Args:
@@ -746,14 +805,15 @@ class LLMClient(ABC):
return request
new_request = request.copy()
model_metadata = await self.get_model_metadata(request.model)
new_messages = message_converter.convert(request.messages, model_metadata)
new_messages = message_converter.convert(request.get_messages(), model_metadata)
new_request.messages = new_messages
return new_request
async def cached_models(self) -> List[ModelMetadata]:
"""Get all the models from the cache or the llm server.
If the model metadata is not in the cache, it will be fetched from the llm server.
If the model metadata is not in the cache, it will be fetched from the
llm server.
Returns:
List[ModelMetadata]: A list of model metadata.