DB-GPT/dbgpt/core/interface/llm.py
2024-08-29 19:39:42 +08:00

891 lines
30 KiB
Python

"""The interface for LLM."""
import collections
import copy
import logging
import time
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from cachetools import TTLCache
from dbgpt._private.pydantic import BaseModel, model_to_dict
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.util import BaseParameters
from dbgpt.util.annotations import PublicAPI
from dbgpt.util.model_utils import GPUInfo
logger = logging.getLogger(__name__)
@dataclass
@PublicAPI(stability="beta")
class ModelInferenceMetrics:
"""A class to represent metrics for assessing the inference performance of a LLM."""
collect_index: Optional[int] = 0
start_time_ms: Optional[int] = None
"""The timestamp (in milliseconds) when the model inference starts."""
end_time_ms: Optional[int] = None
"""The timestamp (in milliseconds) when the model inference ends."""
current_time_ms: Optional[int] = None
"""The current timestamp (in milliseconds) when the model inference return
partially output(stream)."""
first_token_time_ms: Optional[int] = None
"""The timestamp (in milliseconds) when the first token is generated."""
first_completion_time_ms: Optional[int] = None
"""The timestamp (in milliseconds) when the first completion is generated."""
first_completion_tokens: Optional[int] = None
"""The number of tokens when the first completion is generated."""
prompt_tokens: Optional[int] = None
"""The number of tokens in the input prompt."""
completion_tokens: Optional[int] = None
"""The number of tokens in the generated completion."""
total_tokens: Optional[int] = None
"""The total number of tokens (prompt plus completion)."""
speed_per_second: Optional[float] = None
"""The average number of tokens generated per second."""
current_gpu_infos: Optional[List[GPUInfo]] = None
"""Current gpu information, all devices"""
avg_gpu_infos: Optional[List[GPUInfo]] = None
"""Average memory usage across all collection points"""
@staticmethod
def create_metrics(
last_metrics: Optional["ModelInferenceMetrics"] = None,
) -> "ModelInferenceMetrics":
"""Create metrics for model inference.
Args:
last_metrics(ModelInferenceMetrics): The last metrics.
Returns:
ModelInferenceMetrics: The metrics for model inference.
"""
start_time_ms = last_metrics.start_time_ms if last_metrics else None
first_token_time_ms = last_metrics.first_token_time_ms if last_metrics else None
first_completion_time_ms = (
last_metrics.first_completion_time_ms if last_metrics else None
)
first_completion_tokens = (
last_metrics.first_completion_tokens if last_metrics else None
)
prompt_tokens = last_metrics.prompt_tokens if last_metrics else None
completion_tokens = last_metrics.completion_tokens if last_metrics else None
total_tokens = last_metrics.total_tokens if last_metrics else None
speed_per_second = last_metrics.speed_per_second if last_metrics else None
current_gpu_infos = last_metrics.current_gpu_infos if last_metrics else None
avg_gpu_infos = last_metrics.avg_gpu_infos if last_metrics else None
if not start_time_ms:
start_time_ms = time.time_ns() // 1_000_000
current_time_ms = time.time_ns() // 1_000_000
end_time_ms = current_time_ms
return ModelInferenceMetrics(
start_time_ms=start_time_ms,
end_time_ms=end_time_ms,
current_time_ms=current_time_ms,
first_token_time_ms=first_token_time_ms,
first_completion_time_ms=first_completion_time_ms,
first_completion_tokens=first_completion_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
speed_per_second=speed_per_second,
current_gpu_infos=current_gpu_infos,
avg_gpu_infos=avg_gpu_infos,
)
def to_dict(self) -> Dict:
"""Convert the model inference metrics to dict."""
return asdict(self)
@dataclass
@PublicAPI(stability="beta")
class ModelRequestContext:
"""A class to represent the context of a LLM model request."""
stream: bool = False
"""Whether to return a stream of responses."""
cache_enable: bool = False
"""Whether to enable the cache for the model inference"""
user_name: Optional[str] = None
"""The user name of the model request."""
sys_code: Optional[str] = None
"""The system code of the model request."""
conv_uid: Optional[str] = None
"""The conversation id of the model inference."""
span_id: Optional[str] = None
"""The span id of the model inference."""
chat_mode: Optional[str] = None
"""The chat mode of the model inference."""
chat_param: Optional[str] = None
"""The chat param of chat mode"""
extra: Optional[Dict[str, Any]] = field(default_factory=dict)
"""The extra information of the model inference."""
request_id: Optional[str] = None
"""The request id of the model inference."""
@dataclass
@PublicAPI(stability="beta")
class ModelOutput:
"""A class to represent the output of a LLM.""" ""
text: str
"""The generated text."""
error_code: int
"""The error code of the model inference. If the model inference is successful,
the error code is 0."""
incremental: bool = False
model_context: Optional[Dict] = None
finish_reason: Optional[str] = None
usage: Optional[Dict[str, Any]] = None
metrics: Optional[ModelInferenceMetrics] = None
"""Some metrics for model inference"""
def to_dict(self) -> Dict:
"""Convert the model output to dict."""
return asdict(self)
@property
def success(self) -> bool:
"""Check if the model inference is successful."""
return self.error_code == 0
_ModelMessageType = Union[List[ModelMessage], List[Dict[str, Any]]]
@dataclass
@PublicAPI(stability="beta")
class ModelRequest:
"""The model request."""
model: str
"""The name of the model."""
messages: _ModelMessageType
"""The input messages."""
temperature: Optional[float] = None
"""The temperature of the model inference."""
top_p: Optional[float] = None
"""The top p of the model inference."""
max_new_tokens: Optional[int] = None
"""The maximum number of tokens to generate."""
stop: Optional[str] = None
"""The stop condition of the model inference."""
stop_token_ids: Optional[List[int]] = None
"""The stop token ids of the model inference."""
context_len: Optional[int] = None
"""The context length of the model inference."""
echo: Optional[bool] = False
"""Whether to echo the input messages."""
span_id: Optional[str] = None
"""The span id of the model inference."""
context: Optional[ModelRequestContext] = field(
default_factory=lambda: ModelRequestContext()
)
"""The context of the model inference."""
@property
def stream(self) -> bool:
"""Whether to return a stream of responses."""
return bool(self.context and self.context.stream)
def copy(self) -> "ModelRequest":
"""Copy the model request.
Returns:
ModelRequest: The copied model request.
"""
new_request = copy.deepcopy(self)
# Transform messages to List[ModelMessage]
new_request.messages = new_request.get_messages()
return new_request
def to_dict(self) -> Dict[str, Any]:
"""Convert the model request to dict.
Returns:
Dict[str, Any]: The model request in dict.
"""
new_reqeust = copy.deepcopy(self)
new_messages = []
for message in new_reqeust.messages:
if isinstance(message, dict):
new_messages.append(message)
else:
new_messages.append(message.dict())
new_reqeust.messages = new_messages
# Skip None fields
return {k: v for k, v in asdict(new_reqeust).items() if v is not None}
def to_trace_metadata(self) -> Dict[str, Any]:
"""Convert the model request to trace metadata.
Returns:
Dict[str, Any]: The trace metadata.
"""
metadata = self.to_dict()
metadata["prompt"] = self.messages_to_string()
return metadata
def get_messages(self) -> List[ModelMessage]:
"""Get the messages.
If the messages is not a list of ModelMessage, it will be converted to a list
of ModelMessage.
Returns:
List[ModelMessage]: The messages.
"""
messages = []
for message in self.messages:
if isinstance(message, dict):
messages.append(ModelMessage(**message))
else:
messages.append(message)
return messages
def get_single_user_message(self) -> Optional[ModelMessage]:
"""Get the single user message.
Returns:
Optional[ModelMessage]: The single user message.
"""
messages = self.get_messages()
if len(messages) != 1 and messages[0].role != ModelMessageRoleType.HUMAN:
raise ValueError("The messages is not a single user message")
return messages[0]
@staticmethod
def build_request(
model: str,
messages: List[ModelMessage],
context: Optional[Union[ModelRequestContext, Dict[str, Any], BaseModel]] = None,
stream: bool = False,
echo: bool = False,
**kwargs,
):
"""Build a model request.
Args:
model(str): The model name.
messages(List[ModelMessage]): The messages.
context(Optional[Union[ModelRequestContext, Dict[str, Any], BaseModel]]):
The context.
stream(bool): Whether to return a stream of responses. Defaults to False.
echo(bool): Whether to echo the input messages. Defaults to False.
**kwargs: Other arguments.
"""
if not context:
context = ModelRequestContext(stream=stream)
elif not isinstance(context, ModelRequestContext):
context_dict = None
if isinstance(context, dict):
context_dict = context
elif isinstance(context, BaseModel):
context_dict = model_to_dict(context)
if context_dict and "stream" not in context_dict:
context_dict["stream"] = stream
if context_dict:
context = ModelRequestContext(**context_dict)
else:
context = ModelRequestContext(stream=stream)
return ModelRequest(
model=model,
messages=messages,
context=context,
echo=echo,
**kwargs,
)
@staticmethod
def _build(model: str, prompt: str, **kwargs):
return ModelRequest(
model=model,
messages=[ModelMessage(role=ModelMessageRoleType.HUMAN, content=prompt)],
**kwargs,
)
def to_common_messages(
self, support_system_role: bool = True
) -> List[Dict[str, Any]]:
"""Convert the messages to the common format(like OpenAI API).
This function will move last user message to the end of the list.
Args:
support_system_role (bool): Whether to support system role
Returns:
List[Dict[str, Any]]: The messages in the format of OpenAI API.
Raises:
ValueError: If the message role is not supported
Examples:
.. code-block:: python
from dbgpt.core.interface.message import (
ModelMessage,
ModelMessageRoleType,
)
messages = [
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hi"),
ModelMessage(
role=ModelMessageRoleType.AI, content="Hi, I'm a robot."
),
ModelMessage(
role=ModelMessageRoleType.HUMAN, content="Who are your"
),
]
openai_messages = ModelRequest.to_openai_messages(messages)
assert openai_messages == [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hi, I'm a robot."},
{"role": "user", "content": "Who are your"},
]
"""
messages = [
m if isinstance(m, ModelMessage) else ModelMessage(**m)
for m in self.messages
]
return ModelMessage.to_common_messages(
messages, support_system_role=support_system_role
)
def messages_to_string(self) -> str:
"""Convert the messages to string.
Returns:
str: The messages in string format.
"""
return ModelMessage.messages_to_string(self.get_messages())
@dataclass
class ModelExtraMedata(BaseParameters):
"""A class to represent the extra metadata of a LLM."""
prompt_roles: List[str] = field(
default_factory=lambda: [
ModelMessageRoleType.SYSTEM,
ModelMessageRoleType.HUMAN,
ModelMessageRoleType.AI,
],
metadata={"help": "The roles of the prompt"},
)
prompt_sep: Optional[str] = field(
default="\n",
metadata={"help": "The separator of the prompt between multiple rounds"},
)
# You can see the chat template in your model repo tokenizer config,
# typically in the tokenizer_config.json
prompt_chat_template: Optional[str] = field(
default=None,
metadata={
"help": "The chat template, see: "
"https://huggingface.co/docs/transformers/main/en/chat_templating"
},
)
@property
def support_system_message(self) -> bool:
"""Whether the model supports system message.
Returns:
bool: Whether the model supports system message.
"""
return ModelMessageRoleType.SYSTEM in self.prompt_roles
@dataclass
@PublicAPI(stability="beta")
class ModelMetadata(BaseParameters):
"""A class to represent a LLM model."""
model: str = field(
metadata={"help": "Model name"},
)
context_length: Optional[int] = field(
default=4096,
metadata={"help": "Context length of model"},
)
chat_model: Optional[bool] = field(
default=True,
metadata={"help": "Whether the model is a chat model"},
)
is_function_calling_model: Optional[bool] = field(
default=False,
metadata={"help": "Whether the model is a function calling model"},
)
metadata: Optional[Dict[str, Any]] = field(
default_factory=dict,
metadata={"help": "Model metadata"},
)
ext_metadata: Optional[ModelExtraMedata] = field(
default_factory=ModelExtraMedata,
metadata={"help": "Model extra metadata"},
)
@classmethod
def from_dict(
cls, data: dict, ignore_extra_fields: bool = False
) -> "ModelMetadata":
"""Create a new model metadata from a dict."""
if "ext_metadata" in data:
data["ext_metadata"] = ModelExtraMedata(**data["ext_metadata"])
return cls(**data)
class MessageConverter(ABC):
r"""An abstract class for message converter.
Different LLMs may have different message formats, this class is used to convert
the messages to the format of the LLM.
Examples:
>>> from typing import List
>>> from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
>>> from dbgpt.core.interface.llm import MessageConverter, ModelMetadata
>>> class RemoveSystemMessageConverter(MessageConverter):
... def convert(
... self,
... messages: List[ModelMessage],
... model_metadata: Optional[ModelMetadata] = None,
... ) -> List[ModelMessage]:
... # Convert the messages, merge system messages to the last user
... # message.
... system_message = None
... other_messages = []
... sep = "\\n"
... for message in messages:
... if message.role == ModelMessageRoleType.SYSTEM:
... system_message = message
... else:
... other_messages.append(message)
... if system_message and other_messages:
... other_messages[-1].content = (
... system_message.content + sep + other_messages[-1].content
... )
... return other_messages
...
>>> messages = [
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ModelMessage(role=ModelMessageRoleType.HUMAN, content="Who are you"),
... ]
>>> converter = RemoveSystemMessageConverter()
>>> converted_messages = converter.convert(messages, None)
>>> assert converted_messages == [
... ModelMessage(
... role=ModelMessageRoleType.HUMAN,
... content="You are a helpful assistant\\nWho are you",
... ),
... ]
"""
@abstractmethod
def convert(
self,
messages: List[ModelMessage],
model_metadata: Optional[ModelMetadata] = None,
) -> List[ModelMessage]:
"""Convert the messages.
Args:
messages(List[ModelMessage]): The messages.
model_metadata(ModelMetadata): The model metadata.
Returns:
List[ModelMessage]: The converted messages.
"""
class DefaultMessageConverter(MessageConverter):
"""The default message converter."""
def __init__(self, prompt_sep: Optional[str] = None):
"""Create a new default message converter."""
self._prompt_sep = prompt_sep
def convert(
self,
messages: List[ModelMessage],
model_metadata: Optional[ModelMetadata] = None,
) -> List[ModelMessage]:
"""Convert the messages.
There are three steps to convert the messages:
1. Just keep system, human and AI messages
2. Move the last user's message to the end of the list
3. Convert the messages to no system message if the model does not support
system message
Args:
messages(List[ModelMessage]): The messages.
model_metadata(ModelMetadata): The model metadata.
Returns:
List[ModelMessage]: The converted messages.
"""
# 1. Just keep system, human and AI messages
messages = list(filter(lambda m: m.pass_to_model, messages))
# 2. Move the last user's message to the end of the list
messages = self.move_last_user_message_to_end(messages)
if not model_metadata or not model_metadata.ext_metadata:
logger.warning("No model metadata, skip message system message conversion")
return messages
if not model_metadata.ext_metadata.support_system_message:
# 3. Convert the messages to no system message
return self.convert_to_no_system_message(messages, model_metadata)
return messages
def convert_to_no_system_message(
self,
messages: List[ModelMessage],
model_metadata: Optional[ModelMetadata] = None,
) -> List[ModelMessage]:
r"""Convert the messages to no system message.
Examples:
>>> # Convert the messages to no system message, just merge system messages
>>> # to the last user message
>>> from typing import List
>>> from dbgpt.core.interface.message import (
... ModelMessage,
... ModelMessageRoleType,
... )
>>> from dbgpt.core.interface.llm import (
... DefaultMessageConverter,
... ModelMetadata,
... )
>>> messages = [
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ModelMessage(
... role=ModelMessageRoleType.HUMAN, content="Who are you"
... ),
... ]
>>> converter = DefaultMessageConverter()
>>> model_metadata = ModelMetadata(model="test")
>>> converted_messages = converter.convert_to_no_system_message(
... messages, model_metadata
... )
>>> assert converted_messages == [
... ModelMessage(
... role=ModelMessageRoleType.HUMAN,
... content="You are a helpful assistant\nWho are you",
... ),
... ]
"""
if not model_metadata or not model_metadata.ext_metadata:
logger.warning("No model metadata, skip message conversion")
return messages
ext_metadata = model_metadata.ext_metadata
system_messages = []
result_messages = []
for message in messages:
if message.role == ModelMessageRoleType.SYSTEM:
# Not support system message, append system message to the last user
# message
system_messages.append(message)
elif message.role in [
ModelMessageRoleType.HUMAN,
ModelMessageRoleType.AI,
]:
result_messages.append(message)
prompt_sep = self._prompt_sep or ext_metadata.prompt_sep or "\n"
system_message_str = None
if len(system_messages) > 1:
logger.warning("Your system messages have more than one message")
system_message_str = prompt_sep.join([m.content for m in system_messages])
elif len(system_messages) == 1:
system_message_str = system_messages[0].content
if system_message_str and result_messages:
# Not support system messages, merge system messages to the last user
# message
result_messages[-1].content = (
system_message_str + prompt_sep + result_messages[-1].content
)
return result_messages
def move_last_user_message_to_end(
self, messages: List[ModelMessage]
) -> List[ModelMessage]:
"""Try to move the last user message to the end of the list.
Examples:
>>> from typing import List
>>> from dbgpt.core.interface.message import (
... ModelMessage,
... ModelMessageRoleType,
... )
>>> from dbgpt.core.interface.llm import DefaultMessageConverter
>>> messages = [
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ModelMessage(
... role=ModelMessageRoleType.HUMAN, content="Who are you"
... ),
... ModelMessage(role=ModelMessageRoleType.AI, content="I'm a robot"),
... ModelMessage(
... role=ModelMessageRoleType.HUMAN, content="What's your name"
... ),
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ]
>>> converter = DefaultMessageConverter()
>>> converted_messages = converter.move_last_user_message_to_end(messages)
>>> assert converted_messages == [
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ModelMessage(
... role=ModelMessageRoleType.HUMAN, content="Who are you"
... ),
... ModelMessage(role=ModelMessageRoleType.AI, content="I'm a robot"),
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ModelMessage(
... role=ModelMessageRoleType.HUMAN, content="What's your name"
... ),
... ]
Args:
messages(List[ModelMessage]): The messages.
Returns:
List[ModelMessage]: The converted messages.
"""
last_user_input_index = None
for i in range(len(messages) - 1, -1, -1):
if messages[i].role == ModelMessageRoleType.HUMAN:
last_user_input_index = i
break
if last_user_input_index is not None:
last_user_input = messages.pop(last_user_input_index)
messages.append(last_user_input)
return messages
@PublicAPI(stability="beta")
class LLMClient(ABC):
"""An abstract class for LLM client."""
# Cache the model metadata for 60 seconds
_MODEL_CACHE_ = TTLCache(maxsize=100, ttl=60)
@property
def cache(self) -> collections.abc.MutableMapping:
"""Return the cache object to cache the model metadata.
You can override this property to use your own cache object.
Returns:
collections.abc.MutableMapping: The cache object.
"""
return self._MODEL_CACHE_
@abstractmethod
async def generate(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> ModelOutput:
"""Generate a response for a given model request.
Sometimes, different LLMs may have different message formats,
you can use the message converter to convert the messages to the format of the
LLM.
Args:
request(ModelRequest): The model request.
message_converter(MessageConverter): The message converter.
Returns:
ModelOutput: The model output.
"""
@abstractmethod
async def generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> AsyncIterator[ModelOutput]:
"""Generate a stream of responses for a given model request.
Sometimes, different LLMs may have different message formats,
you can use the message converter to convert the messages to the format of the
LLM.
Args:
request(ModelRequest): The model request.
message_converter(MessageConverter): The message converter.
Returns:
AsyncIterator[ModelOutput]: The model output stream.
"""
@abstractmethod
async def models(self) -> List[ModelMetadata]:
"""Get all the models.
Returns:
List[ModelMetadata]: A list of model metadata.
"""
@abstractmethod
async def count_token(self, model: str, prompt: str) -> int:
"""Count the number of tokens in a given prompt.
Args:
model(str): The model name.
prompt(str): The prompt.
Returns:
int: The number of tokens.
"""
async def covert_message(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> ModelRequest:
"""Covert the message.
If no message converter is provided, the original request will be returned.
Args:
request(ModelRequest): The model request.
message_converter(MessageConverter): The message converter.
Returns:
ModelRequest: The converted model request.
"""
if not message_converter:
return request
new_request = request.copy()
model_metadata = await self.get_model_metadata(request.model)
new_messages = message_converter.convert(request.get_messages(), model_metadata)
new_request.messages = new_messages
return new_request
async def cached_models(self) -> List[ModelMetadata]:
"""Get all the models from the cache or the llm server.
If the model metadata is not in the cache, it will be fetched from the
llm server.
Returns:
List[ModelMetadata]: A list of model metadata.
"""
key = "____$llm_client_models$____"
if key not in self.cache:
models = await self.models()
self.cache[key] = models
for model in models:
model_metadata_key = (
f"____$llm_client_models_metadata_{model.model}$____"
)
self.cache[model_metadata_key] = model
return self.cache[key]
async def get_model_metadata(self, model: str) -> ModelMetadata:
"""Get the model metadata.
Args:
model(str): The model name.
Returns:
ModelMetadata: The model metadata.
Raises:
ValueError: If the model is not found.
"""
model_metadata_key = f"____$llm_client_models_metadata_{model}$____"
if model_metadata_key not in self.cache:
await self.cached_models()
model_metadata = self.cache.get(model_metadata_key)
if not model_metadata:
raise ValueError(f"Model {model} not found")
return model_metadata
def __call__(self, *args, **kwargs) -> ModelOutput:
"""Return the model output.
Call the LLM client to generate the response for the given message.
Please do not use this method in the production environment, it is only used
for debugging.
"""
from dbgpt.util import get_or_create_event_loop
messages = kwargs.get("messages")
model = kwargs.get("model")
if messages:
del kwargs["messages"]
model_messages = ModelMessage.from_openai_messages(messages)
else:
model_messages = [ModelMessage.build_human_message(args[0])]
if not model:
if hasattr(self, "default_model"):
model = getattr(self, "default_model")
else:
raise ValueError("The default model is not set")
if "model" in kwargs:
del kwargs["model"]
req = ModelRequest.build_request(model, model_messages, **kwargs)
loop = get_or_create_event_loop()
return loop.run_until_complete(self.generate(req))