mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-23 01:49:58 +00:00
feat(awel): New MessageConverter and more AWEL operators (#1039)
This commit is contained in:
@@ -1,14 +1,21 @@
|
||||
import collections
|
||||
import copy
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.util import BaseParameters
|
||||
from dbgpt.util.annotations import PublicAPI
|
||||
from dbgpt.util.model_utils import GPUInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
@@ -223,6 +230,29 @@ class ModelRequest:
|
||||
raise ValueError("The messages is not a single user message")
|
||||
return messages[0]
|
||||
|
||||
@staticmethod
|
||||
def build_request(
|
||||
model: str,
|
||||
messages: List[ModelMessage],
|
||||
context: Union[ModelRequestContext, Dict[str, Any], BaseModel],
|
||||
stream: Optional[bool] = False,
|
||||
**kwargs,
|
||||
):
|
||||
context_dict = None
|
||||
if isinstance(context, dict):
|
||||
context_dict = context
|
||||
elif isinstance(context, BaseModel):
|
||||
context_dict = context.dict()
|
||||
if context_dict and "stream" not in context_dict:
|
||||
context_dict["stream"] = stream
|
||||
context = ModelRequestContext(**context_dict)
|
||||
return ModelRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
context=context,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build(model: str, prompt: str, **kwargs):
|
||||
return ModelRequest(
|
||||
@@ -271,6 +301,43 @@ class ModelRequest:
|
||||
return ModelMessage.to_openai_messages(messages)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelExtraMedata(BaseParameters):
|
||||
"""A class to represent the extra metadata of a LLM."""
|
||||
|
||||
prompt_roles: Optional[List[str]] = field(
|
||||
default_factory=lambda: [
|
||||
ModelMessageRoleType.SYSTEM,
|
||||
ModelMessageRoleType.HUMAN,
|
||||
ModelMessageRoleType.AI,
|
||||
],
|
||||
metadata={"help": "The roles of the prompt"},
|
||||
)
|
||||
|
||||
prompt_sep: Optional[str] = field(
|
||||
default="\n",
|
||||
metadata={"help": "The separator of the prompt between multiple rounds"},
|
||||
)
|
||||
|
||||
# You can see the chat template in your model repo tokenizer config,
|
||||
# typically in the tokenizer_config.json
|
||||
prompt_chat_template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The chat template, see: https://huggingface.co/docs/transformers/main/en/chat_templating"
|
||||
},
|
||||
)
|
||||
|
||||
@property
|
||||
def support_system_message(self) -> bool:
|
||||
"""Whether the model supports system message.
|
||||
|
||||
Returns:
|
||||
bool: Whether the model supports system message.
|
||||
"""
|
||||
return ModelMessageRoleType.SYSTEM in self.prompt_roles
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
class ModelMetadata(BaseParameters):
|
||||
@@ -295,18 +362,294 @@ class ModelMetadata(BaseParameters):
|
||||
default_factory=dict,
|
||||
metadata={"help": "Model metadata"},
|
||||
)
|
||||
ext_metadata: Optional[ModelExtraMedata] = field(
|
||||
default_factory=ModelExtraMedata,
|
||||
metadata={"help": "Model extra metadata"},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls, data: dict, ignore_extra_fields: bool = False
|
||||
) -> "ModelMetadata":
|
||||
if "ext_metadata" in data:
|
||||
data["ext_metadata"] = ModelExtraMedata(**data["ext_metadata"])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class MessageConverter(ABC):
|
||||
"""An abstract class for message converter.
|
||||
|
||||
Different LLMs may have different message formats, this class is used to convert the messages
|
||||
to the format of the LLM.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from typing import List
|
||||
>>> from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
>>> from dbgpt.core.interface.llm import MessageConverter, ModelMetadata
|
||||
>>> class RemoveSystemMessageConverter(MessageConverter):
|
||||
... def convert(
|
||||
... self,
|
||||
... messages: List[ModelMessage],
|
||||
... model_metadata: Optional[ModelMetadata] = None,
|
||||
... ) -> List[ModelMessage]:
|
||||
... # Convert the messages, merge system messages to the last user message.
|
||||
... system_message = None
|
||||
... other_messages = []
|
||||
... sep = "\\n"
|
||||
... for message in messages:
|
||||
... if message.role == ModelMessageRoleType.SYSTEM:
|
||||
... system_message = message
|
||||
... else:
|
||||
... other_messages.append(message)
|
||||
... if system_message and other_messages:
|
||||
... other_messages[-1].content = (
|
||||
... system_message.content + sep + other_messages[-1].content
|
||||
... )
|
||||
... return other_messages
|
||||
...
|
||||
>>> messages = [
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.SYSTEM,
|
||||
... content="You are a helpful assistant",
|
||||
... ),
|
||||
... ModelMessage(role=ModelMessageRoleType.HUMAN, content="Who are you"),
|
||||
... ]
|
||||
>>> converter = RemoveSystemMessageConverter()
|
||||
>>> converted_messages = converter.convert(messages, None)
|
||||
>>> assert converted_messages == [
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.HUMAN,
|
||||
... content="You are a helpful assistant\\nWho are you",
|
||||
... ),
|
||||
... ]
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def convert(
|
||||
self,
|
||||
messages: List[ModelMessage],
|
||||
model_metadata: Optional[ModelMetadata] = None,
|
||||
) -> List[ModelMessage]:
|
||||
"""Convert the messages.
|
||||
|
||||
Args:
|
||||
messages(List[ModelMessage]): The messages.
|
||||
model_metadata(ModelMetadata): The model metadata.
|
||||
|
||||
Returns:
|
||||
List[ModelMessage]: The converted messages.
|
||||
"""
|
||||
|
||||
|
||||
class DefaultMessageConverter(MessageConverter):
|
||||
"""The default message converter."""
|
||||
|
||||
def __init__(self, prompt_sep: Optional[str] = None):
|
||||
self._prompt_sep = prompt_sep
|
||||
|
||||
def convert(
|
||||
self,
|
||||
messages: List[ModelMessage],
|
||||
model_metadata: Optional[ModelMetadata] = None,
|
||||
) -> List[ModelMessage]:
|
||||
"""Convert the messages.
|
||||
|
||||
There are three steps to convert the messages:
|
||||
|
||||
1. Just keep system, human and AI messages
|
||||
|
||||
2. Move the last user's message to the end of the list
|
||||
|
||||
3. Convert the messages to no system message if the model does not support system message
|
||||
|
||||
Args:
|
||||
messages(List[ModelMessage]): The messages.
|
||||
model_metadata(ModelMetadata): The model metadata.
|
||||
|
||||
Returns:
|
||||
List[ModelMessage]: The converted messages.
|
||||
"""
|
||||
# 1. Just keep system, human and AI messages
|
||||
messages = list(filter(lambda m: m.pass_to_model, messages))
|
||||
# 2. Move the last user's message to the end of the list
|
||||
messages = self.move_last_user_message_to_end(messages)
|
||||
|
||||
if not model_metadata or not model_metadata.ext_metadata:
|
||||
logger.warning("No model metadata, skip message system message conversion")
|
||||
return messages
|
||||
if model_metadata.ext_metadata.support_system_message:
|
||||
# 3. Convert the messages to no system message
|
||||
return self.convert_to_no_system_message(messages, model_metadata)
|
||||
return messages
|
||||
|
||||
def convert_to_no_system_message(
|
||||
self,
|
||||
messages: List[ModelMessage],
|
||||
model_metadata: Optional[ModelMetadata] = None,
|
||||
) -> List[ModelMessage]:
|
||||
"""Convert the messages to no system message.
|
||||
|
||||
Examples:
|
||||
>>> # Convert the messages to no system message, just merge system messages to the last user message
|
||||
>>> from typing import List
|
||||
>>> from dbgpt.core.interface.message import (
|
||||
... ModelMessage,
|
||||
... ModelMessageRoleType,
|
||||
... )
|
||||
>>> from dbgpt.core.interface.llm import (
|
||||
... DefaultMessageConverter,
|
||||
... ModelMetadata,
|
||||
... )
|
||||
>>> messages = [
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.SYSTEM,
|
||||
... content="You are a helpful assistant",
|
||||
... ),
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.HUMAN, content="Who are you"
|
||||
... ),
|
||||
... ]
|
||||
>>> converter = DefaultMessageConverter()
|
||||
>>> model_metadata = ModelMetadata(model="test")
|
||||
>>> converted_messages = converter.convert_to_no_system_message(
|
||||
... messages, model_metadata
|
||||
... )
|
||||
>>> assert converted_messages == [
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.HUMAN,
|
||||
... content="You are a helpful assistant\\nWho are you",
|
||||
... ),
|
||||
... ]
|
||||
"""
|
||||
if not model_metadata or not model_metadata.ext_metadata:
|
||||
logger.warning("No model metadata, skip message conversion")
|
||||
return messages
|
||||
ext_metadata = model_metadata.ext_metadata
|
||||
system_messages = []
|
||||
result_messages = []
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.SYSTEM:
|
||||
# Not support system message, append system message to the last user message
|
||||
system_messages.append(message)
|
||||
elif message.role in [
|
||||
ModelMessageRoleType.HUMAN,
|
||||
ModelMessageRoleType.AI,
|
||||
]:
|
||||
result_messages.append(message)
|
||||
prompt_sep = self._prompt_sep or ext_metadata.prompt_sep or "\n"
|
||||
system_message_str = None
|
||||
if len(system_messages) > 1:
|
||||
logger.warning("Your system messages have more than one message")
|
||||
system_message_str = prompt_sep.join([m.content for m in system_messages])
|
||||
elif len(system_messages) == 1:
|
||||
system_message_str = system_messages[0].content
|
||||
|
||||
if system_message_str and result_messages:
|
||||
# Not support system messages, merge system messages to the last user message
|
||||
result_messages[-1].content = (
|
||||
system_message_str + prompt_sep + result_messages[-1].content
|
||||
)
|
||||
return result_messages
|
||||
|
||||
def move_last_user_message_to_end(
|
||||
self, messages: List[ModelMessage]
|
||||
) -> List[ModelMessage]:
|
||||
"""Move the last user message to the end of the list.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from typing import List
|
||||
>>> from dbgpt.core.interface.message import (
|
||||
... ModelMessage,
|
||||
... ModelMessageRoleType,
|
||||
... )
|
||||
>>> from dbgpt.core.interface.llm import DefaultMessageConverter
|
||||
>>> messages = [
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.SYSTEM,
|
||||
... content="You are a helpful assistant",
|
||||
... ),
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.HUMAN, content="Who are you"
|
||||
... ),
|
||||
... ModelMessage(role=ModelMessageRoleType.AI, content="I'm a robot"),
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.HUMAN, content="What's your name"
|
||||
... ),
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.SYSTEM,
|
||||
... content="You are a helpful assistant",
|
||||
... ),
|
||||
... ]
|
||||
>>> converter = DefaultMessageConverter()
|
||||
>>> converted_messages = converter.move_last_user_message_to_end(messages)
|
||||
>>> assert converted_messages == [
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.SYSTEM,
|
||||
... content="You are a helpful assistant",
|
||||
... ),
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.HUMAN, content="Who are you"
|
||||
... ),
|
||||
... ModelMessage(role=ModelMessageRoleType.AI, content="I'm a robot"),
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.SYSTEM,
|
||||
... content="You are a helpful assistant",
|
||||
... ),
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.HUMAN, content="What's your name"
|
||||
... ),
|
||||
... ]
|
||||
|
||||
Args:
|
||||
messages(List[ModelMessage]): The messages.
|
||||
|
||||
Returns:
|
||||
List[ModelMessage]: The converted messages.
|
||||
"""
|
||||
last_user_input_index = None
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
if messages[i].role == ModelMessageRoleType.HUMAN:
|
||||
last_user_input_index = i
|
||||
break
|
||||
if last_user_input_index is not None:
|
||||
last_user_input = messages.pop(last_user_input_index)
|
||||
messages.append(last_user_input)
|
||||
return messages
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class LLMClient(ABC):
|
||||
"""An abstract class for LLM client."""
|
||||
|
||||
# Cache the model metadata for 60 seconds
|
||||
_MODEL_CACHE_ = TTLCache(maxsize=100, ttl=60)
|
||||
|
||||
@property
|
||||
def cache(self) -> collections.abc.MutableMapping:
|
||||
"""The cache object to cache the model metadata.
|
||||
|
||||
You can override this property to use your own cache object.
|
||||
Returns:
|
||||
collections.abc.MutableMapping: The cache object.
|
||||
"""
|
||||
return self._MODEL_CACHE_
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self, request: ModelRequest) -> ModelOutput:
|
||||
async def generate(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> ModelOutput:
|
||||
"""Generate a response for a given model request.
|
||||
|
||||
Sometimes, different LLMs may have different message formats,
|
||||
you can use the message converter to convert the messages to the format of the LLM.
|
||||
|
||||
Args:
|
||||
request(ModelRequest): The model request.
|
||||
message_converter(MessageConverter): The message converter.
|
||||
|
||||
Returns:
|
||||
ModelOutput: The model output.
|
||||
@@ -315,12 +658,18 @@ class LLMClient(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def generate_stream(
|
||||
self, request: ModelRequest
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""Generate a stream of responses for a given model request.
|
||||
|
||||
Sometimes, different LLMs may have different message formats,
|
||||
you can use the message converter to convert the messages to the format of the LLM.
|
||||
|
||||
Args:
|
||||
request(ModelRequest): The model request.
|
||||
message_converter(MessageConverter): The message converter.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[ModelOutput]: The model output stream.
|
||||
@@ -345,3 +694,65 @@ class LLMClient(ABC):
|
||||
Returns:
|
||||
int: The number of tokens.
|
||||
"""
|
||||
|
||||
async def covert_message(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> ModelRequest:
|
||||
"""Covert the message.
|
||||
If no message converter is provided, the original request will be returned.
|
||||
|
||||
Args:
|
||||
request(ModelRequest): The model request.
|
||||
message_converter(MessageConverter): The message converter.
|
||||
|
||||
Returns:
|
||||
ModelRequest: The converted model request.
|
||||
"""
|
||||
if not message_converter:
|
||||
return request
|
||||
new_request = request.copy()
|
||||
model_metadata = await self.get_model_metadata(request.model)
|
||||
new_messages = message_converter.convert(request.messages, model_metadata)
|
||||
new_request.messages = new_messages
|
||||
return new_request
|
||||
|
||||
async def cached_models(self) -> List[ModelMetadata]:
|
||||
"""Get all the models from the cache or the llm server.
|
||||
|
||||
If the model metadata is not in the cache, it will be fetched from the llm server.
|
||||
|
||||
Returns:
|
||||
List[ModelMetadata]: A list of model metadata.
|
||||
"""
|
||||
key = "____$llm_client_models$____"
|
||||
if key not in self.cache:
|
||||
models = await self.models()
|
||||
self.cache[key] = models
|
||||
for model in models:
|
||||
model_metadata_key = (
|
||||
f"____$llm_client_models_metadata_{model.model}$____"
|
||||
)
|
||||
self.cache[model_metadata_key] = model
|
||||
return self.cache[key]
|
||||
|
||||
async def get_model_metadata(self, model: str) -> ModelMetadata:
|
||||
"""Get the model metadata.
|
||||
|
||||
Args:
|
||||
model(str): The model name.
|
||||
|
||||
Returns:
|
||||
ModelMetadata: The model metadata.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not found.
|
||||
"""
|
||||
model_metadata_key = f"____$llm_client_models_metadata_{model}$____"
|
||||
if model_metadata_key not in self.cache:
|
||||
await self.cached_models()
|
||||
model_metadata = self.cache.get(model_metadata_key)
|
||||
if not model_metadata:
|
||||
raise ValueError(f"Model {model} not found")
|
||||
return model_metadata
|
||||
|
Reference in New Issue
Block a user