feat(awel): New MessageConverter and more AWEL operators (#1039)

This commit is contained in:
Fangyin Cheng
2024-01-08 09:40:05 +08:00
committed by GitHub
parent 765fb181f6
commit e8861bd8fa
48 changed files with 2333 additions and 719 deletions

View File

@@ -1,14 +1,21 @@
import collections
import copy
import logging
import time
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from cachetools import TTLCache
from dbgpt._private.pydantic import BaseModel
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.util import BaseParameters
from dbgpt.util.annotations import PublicAPI
from dbgpt.util.model_utils import GPUInfo
logger = logging.getLogger(__name__)
@dataclass
@PublicAPI(stability="beta")
@@ -223,6 +230,29 @@ class ModelRequest:
raise ValueError("The messages is not a single user message")
return messages[0]
@staticmethod
def build_request(
model: str,
messages: List[ModelMessage],
context: Union[ModelRequestContext, Dict[str, Any], BaseModel],
stream: Optional[bool] = False,
**kwargs,
):
context_dict = None
if isinstance(context, dict):
context_dict = context
elif isinstance(context, BaseModel):
context_dict = context.dict()
if context_dict and "stream" not in context_dict:
context_dict["stream"] = stream
context = ModelRequestContext(**context_dict)
return ModelRequest(
model=model,
messages=messages,
context=context,
**kwargs,
)
@staticmethod
def _build(model: str, prompt: str, **kwargs):
return ModelRequest(
@@ -271,6 +301,43 @@ class ModelRequest:
return ModelMessage.to_openai_messages(messages)
@dataclass
class ModelExtraMedata(BaseParameters):
"""A class to represent the extra metadata of a LLM."""
prompt_roles: Optional[List[str]] = field(
default_factory=lambda: [
ModelMessageRoleType.SYSTEM,
ModelMessageRoleType.HUMAN,
ModelMessageRoleType.AI,
],
metadata={"help": "The roles of the prompt"},
)
prompt_sep: Optional[str] = field(
default="\n",
metadata={"help": "The separator of the prompt between multiple rounds"},
)
# You can see the chat template in your model repo tokenizer config,
# typically in the tokenizer_config.json
prompt_chat_template: Optional[str] = field(
default=None,
metadata={
"help": "The chat template, see: https://huggingface.co/docs/transformers/main/en/chat_templating"
},
)
@property
def support_system_message(self) -> bool:
"""Whether the model supports system message.
Returns:
bool: Whether the model supports system message.
"""
return ModelMessageRoleType.SYSTEM in self.prompt_roles
@dataclass
@PublicAPI(stability="beta")
class ModelMetadata(BaseParameters):
@@ -295,18 +362,294 @@ class ModelMetadata(BaseParameters):
default_factory=dict,
metadata={"help": "Model metadata"},
)
ext_metadata: Optional[ModelExtraMedata] = field(
default_factory=ModelExtraMedata,
metadata={"help": "Model extra metadata"},
)
@classmethod
def from_dict(
cls, data: dict, ignore_extra_fields: bool = False
) -> "ModelMetadata":
if "ext_metadata" in data:
data["ext_metadata"] = ModelExtraMedata(**data["ext_metadata"])
return cls(**data)
class MessageConverter(ABC):
"""An abstract class for message converter.
Different LLMs may have different message formats, this class is used to convert the messages
to the format of the LLM.
Examples:
>>> from typing import List
>>> from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
>>> from dbgpt.core.interface.llm import MessageConverter, ModelMetadata
>>> class RemoveSystemMessageConverter(MessageConverter):
... def convert(
... self,
... messages: List[ModelMessage],
... model_metadata: Optional[ModelMetadata] = None,
... ) -> List[ModelMessage]:
... # Convert the messages, merge system messages to the last user message.
... system_message = None
... other_messages = []
... sep = "\\n"
... for message in messages:
... if message.role == ModelMessageRoleType.SYSTEM:
... system_message = message
... else:
... other_messages.append(message)
... if system_message and other_messages:
... other_messages[-1].content = (
... system_message.content + sep + other_messages[-1].content
... )
... return other_messages
...
>>> messages = [
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ModelMessage(role=ModelMessageRoleType.HUMAN, content="Who are you"),
... ]
>>> converter = RemoveSystemMessageConverter()
>>> converted_messages = converter.convert(messages, None)
>>> assert converted_messages == [
... ModelMessage(
... role=ModelMessageRoleType.HUMAN,
... content="You are a helpful assistant\\nWho are you",
... ),
... ]
"""
@abstractmethod
def convert(
self,
messages: List[ModelMessage],
model_metadata: Optional[ModelMetadata] = None,
) -> List[ModelMessage]:
"""Convert the messages.
Args:
messages(List[ModelMessage]): The messages.
model_metadata(ModelMetadata): The model metadata.
Returns:
List[ModelMessage]: The converted messages.
"""
class DefaultMessageConverter(MessageConverter):
"""The default message converter."""
def __init__(self, prompt_sep: Optional[str] = None):
self._prompt_sep = prompt_sep
def convert(
self,
messages: List[ModelMessage],
model_metadata: Optional[ModelMetadata] = None,
) -> List[ModelMessage]:
"""Convert the messages.
There are three steps to convert the messages:
1. Just keep system, human and AI messages
2. Move the last user's message to the end of the list
3. Convert the messages to no system message if the model does not support system message
Args:
messages(List[ModelMessage]): The messages.
model_metadata(ModelMetadata): The model metadata.
Returns:
List[ModelMessage]: The converted messages.
"""
# 1. Just keep system, human and AI messages
messages = list(filter(lambda m: m.pass_to_model, messages))
# 2. Move the last user's message to the end of the list
messages = self.move_last_user_message_to_end(messages)
if not model_metadata or not model_metadata.ext_metadata:
logger.warning("No model metadata, skip message system message conversion")
return messages
if model_metadata.ext_metadata.support_system_message:
# 3. Convert the messages to no system message
return self.convert_to_no_system_message(messages, model_metadata)
return messages
def convert_to_no_system_message(
self,
messages: List[ModelMessage],
model_metadata: Optional[ModelMetadata] = None,
) -> List[ModelMessage]:
"""Convert the messages to no system message.
Examples:
>>> # Convert the messages to no system message, just merge system messages to the last user message
>>> from typing import List
>>> from dbgpt.core.interface.message import (
... ModelMessage,
... ModelMessageRoleType,
... )
>>> from dbgpt.core.interface.llm import (
... DefaultMessageConverter,
... ModelMetadata,
... )
>>> messages = [
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ModelMessage(
... role=ModelMessageRoleType.HUMAN, content="Who are you"
... ),
... ]
>>> converter = DefaultMessageConverter()
>>> model_metadata = ModelMetadata(model="test")
>>> converted_messages = converter.convert_to_no_system_message(
... messages, model_metadata
... )
>>> assert converted_messages == [
... ModelMessage(
... role=ModelMessageRoleType.HUMAN,
... content="You are a helpful assistant\\nWho are you",
... ),
... ]
"""
if not model_metadata or not model_metadata.ext_metadata:
logger.warning("No model metadata, skip message conversion")
return messages
ext_metadata = model_metadata.ext_metadata
system_messages = []
result_messages = []
for message in messages:
if message.role == ModelMessageRoleType.SYSTEM:
# Not support system message, append system message to the last user message
system_messages.append(message)
elif message.role in [
ModelMessageRoleType.HUMAN,
ModelMessageRoleType.AI,
]:
result_messages.append(message)
prompt_sep = self._prompt_sep or ext_metadata.prompt_sep or "\n"
system_message_str = None
if len(system_messages) > 1:
logger.warning("Your system messages have more than one message")
system_message_str = prompt_sep.join([m.content for m in system_messages])
elif len(system_messages) == 1:
system_message_str = system_messages[0].content
if system_message_str and result_messages:
# Not support system messages, merge system messages to the last user message
result_messages[-1].content = (
system_message_str + prompt_sep + result_messages[-1].content
)
return result_messages
def move_last_user_message_to_end(
self, messages: List[ModelMessage]
) -> List[ModelMessage]:
"""Move the last user message to the end of the list.
Examples:
>>> from typing import List
>>> from dbgpt.core.interface.message import (
... ModelMessage,
... ModelMessageRoleType,
... )
>>> from dbgpt.core.interface.llm import DefaultMessageConverter
>>> messages = [
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ModelMessage(
... role=ModelMessageRoleType.HUMAN, content="Who are you"
... ),
... ModelMessage(role=ModelMessageRoleType.AI, content="I'm a robot"),
... ModelMessage(
... role=ModelMessageRoleType.HUMAN, content="What's your name"
... ),
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ]
>>> converter = DefaultMessageConverter()
>>> converted_messages = converter.move_last_user_message_to_end(messages)
>>> assert converted_messages == [
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ModelMessage(
... role=ModelMessageRoleType.HUMAN, content="Who are you"
... ),
... ModelMessage(role=ModelMessageRoleType.AI, content="I'm a robot"),
... ModelMessage(
... role=ModelMessageRoleType.SYSTEM,
... content="You are a helpful assistant",
... ),
... ModelMessage(
... role=ModelMessageRoleType.HUMAN, content="What's your name"
... ),
... ]
Args:
messages(List[ModelMessage]): The messages.
Returns:
List[ModelMessage]: The converted messages.
"""
last_user_input_index = None
for i in range(len(messages) - 1, -1, -1):
if messages[i].role == ModelMessageRoleType.HUMAN:
last_user_input_index = i
break
if last_user_input_index is not None:
last_user_input = messages.pop(last_user_input_index)
messages.append(last_user_input)
return messages
@PublicAPI(stability="beta")
class LLMClient(ABC):
"""An abstract class for LLM client."""
# Cache the model metadata for 60 seconds
_MODEL_CACHE_ = TTLCache(maxsize=100, ttl=60)
@property
def cache(self) -> collections.abc.MutableMapping:
"""The cache object to cache the model metadata.
You can override this property to use your own cache object.
Returns:
collections.abc.MutableMapping: The cache object.
"""
return self._MODEL_CACHE_
@abstractmethod
async def generate(self, request: ModelRequest) -> ModelOutput:
async def generate(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> ModelOutput:
"""Generate a response for a given model request.
Sometimes, different LLMs may have different message formats,
you can use the message converter to convert the messages to the format of the LLM.
Args:
request(ModelRequest): The model request.
message_converter(MessageConverter): The message converter.
Returns:
ModelOutput: The model output.
@@ -315,12 +658,18 @@ class LLMClient(ABC):
@abstractmethod
async def generate_stream(
self, request: ModelRequest
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> AsyncIterator[ModelOutput]:
"""Generate a stream of responses for a given model request.
Sometimes, different LLMs may have different message formats,
you can use the message converter to convert the messages to the format of the LLM.
Args:
request(ModelRequest): The model request.
message_converter(MessageConverter): The message converter.
Returns:
AsyncIterator[ModelOutput]: The model output stream.
@@ -345,3 +694,65 @@ class LLMClient(ABC):
Returns:
int: The number of tokens.
"""
async def covert_message(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> ModelRequest:
"""Covert the message.
If no message converter is provided, the original request will be returned.
Args:
request(ModelRequest): The model request.
message_converter(MessageConverter): The message converter.
Returns:
ModelRequest: The converted model request.
"""
if not message_converter:
return request
new_request = request.copy()
model_metadata = await self.get_model_metadata(request.model)
new_messages = message_converter.convert(request.messages, model_metadata)
new_request.messages = new_messages
return new_request
async def cached_models(self) -> List[ModelMetadata]:
"""Get all the models from the cache or the llm server.
If the model metadata is not in the cache, it will be fetched from the llm server.
Returns:
List[ModelMetadata]: A list of model metadata.
"""
key = "____$llm_client_models$____"
if key not in self.cache:
models = await self.models()
self.cache[key] = models
for model in models:
model_metadata_key = (
f"____$llm_client_models_metadata_{model.model}$____"
)
self.cache[model_metadata_key] = model
return self.cache[key]
async def get_model_metadata(self, model: str) -> ModelMetadata:
"""Get the model metadata.
Args:
model(str): The model name.
Returns:
ModelMetadata: The model metadata.
Raises:
ValueError: If the model is not found.
"""
model_metadata_key = f"____$llm_client_models_metadata_{model}$____"
if model_metadata_key not in self.cache:
await self.cached_models()
model_metadata = self.cache.get(model_metadata_key)
if not model_metadata:
raise ValueError(f"Model {model} not found")
return model_metadata

View File

@@ -5,7 +5,6 @@ from datetime import datetime
from typing import Callable, Dict, List, Optional, Tuple, Union
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core.awel import MapOperator
from dbgpt.core.interface.storage import (
InMemoryStorage,
ResourceIdentifier,
@@ -114,6 +113,50 @@ class ModelMessage(BaseModel):
content: str
round_index: Optional[int] = 0
@property
def pass_to_model(self) -> bool:
"""Whether the message will be passed to the model
The view message will not be passed to the model
Returns:
bool: Whether the message will be passed to the model
"""
return self.role in [
ModelMessageRoleType.SYSTEM,
ModelMessageRoleType.HUMAN,
ModelMessageRoleType.AI,
]
@staticmethod
def from_base_messages(messages: List[BaseMessage]) -> List["ModelMessage"]:
result = []
for message in messages:
content, round_index = message.content, message.round_index
if isinstance(message, HumanMessage):
result.append(
ModelMessage(
role=ModelMessageRoleType.HUMAN,
content=content,
round_index=round_index,
)
)
elif isinstance(message, AIMessage):
result.append(
ModelMessage(
role=ModelMessageRoleType.AI,
content=content,
round_index=round_index,
)
)
elif isinstance(message, SystemMessage):
result.append(
ModelMessage(
role=ModelMessageRoleType.SYSTEM, content=message.content
)
)
return result
@staticmethod
def from_openai_messages(
messages: Union[str, List[Dict[str, str]]]
@@ -142,9 +185,15 @@ class ModelMessage(BaseModel):
return result
@staticmethod
def to_openai_messages(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
def to_openai_messages(
messages: List["ModelMessage"], convert_to_compatible_format: bool = False
) -> List[Dict[str, str]]:
"""Convert to OpenAI message format and
hugggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
Args:
messages (List["ModelMessage"]): The model messages
convert_to_compatible_format (bool): Whether to convert to compatible format
"""
history = []
# Add history conversation
@@ -157,15 +206,16 @@ class ModelMessage(BaseModel):
history.append({"role": "assistant", "content": message.content})
else:
pass
# Move the last user's information to the end
last_user_input_index = None
for i in range(len(history) - 1, -1, -1):
if history[i]["role"] == "user":
last_user_input_index = i
break
if last_user_input_index:
last_user_input = history.pop(last_user_input_index)
history.append(last_user_input)
if convert_to_compatible_format:
# Move the last user's information to the end
last_user_input_index = None
for i in range(len(history) - 1, -1, -1):
if history[i]["role"] == "user":
last_user_input_index = i
break
if last_user_input_index:
last_user_input = history.pop(last_user_input_index)
history.append(last_user_input)
return history
@staticmethod
@@ -189,8 +239,8 @@ class ModelMessage(BaseModel):
return str_msg
_SingleRoundMessage = List[ModelMessage]
_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[ModelMessage]]
_SingleRoundMessage = List[BaseMessage]
_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[BaseMessage]]
def _message_to_dict(message: BaseMessage) -> Dict:
@@ -338,7 +388,8 @@ class OnceConversation:
"""Start a new round of conversation
Example:
>>> conversation = OnceConversation()
>>> conversation = OnceConversation("chat_normal")
>>> # The chat order will be 0, then we start a new round of conversation
>>> assert conversation.chat_order == 0
>>> conversation.start_new_round()
@@ -585,6 +636,28 @@ class OnceConversation:
)
return messages
def get_history_message(
self, include_system_message: bool = False
) -> List[BaseMessage]:
"""Get the history message
Not include the system messages.
Args:
include_system_message (bool): Whether to include the system message
Returns:
List[BaseMessage]: The history messages
"""
messages = []
for message in self.messages:
if message.pass_to_model:
if include_system_message:
messages.append(message)
elif message.type != "system":
messages.append(message)
return messages
class ConversationIdentifier(ResourceIdentifier):
"""Conversation identifier"""

View File

@@ -0,0 +1,114 @@
import dataclasses
from typing import Any, Dict, List, Optional
from dbgpt.core import (
ChatPromptTemplate,
MessageStorageItem,
ModelMessage,
ModelRequest,
StorageConversation,
StorageInterface,
)
from dbgpt.core.awel import (
DAG,
BaseOperator,
InputOperator,
JoinOperator,
MapOperator,
SimpleCallDataInputSource,
)
from dbgpt.core.interface.operator.prompt_operator import HistoryPromptBuilderOperator
from .message_operator import (
BufferedConversationMapperOperator,
ChatHistoryLoadType,
PreChatHistoryLoadOperator,
)
@dataclasses.dataclass
class ChatComposerInput:
"""The composer input."""
prompt_dict: Dict[str, Any]
model_dict: Dict[str, Any]
context: ChatHistoryLoadType
class ChatHistoryPromptComposerOperator(MapOperator[ChatComposerInput, ModelRequest]):
"""The chat history prompt composer operator.
For simple use, you can use this operator to compose the chat history prompt.
"""
def __init__(
self,
prompt_template: ChatPromptTemplate,
history_key: str = "chat_history",
last_k_round: int = 2,
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
**kwargs,
):
super().__init__(**kwargs)
self._prompt_template = prompt_template
self._history_key = history_key
self._last_k_round = last_k_round
self._storage = storage
self._message_storage = message_storage
self._sub_compose_dag = self._build_composer_dag()
async def map(self, input_value: ChatComposerInput) -> ModelRequest:
end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
# Sub dag, use the same dag context in the parent dag
return await end_node.call(
call_data={"data": input_value}, dag_ctx=self.current_dag_context
)
def _build_composer_dag(self) -> DAG:
with DAG("dbgpt_awel_chat_history_prompt_composer") as composer_dag:
input_task = InputOperator(input_source=SimpleCallDataInputSource())
# Load and store chat history, default use InMemoryStorage.
chat_history_load_task = PreChatHistoryLoadOperator(
storage=self._storage, message_storage=self._message_storage
)
# History transform task, here we keep last 5 round messages
history_transform_task = BufferedConversationMapperOperator(
last_k_round=self._last_k_round
)
history_prompt_build_task = HistoryPromptBuilderOperator(
prompt=self._prompt_template, history_key=self._history_key
)
model_request_build_task = JoinOperator(self._build_model_request)
# Build composer dag
(
input_task
>> MapOperator(lambda x: x.context)
>> chat_history_load_task
>> history_transform_task
>> history_prompt_build_task
)
(
input_task
>> MapOperator(lambda x: x.prompt_dict)
>> history_prompt_build_task
)
history_prompt_build_task >> model_request_build_task
(
input_task
>> MapOperator(lambda x: x.model_dict)
>> model_request_build_task
)
return composer_dag
def _build_model_request(
self, messages: List[ModelMessage], model_dict: Dict[str, Any]
) -> ModelRequest:
return ModelRequest.build_request(messages=messages, **model_dict)
async def after_dag_end(self):
# Should call after_dag_end() of sub dag
await self._sub_compose_dag._after_dag_end()

View File

@@ -1,11 +1,12 @@
import dataclasses
from abc import ABC
from typing import Any, AsyncIterator, Dict, Optional, Union
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from dbgpt._private.pydantic import BaseModel
from dbgpt.core.awel import (
BranchFunc,
BranchOperator,
DAGContext,
MapOperator,
StreamifyAbsOperator,
)
@@ -22,20 +23,30 @@ RequestInput = Union[
str,
Dict[str, Any],
BaseModel,
ModelMessage,
List[ModelMessage],
]
class RequestBuildOperator(MapOperator[RequestInput, ModelRequest], ABC):
class RequestBuilderOperator(MapOperator[RequestInput, ModelRequest], ABC):
"""Build the model request from the input value."""
def __init__(self, model: Optional[str] = None, **kwargs):
self._model = model
super().__init__(**kwargs)
async def map(self, input_value: RequestInput) -> ModelRequest:
req_dict = {}
if not input_value:
raise ValueError("input_value is not set")
if isinstance(input_value, str):
req_dict = {"messages": [ModelMessage.build_human_message(input_value)]}
elif isinstance(input_value, dict):
req_dict = input_value
elif isinstance(input_value, ModelMessage):
req_dict = {"messages": [input_value]}
elif isinstance(input_value, list) and isinstance(input_value[0], ModelMessage):
req_dict = {"messages": input_value}
elif dataclasses.is_dataclass(input_value):
req_dict = dataclasses.asdict(input_value)
elif isinstance(input_value, BaseModel):
@@ -76,6 +87,7 @@ class BaseLLM:
"""The abstract operator for a LLM."""
SHARE_DATA_KEY_MODEL_NAME = "share_data_key_model_name"
SHARE_DATA_KEY_MODEL_OUTPUT = "share_data_key_model_output"
def __init__(self, llm_client: Optional[LLMClient] = None):
self._llm_client = llm_client
@@ -87,8 +99,16 @@ class BaseLLM:
raise ValueError("llm_client is not set")
return self._llm_client
async def save_model_output(
self, current_dag_context: DAGContext, model_output: ModelOutput
) -> None:
"""Save the model output to the share data."""
await current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_MODEL_OUTPUT, model_output
)
class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
"""The operator for a LLM.
Args:
@@ -105,10 +125,12 @@ class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
await self.current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_MODEL_NAME, request.model
)
return await self.llm_client.generate(request)
model_output = await self.llm_client.generate(request)
await self.save_model_output(self.current_dag_context, model_output)
return model_output
class StreamingLLMOperator(
class BaseStreamingLLMOperator(
BaseLLM, StreamifyAbsOperator[ModelRequest, ModelOutput], ABC
):
"""The streaming operator for a LLM.
@@ -127,8 +149,12 @@ class StreamingLLMOperator(
await self.current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_MODEL_NAME, request.model
)
model_output = None
async for output in self.llm_client.generate_stream(request):
model_output = output
yield output
if model_output:
await self.save_model_output(self.current_dag_context, model_output)
class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):

View File

@@ -1,19 +1,17 @@
import uuid
from abc import ABC, abstractmethod
from typing import Any, AsyncIterator, List, Optional
from abc import ABC
from typing import Any, Dict, List, Optional, Union
from dbgpt.core import (
MessageStorageItem,
ModelMessage,
ModelMessageRoleType,
ModelOutput,
ModelRequest,
ModelRequestContext,
StorageConversation,
StorageInterface,
)
from dbgpt.core.awel import BaseOperator, MapOperator, TransformStreamAbsOperator
from dbgpt.core.interface.message import _MultiRoundMessageMapper
from dbgpt.core.awel import BaseOperator, MapOperator
from dbgpt.core.interface.message import BaseMessage, _MultiRoundMessageMapper
class BaseConversationOperator(BaseOperator, ABC):
@@ -21,32 +19,41 @@ class BaseConversationOperator(BaseOperator, ABC):
SHARE_DATA_KEY_STORAGE_CONVERSATION = "share_data_key_storage_conversation"
SHARE_DATA_KEY_MODEL_REQUEST = "share_data_key_model_request"
SHARE_DATA_KEY_MODEL_REQUEST_CONTEXT = "share_data_key_model_request_context"
_check_storage: bool = True
def __init__(
self,
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
check_storage: bool = True,
**kwargs,
):
self._check_storage = check_storage
super().__init__(**kwargs)
self._storage = storage
self._message_storage = message_storage
@property
def storage(self) -> StorageInterface[StorageConversation, Any]:
def storage(self) -> Optional[StorageInterface[StorageConversation, Any]]:
"""Return the LLM client."""
if not self._storage:
raise ValueError("Storage is not set")
if self._check_storage:
raise ValueError("Storage is not set")
return None
return self._storage
@property
def message_storage(self) -> StorageInterface[MessageStorageItem, Any]:
def message_storage(self) -> Optional[StorageInterface[MessageStorageItem, Any]]:
"""Return the LLM client."""
if not self._message_storage:
raise ValueError("Message storage is not set")
if self._check_storage:
raise ValueError("Message storage is not set")
return None
return self._message_storage
async def get_storage_conversation(self) -> StorageConversation:
async def get_storage_conversation(self) -> Optional[StorageConversation]:
"""Get the storage conversation from share data.
Returns:
@@ -58,104 +65,11 @@ class BaseConversationOperator(BaseOperator, ABC):
)
)
if not storage_conv:
raise ValueError("Storage conversation is not set")
if self._check_storage:
raise ValueError("Storage conversation is not set")
return None
return storage_conv
async def get_model_request(self) -> ModelRequest:
"""Get the model request from share data.
Returns:
ModelRequest: The model request.
"""
model_request: ModelRequest = (
await self.current_dag_context.get_from_share_data(
self.SHARE_DATA_KEY_MODEL_REQUEST
)
)
if not model_request:
raise ValueError("Model request is not set")
return model_request
class PreConversationOperator(
BaseConversationOperator, MapOperator[ModelRequest, ModelRequest]
):
"""The operator to prepare the storage conversation.
In DB-GPT, conversation record and the messages in the conversation are stored in the storage,
and they can store in different storage(for high performance).
"""
def __init__(
self,
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
**kwargs,
):
super().__init__(storage=storage, message_storage=message_storage)
MapOperator.__init__(self, **kwargs)
async def map(self, input_value: ModelRequest) -> ModelRequest:
"""Map the input value to a ModelRequest.
Args:
input_value (ModelRequest): The input value.
Returns:
ModelRequest: The mapped ModelRequest.
"""
if input_value.context is None:
input_value.context = ModelRequestContext()
if not input_value.context.conv_uid:
input_value.context.conv_uid = str(uuid.uuid4())
if not input_value.context.extra:
input_value.context.extra = {}
chat_mode = input_value.context.chat_mode
# Create a new storage conversation, this will load the conversation from storage, so we must do this async
storage_conv: StorageConversation = await self.blocking_func_to_async(
StorageConversation,
conv_uid=input_value.context.conv_uid,
chat_mode=chat_mode,
user_name=input_value.context.user_name,
sys_code=input_value.context.sys_code,
conv_storage=self.storage,
message_storage=self.message_storage,
)
input_messages = input_value.get_messages()
await self.save_to_storage(storage_conv, input_messages)
# Get all messages from current storage conversation, and overwrite the input value
messages: List[ModelMessage] = storage_conv.get_model_messages()
input_value.messages = messages
# Save the storage conversation to share data, for the child operators
await self.current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_STORAGE_CONVERSATION, storage_conv
)
await self.current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_MODEL_REQUEST, input_value
)
return input_value
async def save_to_storage(
self, storage_conv: StorageConversation, input_messages: List[ModelMessage]
) -> None:
"""Save the messages to storage.
Args:
storage_conv (StorageConversation): The storage conversation.
input_messages (List[ModelMessage]): The input messages.
"""
# check first
self.check_messages(input_messages)
storage_conv.start_new_round()
for message in input_messages:
if message.role == ModelMessageRoleType.HUMAN:
storage_conv.add_user_message(message.content)
else:
storage_conv.add_system_message(message.content)
def check_messages(self, messages: List[ModelMessage]) -> None:
"""Check the messages.
@@ -174,164 +88,147 @@ class PreConversationOperator(
]:
raise ValueError(f"Message role {message.role} is not supported")
async def after_dag_end(self):
"""The callback after DAG end"""
# Save the storage conversation to storage after the whole DAG finished
storage_conv: StorageConversation = await self.get_storage_conversation()
# TODO dont save if the conversation has some internal error
storage_conv.end_current_round()
ChatHistoryLoadType = Union[ModelRequestContext, Dict[str, Any]]
class PostConversationOperator(
BaseConversationOperator, MapOperator[ModelOutput, ModelOutput]
class PreChatHistoryLoadOperator(
BaseConversationOperator, MapOperator[ChatHistoryLoadType, List[BaseMessage]]
):
def __init__(self, **kwargs):
"""The operator to prepare the storage conversation.
In DB-GPT, conversation record and the messages in the conversation are stored in the storage,
and they can store in different storage(for high performance).
This operator just load the conversation and messages from storage.
"""
def __init__(
self,
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
include_system_message: bool = False,
**kwargs,
):
super().__init__(storage=storage, message_storage=message_storage)
MapOperator.__init__(self, **kwargs)
self._include_system_message = include_system_message
async def map(self, input_value: ModelOutput) -> ModelOutput:
"""Map the input value to a ModelOutput.
async def map(self, input_value: ChatHistoryLoadType) -> List[BaseMessage]:
"""Map the input value to a ModelRequest.
Args:
input_value (ModelOutput): The input value.
input_value (ChatHistoryLoadType): The input value.
Returns:
ModelOutput: The mapped ModelOutput.
List[BaseMessage]: The messages stored in the storage.
"""
# Get the storage conversation from share data
storage_conv: StorageConversation = await self.get_storage_conversation()
storage_conv.add_ai_message(input_value.text)
return input_value
if not input_value:
raise ValueError("Model request context can't be None")
if isinstance(input_value, dict):
input_value = ModelRequestContext(**input_value)
if not input_value.conv_uid:
input_value.conv_uid = str(uuid.uuid4())
if not input_value.extra:
input_value.extra = {}
chat_mode = input_value.chat_mode
class PostStreamingConversationOperator(
BaseConversationOperator, TransformStreamAbsOperator[ModelOutput, ModelOutput]
):
def __init__(self, **kwargs):
TransformStreamAbsOperator.__init__(self, **kwargs)
# Create a new storage conversation, this will load the conversation from storage, so we must do this async
storage_conv: StorageConversation = await self.blocking_func_to_async(
StorageConversation,
conv_uid=input_value.conv_uid,
chat_mode=chat_mode,
user_name=input_value.user_name,
sys_code=input_value.sys_code,
conv_storage=self.storage,
message_storage=self.message_storage,
)
async def transform_stream(
self, input_value: AsyncIterator[ModelOutput]
) -> ModelOutput:
"""Transform the input value to a ModelOutput.
Args:
input_value (ModelOutput): The input value.
Returns:
ModelOutput: The transformed ModelOutput.
"""
full_text = ""
async for model_output in input_value:
# Now model_output.text if full text, if it is a delta text, we should merge all delta text to a full text
full_text = model_output.text
yield model_output
# Get the storage conversation from share data
storage_conv: StorageConversation = await self.get_storage_conversation()
storage_conv.add_ai_message(full_text)
# Save the storage conversation to share data, for the child operators
await self.current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_STORAGE_CONVERSATION, storage_conv
)
await self.current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_MODEL_REQUEST_CONTEXT, input_value
)
# Get history messages from storage
history_messages: List[BaseMessage] = storage_conv.get_history_message(
include_system_message=self._include_system_message
)
return history_messages
class ConversationMapperOperator(
BaseConversationOperator, MapOperator[ModelRequest, ModelRequest]
BaseConversationOperator, MapOperator[List[BaseMessage], List[BaseMessage]]
):
def __init__(self, message_mapper: _MultiRoundMessageMapper = None, **kwargs):
MapOperator.__init__(self, **kwargs)
self._message_mapper = message_mapper
async def map(self, input_value: ModelRequest) -> ModelRequest:
"""Map the input value to a ModelRequest.
async def map(self, input_value: List[BaseMessage]) -> List[BaseMessage]:
return self.map_messages(input_value)
Args:
input_value (ModelRequest): The input value.
Returns:
ModelRequest: The mapped ModelRequest.
"""
input_value = input_value.copy()
messages: List[ModelMessage] = self.map_messages(input_value.messages)
# Overwrite the input value
input_value.messages = messages
return input_value
def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
"""Map the input messages to a list of ModelMessage.
Args:
messages (List[ModelMessage]): The input messages.
Returns:
List[ModelMessage]: The mapped ModelMessage.
"""
messages_by_round: List[List[ModelMessage]] = self._split_messages_by_round(
def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
messages_by_round: List[List[BaseMessage]] = self._split_messages_by_round(
messages
)
message_mapper = self._message_mapper or self.map_multi_round_messages
return message_mapper(messages_by_round)
def map_multi_round_messages(
self, messages_by_round: List[List[ModelMessage]]
) -> List[ModelMessage]:
"""Map multi round messages to a list of ModelMessage
self, messages_by_round: List[List[BaseMessage]]
) -> List[BaseMessage]:
"""Map multi round messages to a list of BaseMessage.
By default, just merge all multi round messages to a list of ModelMessage according origin order.
By default, just merge all multi round messages to a list of BaseMessage according origin order.
And you can overwrite this method to implement your own logic.
Examples:
Merge multi round messages to a list of ModelMessage according origin order.
Merge multi round messages to a list of BaseMessage according origin order.
.. code-block:: python
>>> from dbgpt.core.interface.message import (
... AIMessage,
... HumanMessage,
... SystemMessage,
... )
>>> messages_by_round = [
... [
... HumanMessage(content="Hi", round_index=1),
... AIMessage(content="Hello!", round_index=1),
... ],
... [
... HumanMessage(content="What's the error?", round_index=2),
... AIMessage(content="Just a joke.", round_index=2),
... ],
... ]
>>> operator = ConversationMapperOperator()
>>> messages = operator.map_multi_round_messages(messages_by_round)
>>> assert messages == [
... HumanMessage(content="Hi", round_index=1),
... AIMessage(content="Hello!", round_index=1),
... HumanMessage(content="What's the error?", round_index=2),
... AIMessage(content="Just a joke.", round_index=2),
... ]
import asyncio
from dbgpt.core.operator import ConversationMapperOperator
Map multi round messages to a list of BaseMessage just keep the last one round.
messages_by_round = [
[
ModelMessage(role="human", content="Hi", round_index=1),
ModelMessage(role="ai", content="Hello!", round_index=1),
],
[
ModelMessage(role="system", content="Error 404", round_index=2),
ModelMessage(
role="human", content="What's the error?", round_index=2
),
ModelMessage(role="ai", content="Just a joke.", round_index=2),
],
[
ModelMessage(role="human", content="Funny!", round_index=3),
],
]
operator = ConversationMapperOperator()
messages = operator.map_multi_round_messages(messages_by_round)
assert messages == [
ModelMessage(role="human", content="Hi", round_index=1),
ModelMessage(role="ai", content="Hello!", round_index=1),
ModelMessage(role="system", content="Error 404", round_index=2),
ModelMessage(
role="human", content="What's the error?", round_index=2
),
ModelMessage(role="ai", content="Just a joke.", round_index=2),
ModelMessage(role="human", content="Funny!", round_index=3),
]
Map multi round messages to a list of ModelMessage just keep the last one round.
.. code-block:: python
class MyMapper(ConversationMapperOperator):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def map_multi_round_messages(
self, messages_by_round: List[List[ModelMessage]]
) -> List[ModelMessage]:
return messages_by_round[-1]
operator = MyMapper()
messages = operator.map_multi_round_messages(messages_by_round)
assert messages == [
ModelMessage(role="human", content="Funny!", round_index=3),
]
>>> class MyMapper(ConversationMapperOperator):
... def __init__(self, **kwargs):
... super().__init__(**kwargs)
...
... def map_multi_round_messages(
... self, messages_by_round: List[List[BaseMessage]]
... ) -> List[BaseMessage]:
... return messages_by_round[-1]
...
>>> operator = MyMapper()
>>> messages = operator.map_multi_round_messages(messages_by_round)
>>> assert messages == [
... HumanMessage(content="What's the error?", round_index=2),
... AIMessage(content="Just a joke.", round_index=2),
... ]
Args:
"""
@@ -340,17 +237,17 @@ class ConversationMapperOperator(
return sum(messages_by_round, [])
def _split_messages_by_round(
self, messages: List[ModelMessage]
) -> List[List[ModelMessage]]:
self, messages: List[BaseMessage]
) -> List[List[BaseMessage]]:
"""Split the messages by round index.
Args:
messages (List[ModelMessage]): The input messages.
messages (List[BaseMessage]): The messages.
Returns:
List[List[ModelMessage]]: The split messages.
List[List[BaseMessage]]: The messages split by round.
"""
messages_by_round: List[List[ModelMessage]] = []
messages_by_round: List[List[BaseMessage]] = []
last_round_index = 0
for message in messages:
if not message.round_index:
@@ -366,7 +263,7 @@ class ConversationMapperOperator(
class BufferedConversationMapperOperator(ConversationMapperOperator):
"""The buffered conversation mapper operator.
This Operator must be used after the PreConversationOperator,
This Operator must be used after the PreChatHistoryLoadOperator,
and it will map the messages in the storage conversation.
Examples:
@@ -419,8 +316,8 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
if message_mapper:
def new_message_mapper(
messages_by_round: List[List[ModelMessage]],
) -> List[ModelMessage]:
messages_by_round: List[List[BaseMessage]],
) -> List[BaseMessage]:
# Apply keep k round messages first, then apply the custom message mapper
messages_by_round = self._keep_last_round_messages(messages_by_round)
return message_mapper(messages_by_round)
@@ -428,23 +325,23 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
else:
def new_message_mapper(
messages_by_round: List[List[ModelMessage]],
) -> List[ModelMessage]:
messages_by_round: List[List[BaseMessage]],
) -> List[BaseMessage]:
messages_by_round = self._keep_last_round_messages(messages_by_round)
return sum(messages_by_round, [])
super().__init__(new_message_mapper, **kwargs)
def _keep_last_round_messages(
self, messages_by_round: List[List[ModelMessage]]
) -> List[List[ModelMessage]]:
self, messages_by_round: List[List[BaseMessage]]
) -> List[List[BaseMessage]]:
"""Keep the last k round messages.
Args:
messages_by_round (List[List[ModelMessage]]): The messages by round.
messages_by_round (List[List[BaseMessage]]): The messages by round.
Returns:
List[List[ModelMessage]]: The latest round messages.
List[List[BaseMessage]]: The latest round messages.
"""
index = self._last_k_round + 1
return messages_by_round[-index:]

View File

@@ -0,0 +1,255 @@
from abc import ABC
from typing import Any, Dict, List, Optional, Union
from dbgpt.core import (
BasePromptTemplate,
ChatPromptTemplate,
ModelMessage,
ModelMessageRoleType,
ModelOutput,
StorageConversation,
)
from dbgpt.core.awel import JoinOperator, MapOperator
from dbgpt.core.interface.message import BaseMessage
from dbgpt.core.interface.operator.llm_operator import BaseLLM
from dbgpt.core.interface.operator.message_operator import BaseConversationOperator
from dbgpt.core.interface.prompt import HumanPromptTemplate, MessageType
from dbgpt.util.function_utils import rearrange_args_by_type
class BasePromptBuilderOperator(BaseConversationOperator, ABC):
"""The base prompt builder operator."""
def __init__(self, check_storage: bool, **kwargs):
super().__init__(check_storage=check_storage, **kwargs)
async def format_prompt(
self, prompt: ChatPromptTemplate, prompt_dict: Dict[str, Any]
) -> List[ModelMessage]:
"""Format the prompt.
Args:
prompt (ChatPromptTemplate): The prompt.
prompt_dict (Dict[str, Any]): The prompt dict.
Returns:
List[ModelMessage]: The formatted prompt.
"""
kwargs = {}
kwargs.update(prompt_dict)
pass_kwargs = {k: v for k, v in kwargs.items() if k in prompt.input_variables}
messages = prompt.format_messages(**pass_kwargs)
messages = ModelMessage.from_base_messages(messages)
# Start new round conversation, and save user message to storage
await self.start_new_round_conv(messages)
return messages
async def start_new_round_conv(self, messages: List[ModelMessage]) -> None:
"""Start a new round conversation.
Args:
messages (List[ModelMessage]): The messages.
"""
lass_user_message = None
for message in messages[::-1]:
if message.role == ModelMessageRoleType.HUMAN:
lass_user_message = message.content
break
if not lass_user_message:
raise ValueError("No user message")
storage_conv: StorageConversation = await self.get_storage_conversation()
if not storage_conv:
return
# Start new round
storage_conv.start_new_round()
storage_conv.add_user_message(lass_user_message)
async def after_dag_end(self):
"""The callback after DAG end"""
# TODO remove this to start_new_round()
# Save the storage conversation to storage after the whole DAG finished
storage_conv: StorageConversation = await self.get_storage_conversation()
if not storage_conv:
return
model_output: ModelOutput = await self.current_dag_context.get_from_share_data(
BaseLLM.SHARE_DATA_KEY_MODEL_OUTPUT
)
if model_output:
# Save model output message to storage
storage_conv.add_ai_message(model_output.text)
# End current conversation round and flush to storage
storage_conv.end_current_round()
PromptTemplateType = Union[ChatPromptTemplate, BasePromptTemplate, MessageType, str]
class PromptBuilderOperator(
BasePromptBuilderOperator, MapOperator[Dict[str, Any], List[ModelMessage]]
):
"""The operator to build the prompt with static prompt.
Examples:
.. code-block:: python
import asyncio
from dbgpt.core.awel import DAG
from dbgpt.core import (
ModelMessage,
HumanMessage,
SystemMessage,
HumanPromptTemplate,
SystemPromptTemplate,
ChatPromptTemplate,
)
from dbgpt.core.operator import PromptBuilderOperator
with DAG("prompt_test") as dag:
str_prompt = PromptBuilderOperator(
"Please write a {dialect} SQL count the length of a field"
)
tp_prompt = PromptBuilderOperator(
HumanPromptTemplate.from_template(
"Please write a {dialect} SQL count the length of a field"
)
)
chat_prompt = PromptBuilderOperator(
ChatPromptTemplate(
messages=[
HumanPromptTemplate.from_template(
"Please write a {dialect} SQL count the length of a field"
)
]
)
)
with_sys_prompt = PromptBuilderOperator(
ChatPromptTemplate(
messages=[
SystemPromptTemplate.from_template(
"You are a {dialect} SQL expert"
),
HumanPromptTemplate.from_template(
"Please write a {dialect} SQL count the length of a field"
),
],
)
)
single_input = {"data": {"dialect": "mysql"}}
single_expected_messages = [
ModelMessage(
content="Please write a mysql SQL count the length of a field",
role="human",
)
]
with_sys_expected_messages = [
ModelMessage(content="You are a mysql SQL expert", role="system"),
ModelMessage(
content="Please write a mysql SQL count the length of a field",
role="human",
),
]
assert (
asyncio.run(str_prompt.call(call_data=single_input))
== single_expected_messages
)
assert (
asyncio.run(tp_prompt.call(call_data=single_input))
== single_expected_messages
)
assert (
asyncio.run(chat_prompt.call(call_data=single_input))
== single_expected_messages
)
assert (
asyncio.run(with_sys_prompt.call(call_data=single_input))
== with_sys_expected_messages
)
"""
def __init__(self, prompt: PromptTemplateType, **kwargs):
if isinstance(prompt, str):
prompt = ChatPromptTemplate(
messages=[HumanPromptTemplate.from_template(prompt)]
)
elif isinstance(prompt, BasePromptTemplate) and not isinstance(
prompt, ChatPromptTemplate
):
prompt = ChatPromptTemplate(
messages=[HumanPromptTemplate.from_template(prompt.template)]
)
elif isinstance(prompt, MessageType):
prompt = ChatPromptTemplate(messages=[prompt])
self._prompt = prompt
super().__init__(check_storage=False, **kwargs)
MapOperator.__init__(self, map_function=self.merge_prompt, **kwargs)
@rearrange_args_by_type
async def merge_prompt(self, prompt_dict: Dict[str, Any]) -> List[ModelMessage]:
return await self.format_prompt(self._prompt, prompt_dict)
class DynamicPromptBuilderOperator(
BasePromptBuilderOperator, JoinOperator[List[ModelMessage]]
):
"""The operator to build the prompt with dynamic prompt.
The prompt template is dynamic, and it created by parent operator.
"""
def __init__(self, **kwargs):
super().__init__(check_storage=False, **kwargs)
JoinOperator.__init__(self, combine_function=self.merge_prompt, **kwargs)
@rearrange_args_by_type
async def merge_prompt(
self, prompt: ChatPromptTemplate, prompt_dict: Dict[str, Any]
) -> List[ModelMessage]:
return await self.format_prompt(prompt, prompt_dict)
class HistoryPromptBuilderOperator(
BasePromptBuilderOperator, JoinOperator[List[ModelMessage]]
):
def __init__(
self, prompt: ChatPromptTemplate, history_key: Optional[str] = None, **kwargs
):
self._prompt = prompt
self._history_key = history_key
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
@rearrange_args_by_type
async def merge_history(
self, history: List[BaseMessage], prompt_dict: Dict[str, Any]
) -> List[ModelMessage]:
prompt_dict[self._history_key] = history
return await self.format_prompt(self._prompt, prompt_dict)
class HistoryDynamicPromptBuilderOperator(
BasePromptBuilderOperator, JoinOperator[List[ModelMessage]]
):
"""The operator to build the prompt with dynamic prompt.
The prompt template is dynamic, and it created by parent operator.
"""
def __init__(self, history_key: Optional[str] = None, **kwargs):
self._history_key = history_key
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
@rearrange_args_by_type
async def merge_history(
self,
prompt: ChatPromptTemplate,
history: List[BaseMessage],
prompt_dict: Dict[str, Any],
) -> List[ModelMessage]:
prompt_dict[self._history_key] = history
return await self.format_prompt(prompt, prompt_dict)

View File

@@ -1,11 +1,14 @@
from __future__ import annotations
import dataclasses
import json
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional
from string import Formatter
from typing import Any, Callable, Dict, List, Optional, Set, Union
from dbgpt._private.pydantic import BaseModel
from dbgpt._private.pydantic import BaseModel, root_validator
from dbgpt.core._private.example_base import ExampleSelector
from dbgpt.core.awel import MapOperator
from dbgpt.core.interface.message import BaseMessage, HumanMessage, SystemMessage
from dbgpt.core.interface.output_parser import BaseOutputParser
from dbgpt.core.interface.storage import (
InMemoryStorage,
@@ -38,15 +41,40 @@ _DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
}
class PromptTemplate(BaseModel, ABC):
class BasePromptTemplate(BaseModel):
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
template: Optional[str]
"""The prompt template."""
template_format: Optional[str] = "f-string"
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs."""
if self.template:
return _DEFAULT_FORMATTER_MAPPING[self.template_format](True)(
self.template, **kwargs
)
@classmethod
def from_template(
cls, template: str, template_format: Optional[str] = "f-string", **kwargs: Any
) -> BasePromptTemplate:
"""Create a prompt template from a template string."""
input_variables = get_template_vars(template, template_format)
return cls(
template=template,
input_variables=input_variables,
template_format=template_format,
**kwargs,
)
class PromptTemplate(BasePromptTemplate):
template_scene: Optional[str]
template_define: Optional[str]
"""this template define"""
template: Optional[str]
"""The prompt template."""
template_format: str = "f-string"
"""strict template will check template args"""
template_is_strict: bool = True
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
@@ -86,12 +114,114 @@ class PromptTemplate(BaseModel, ABC):
self.template_is_strict
)(self.template, **kwargs)
@staticmethod
def from_template(template: str) -> "PromptTemplateOperator":
class BaseChatPromptTemplate(BaseModel, ABC):
prompt: BasePromptTemplate
@property
def input_variables(self) -> List[str]:
"""A list of the names of the variables the prompt template expects."""
return self.prompt.input_variables
@abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format the prompt with the inputs."""
@classmethod
def from_template(
cls, template: str, template_format: Optional[str] = "f-string", **kwargs: Any
) -> BaseChatPromptTemplate:
"""Create a prompt template from a template string."""
return PromptTemplateOperator(
PromptTemplate(template=template, input_variables=[])
)
prompt = BasePromptTemplate.from_template(template, template_format)
return cls(prompt=prompt, **kwargs)
class SystemPromptTemplate(BaseChatPromptTemplate):
"""The system prompt template."""
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
content = self.prompt.format(**kwargs)
return [SystemMessage(content=content)]
class HumanPromptTemplate(BaseChatPromptTemplate):
"""The human prompt template."""
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
content = self.prompt.format(**kwargs)
return [HumanMessage(content=content)]
class MessagesPlaceholder(BaseChatPromptTemplate):
"""The messages placeholder template.
Mostly used for the chat history.
"""
variable_name: str
prompt: BasePromptTemplate = None
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
messages = kwargs.get(self.variable_name, [])
if not isinstance(messages, list):
raise ValueError(
f"Unsupported messages type: {type(messages)}, should be list."
)
for message in messages:
if not isinstance(message, BaseMessage):
raise ValueError(
f"Unsupported message type: {type(message)}, should be BaseMessage."
)
return messages
@property
def input_variables(self) -> List[str]:
"""A list of the names of the variables the prompt template expects.
Returns:
List[str]: The input variables.
"""
return [self.variable_name]
MessageType = Union[BaseChatPromptTemplate, BaseMessage]
class ChatPromptTemplate(BasePromptTemplate):
messages: List[MessageType]
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format the prompt with the inputs."""
result_messages = []
for message in self.messages:
if isinstance(message, BaseMessage):
result_messages.append(message)
elif isinstance(message, BaseChatPromptTemplate):
pass_kwargs = {
k: v for k, v in kwargs.items() if k in message.input_variables
}
result_messages.extend(message.format_messages(**pass_kwargs))
elif isinstance(message, MessagesPlaceholder):
pass_kwargs = {
k: v for k, v in kwargs.items() if k in message.input_variables
}
result_messages.extend(message.format_messages(**pass_kwargs))
else:
raise ValueError(f"Unsupported message type: {type(message)}")
return result_messages
@root_validator(pre=True)
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre-fill the messages."""
input_variables = values.get("input_variables", {})
messages = values.get("messages", [])
if not input_variables:
input_variables = set()
for message in messages:
if isinstance(message, BaseChatPromptTemplate):
input_variables.update(message.input_variables)
values["input_variables"] = sorted(input_variables)
return values
@dataclasses.dataclass
@@ -547,10 +677,36 @@ class PromptManager:
self.storage.delete(identifier)
class PromptTemplateOperator(MapOperator[Dict, str]):
def __init__(self, prompt_template: PromptTemplate, **kwargs: Any):
super().__init__(**kwargs)
self._prompt_template = prompt_template
def _get_string_template_vars(template_str: str) -> Set[str]:
"""Get template variables from a template string."""
variables = set()
formatter = Formatter()
async def map(self, input_value: Dict) -> str:
return self._prompt_template.format(**input_value)
for _, variable_name, _, _ in formatter.parse(template_str):
if variable_name:
variables.add(variable_name)
return variables
def _get_jinja2_template_vars(template_str: str) -> Set[str]:
"""Get template variables from a template string."""
from jinja2 import Environment, meta
env = Environment()
ast = env.parse(template_str)
variables = meta.find_undeclared_variables(ast)
return variables
def get_template_vars(
template_str: str, template_format: str = "f-string"
) -> List[str]:
"""Get template variables from a template string."""
if template_format == "f-string":
result = _get_string_template_vars(template_str)
elif template_format == "jinja2":
result = _get_jinja2_template_vars(template_str)
else:
raise ValueError(f"Unsupported template format: {template_format}")
return sorted(result)

View File

@@ -413,13 +413,18 @@ def test_to_openai_messages(
{"role": "user", "content": human_model_message.content},
]
def test_to_openai_messages_convert_to_compatible_format(
human_model_message, ai_model_message, system_model_message
):
shuffle_messages = ModelMessage.to_openai_messages(
[
system_model_message,
human_model_message,
human_model_message,
ai_model_message,
]
],
convert_to_compatible_format=True,
)
assert shuffle_messages == [
{"role": "system", "content": system_model_message.content},

View File

@@ -99,12 +99,6 @@ class TestPromptTemplate:
formatted_output = prompt.format(response="hello")
assert "Response: " in formatted_output
def test_from_template(self):
template_str = "Hello {name}"
prompt = PromptTemplate.from_template(template_str)
assert prompt._prompt_template.template == template_str
assert prompt._prompt_template.input_variables == []
def test_format_missing_variable(self):
template_str = "Hello {name}"
prompt = PromptTemplate(