feat(awel): New MessageConverter and more AWEL operators (#1039)

This commit is contained in:
Fangyin Cheng
2024-01-08 09:40:05 +08:00
committed by GitHub
parent 765fb181f6
commit e8861bd8fa
48 changed files with 2333 additions and 719 deletions

View File

@@ -1,14 +1,21 @@
import collections
import copy
import logging
import time
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from cachetools import TTLCache
from dbgpt._private.pydantic import BaseModel
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.util import BaseParameters
from dbgpt.util.annotations import PublicAPI
from dbgpt.util.model_utils import GPUInfo
logger = logging.getLogger(__name__)
@dataclass
@PublicAPI(stability="beta")
@@ -223,6 +230,29 @@ class ModelRequest:
raise ValueError("The messages is not a single user message")
return messages[0]
@staticmethod
def build_request(
model: str,
messages: List[ModelMessage],
context: Union[ModelRequestContext, Dict[str, Any], BaseModel],
stream: Optional[bool] = False,
**kwargs,
):
context_dict = None
if isinstance(context, dict):
context_dict = context
elif isinstance(context, BaseModel):
context_dict = context.dict()
if context_dict and "stream" not in context_dict:
context_dict["stream"] = stream
context = ModelRequestContext(**context_dict)
return ModelRequest(
model=model,
messages=messages,
context=context,
**kwargs,
)
@staticmethod
def _build(model: str, prompt: str, **kwargs):
return ModelRequest(
@@ -271,6 +301,43 @@ class ModelRequest:
return ModelMessage.to_openai_messages(messages)
@dataclass
class ModelExtraMedata(BaseParameters):
"""A class to represent the extra metadata of a LLM."""
prompt_roles: Optional[List[str]] = field(
default_factory=lambda: [
ModelMessageRoleType.SYSTEM,
ModelMessageRoleType.HUMAN,
ModelMessageRoleType.AI,
],
metadata={"help": "The roles of the prompt"},
)
prompt_sep: Optional[str] = field(
default="\n",
metadata={"help": "The separator of the prompt between multiple rounds"},
)
# You can see the chat template in your model repo tokenizer config,
# typically in the tokenizer_config.json
prompt_chat_template: Optional[str] = field(
default=None,
metadata={
"help": "The chat template, see: https://huggingface.co/docs/transformers/main/en/chat_templating"
},
)
@property
def support_system_message(self) -> bool:
"""Whether the model supports system message.
Returns:
bool: Whether the model supports system message.
"""
return ModelMessageRoleType.SYSTEM in self.prompt_roles
@dataclass
@PublicAPI(stability="beta")
class ModelMetadata(BaseParameters):
@@ -295,18 +362,294 @@ class ModelMetadata(BaseParameters):
default_factory=dict,
metadata={"help": "Model metadata"},
)
ext_metadata: Optional[ModelExtraMedata] = field(
default_factory=ModelExtraMedata,
metadata={"help": "Model extra metadata"},
)
@classmethod
def from_dict(
cls, data: dict, ignore_extra_fields: bool = False
) -> "ModelMetadata":
if "ext_metadata" in data:
data["ext_metadata"] = ModelExtraMedata(**data["ext_metadata"])
return cls(**data)
class MessageConverter(ABC):
"""An abstract class for message converter.
Different LLMs may have different message formats, this class is used to convert the messages
to the format of the LLM.
Examples:
>>> from typing import List
>>> from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
>>> from dbgpt.core.interface.llm import MessageConverter, ModelMetadata
>>> class RemoveSystemMessageConverter(MessageConverter):
... def convert(
... self,
... messages: List[ModelMessage],
... model_metadata: Optional[ModelMetadata] = None,
... ) -> List[ModelMessage]:
... # Convert the messages, merge system messages to the last user message.
... system_message = None
... other_messages = []
... sep = "\\n"
... for message in messages:
... if message.role == ModelMessageRoleType.SYSTEM:
... system_message = message
... else:
... other_messages.append(message)
... if system_message and other_messages:
... other_messages[-1].content = (
... system_message.content + sep + other_messages[-1].content
... )
... return other_messages
...
>>> messages = [
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ModelMessage(role=ModelMessageRoleType.HUMAN, content="Who are you"),
... ]
>>> converter = RemoveSystemMessageConverter()
>>> converted_messages = converter.convert(messages, None)
>>> assert converted_messages == [
... ModelMessage(
... role=ModelMessageRoleType.HUMAN,
... content="You are a helpful assistant\\nWho are you",
... ),
... ]
"""
@abstractmethod
def convert(
self,
messages: List[ModelMessage],
model_metadata: Optional[ModelMetadata] = None,
) -> List[ModelMessage]:
"""Convert the messages.
Args:
messages(List[ModelMessage]): The messages.
model_metadata(ModelMetadata): The model metadata.
Returns:
List[ModelMessage]: The converted messages.
"""
class DefaultMessageConverter(MessageConverter):
"""The default message converter."""
def __init__(self, prompt_sep: Optional[str] = None):
self._prompt_sep = prompt_sep
def convert(
self,
messages: List[ModelMessage],
model_metadata: Optional[ModelMetadata] = None,
) -> List[ModelMessage]:
"""Convert the messages.
There are three steps to convert the messages:
1. Just keep system, human and AI messages
2. Move the last user's message to the end of the list
3. Convert the messages to no system message if the model does not support system message
Args:
messages(List[ModelMessage]): The messages.
model_metadata(ModelMetadata): The model metadata.
Returns:
List[ModelMessage]: The converted messages.
"""
# 1. Just keep system, human and AI messages
messages = list(filter(lambda m: m.pass_to_model, messages))
# 2. Move the last user's message to the end of the list
messages = self.move_last_user_message_to_end(messages)
if not model_metadata or not model_metadata.ext_metadata:
logger.warning("No model metadata, skip message system message conversion")
return messages
if model_metadata.ext_metadata.support_system_message:
# 3. Convert the messages to no system message
return self.convert_to_no_system_message(messages, model_metadata)
return messages
def convert_to_no_system_message(
self,
messages: List[ModelMessage],
model_metadata: Optional[ModelMetadata] = None,
) -> List[ModelMessage]:
"""Convert the messages to no system message.
Examples:
>>> # Convert the messages to no system message, just merge system messages to the last user message
>>> from typing import List
>>> from dbgpt.core.interface.message import (
... ModelMessage,
... ModelMessageRoleType,
... )
>>> from dbgpt.core.interface.llm import (
... DefaultMessageConverter,
... ModelMetadata,
... )
>>> messages = [
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ModelMessage(
... role=ModelMessageRoleType.HUMAN, content="Who are you"
... ),
... ]
>>> converter = DefaultMessageConverter()
>>> model_metadata = ModelMetadata(model="test")
>>> converted_messages = converter.convert_to_no_system_message(
... messages, model_metadata
... )
>>> assert converted_messages == [
... ModelMessage(
... role=ModelMessageRoleType.HUMAN,
... content="You are a helpful assistant\\nWho are you",
... ),
... ]
"""
if not model_metadata or not model_metadata.ext_metadata:
logger.warning("No model metadata, skip message conversion")
return messages
ext_metadata = model_metadata.ext_metadata
system_messages = []
result_messages = []
for message in messages:
if message.role == ModelMessageRoleType.SYSTEM:
# Not support system message, append system message to the last user message
system_messages.append(message)
elif message.role in [
ModelMessageRoleType.HUMAN,
ModelMessageRoleType.AI,
]:
result_messages.append(message)
prompt_sep = self._prompt_sep or ext_metadata.prompt_sep or "\n"
system_message_str = None
if len(system_messages) > 1:
logger.warning("Your system messages have more than one message")
system_message_str = prompt_sep.join([m.content for m in system_messages])
elif len(system_messages) == 1:
system_message_str = system_messages[0].content
if system_message_str and result_messages:
# Not support system messages, merge system messages to the last user message
result_messages[-1].content = (
system_message_str + prompt_sep + result_messages[-1].content
)
return result_messages
def move_last_user_message_to_end(
self, messages: List[ModelMessage]
) -> List[ModelMessage]:
"""Move the last user message to the end of the list.
Examples:
>>> from typing import List
>>> from dbgpt.core.interface.message import (
... ModelMessage,
... ModelMessageRoleType,
... )
>>> from dbgpt.core.interface.llm import DefaultMessageConverter
>>> messages = [
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ModelMessage(
... role=ModelMessageRoleType.HUMAN, content="Who are you"
... ),
... ModelMessage(role=ModelMessageRoleType.AI, content="I'm a robot"),
... ModelMessage(
... role=ModelMessageRoleType.HUMAN, content="What's your name"
... ),
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ]
>>> converter = DefaultMessageConverter()
>>> converted_messages = converter.move_last_user_message_to_end(messages)
>>> assert converted_messages == [
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ModelMessage(
... role=ModelMessageRoleType.HUMAN, content="Who are you"
... ),
... ModelMessage(role=ModelMessageRoleType.AI, content="I'm a robot"),
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ModelMessage(
... role=ModelMessageRoleType.HUMAN, content="What's your name"
... ),
... ]
Args:
messages(List[ModelMessage]): The messages.
Returns:
List[ModelMessage]: The converted messages.
"""
last_user_input_index = None
for i in range(len(messages) - 1, -1, -1):
if messages[i].role == ModelMessageRoleType.HUMAN:
last_user_input_index = i
break
if last_user_input_index is not None:
last_user_input = messages.pop(last_user_input_index)
messages.append(last_user_input)
return messages
@PublicAPI(stability="beta")
class LLMClient(ABC):
"""An abstract class for LLM client."""
# Cache the model metadata for 60 seconds
_MODEL_CACHE_ = TTLCache(maxsize=100, ttl=60)
@property
def cache(self) -> collections.abc.MutableMapping:
"""The cache object to cache the model metadata.
You can override this property to use your own cache object.
Returns:
collections.abc.MutableMapping: The cache object.
"""
return self._MODEL_CACHE_
@abstractmethod
async def generate(self, request: ModelRequest) -> ModelOutput:
async def generate(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> ModelOutput:
"""Generate a response for a given model request.
Sometimes, different LLMs may have different message formats,
you can use the message converter to convert the messages to the format of the LLM.
Args:
request(ModelRequest): The model request.
message_converter(MessageConverter): The message converter.
Returns:
ModelOutput: The model output.
@@ -315,12 +658,18 @@ class LLMClient(ABC):
@abstractmethod
async def generate_stream(
self, request: ModelRequest
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> AsyncIterator[ModelOutput]:
"""Generate a stream of responses for a given model request.
Sometimes, different LLMs may have different message formats,
you can use the message converter to convert the messages to the format of the LLM.
Args:
request(ModelRequest): The model request.
message_converter(MessageConverter): The message converter.
Returns:
AsyncIterator[ModelOutput]: The model output stream.
@@ -345,3 +694,65 @@ class LLMClient(ABC):
Returns:
int: The number of tokens.
"""
async def covert_message(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> ModelRequest:
"""Covert the message.
If no message converter is provided, the original request will be returned.
Args:
request(ModelRequest): The model request.
message_converter(MessageConverter): The message converter.
Returns:
ModelRequest: The converted model request.
"""
if not message_converter:
return request
new_request = request.copy()
model_metadata = await self.get_model_metadata(request.model)
new_messages = message_converter.convert(request.messages, model_metadata)
new_request.messages = new_messages
return new_request
async def cached_models(self) -> List[ModelMetadata]:
"""Get all the models from the cache or the llm server.
If the model metadata is not in the cache, it will be fetched from the llm server.
Returns:
List[ModelMetadata]: A list of model metadata.
"""
key = "____$llm_client_models$____"
if key not in self.cache:
models = await self.models()
self.cache[key] = models
for model in models:
model_metadata_key = (
f"____$llm_client_models_metadata_{model.model}$____"
)
self.cache[model_metadata_key] = model
return self.cache[key]
async def get_model_metadata(self, model: str) -> ModelMetadata:
"""Get the model metadata.
Args:
model(str): The model name.
Returns:
ModelMetadata: The model metadata.
Raises:
ValueError: If the model is not found.
"""
model_metadata_key = f"____$llm_client_models_metadata_{model}$____"
if model_metadata_key not in self.cache:
await self.cached_models()
model_metadata = self.cache.get(model_metadata_key)
if not model_metadata:
raise ValueError(f"Model {model} not found")
return model_metadata