mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-09 21:08:59 +00:00
feat(awel): New MessageConverter and more AWEL operators (#1039)
This commit is contained in:
@@ -1,14 +1,21 @@
|
||||
import collections
|
||||
import copy
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.util import BaseParameters
|
||||
from dbgpt.util.annotations import PublicAPI
|
||||
from dbgpt.util.model_utils import GPUInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
@@ -223,6 +230,29 @@ class ModelRequest:
|
||||
raise ValueError("The messages is not a single user message")
|
||||
return messages[0]
|
||||
|
||||
@staticmethod
|
||||
def build_request(
|
||||
model: str,
|
||||
messages: List[ModelMessage],
|
||||
context: Union[ModelRequestContext, Dict[str, Any], BaseModel],
|
||||
stream: Optional[bool] = False,
|
||||
**kwargs,
|
||||
):
|
||||
context_dict = None
|
||||
if isinstance(context, dict):
|
||||
context_dict = context
|
||||
elif isinstance(context, BaseModel):
|
||||
context_dict = context.dict()
|
||||
if context_dict and "stream" not in context_dict:
|
||||
context_dict["stream"] = stream
|
||||
context = ModelRequestContext(**context_dict)
|
||||
return ModelRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
context=context,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build(model: str, prompt: str, **kwargs):
|
||||
return ModelRequest(
|
||||
@@ -271,6 +301,43 @@ class ModelRequest:
|
||||
return ModelMessage.to_openai_messages(messages)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelExtraMedata(BaseParameters):
|
||||
"""A class to represent the extra metadata of a LLM."""
|
||||
|
||||
prompt_roles: Optional[List[str]] = field(
|
||||
default_factory=lambda: [
|
||||
ModelMessageRoleType.SYSTEM,
|
||||
ModelMessageRoleType.HUMAN,
|
||||
ModelMessageRoleType.AI,
|
||||
],
|
||||
metadata={"help": "The roles of the prompt"},
|
||||
)
|
||||
|
||||
prompt_sep: Optional[str] = field(
|
||||
default="\n",
|
||||
metadata={"help": "The separator of the prompt between multiple rounds"},
|
||||
)
|
||||
|
||||
# You can see the chat template in your model repo tokenizer config,
|
||||
# typically in the tokenizer_config.json
|
||||
prompt_chat_template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The chat template, see: https://huggingface.co/docs/transformers/main/en/chat_templating"
|
||||
},
|
||||
)
|
||||
|
||||
@property
|
||||
def support_system_message(self) -> bool:
|
||||
"""Whether the model supports system message.
|
||||
|
||||
Returns:
|
||||
bool: Whether the model supports system message.
|
||||
"""
|
||||
return ModelMessageRoleType.SYSTEM in self.prompt_roles
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
class ModelMetadata(BaseParameters):
|
||||
@@ -295,18 +362,294 @@ class ModelMetadata(BaseParameters):
|
||||
default_factory=dict,
|
||||
metadata={"help": "Model metadata"},
|
||||
)
|
||||
ext_metadata: Optional[ModelExtraMedata] = field(
|
||||
default_factory=ModelExtraMedata,
|
||||
metadata={"help": "Model extra metadata"},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls, data: dict, ignore_extra_fields: bool = False
|
||||
) -> "ModelMetadata":
|
||||
if "ext_metadata" in data:
|
||||
data["ext_metadata"] = ModelExtraMedata(**data["ext_metadata"])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class MessageConverter(ABC):
|
||||
"""An abstract class for message converter.
|
||||
|
||||
Different LLMs may have different message formats, this class is used to convert the messages
|
||||
to the format of the LLM.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from typing import List
|
||||
>>> from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
>>> from dbgpt.core.interface.llm import MessageConverter, ModelMetadata
|
||||
>>> class RemoveSystemMessageConverter(MessageConverter):
|
||||
... def convert(
|
||||
... self,
|
||||
... messages: List[ModelMessage],
|
||||
... model_metadata: Optional[ModelMetadata] = None,
|
||||
... ) -> List[ModelMessage]:
|
||||
... # Convert the messages, merge system messages to the last user message.
|
||||
... system_message = None
|
||||
... other_messages = []
|
||||
... sep = "\\n"
|
||||
... for message in messages:
|
||||
... if message.role == ModelMessageRoleType.SYSTEM:
|
||||
... system_message = message
|
||||
... else:
|
||||
... other_messages.append(message)
|
||||
... if system_message and other_messages:
|
||||
... other_messages[-1].content = (
|
||||
... system_message.content + sep + other_messages[-1].content
|
||||
... )
|
||||
... return other_messages
|
||||
...
|
||||
>>> messages = [
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.SYSTEM,
|
||||
... content="You are a helpful assistant",
|
||||
... ),
|
||||
... ModelMessage(role=ModelMessageRoleType.HUMAN, content="Who are you"),
|
||||
... ]
|
||||
>>> converter = RemoveSystemMessageConverter()
|
||||
>>> converted_messages = converter.convert(messages, None)
|
||||
>>> assert converted_messages == [
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.HUMAN,
|
||||
... content="You are a helpful assistant\\nWho are you",
|
||||
... ),
|
||||
... ]
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def convert(
|
||||
self,
|
||||
messages: List[ModelMessage],
|
||||
model_metadata: Optional[ModelMetadata] = None,
|
||||
) -> List[ModelMessage]:
|
||||
"""Convert the messages.
|
||||
|
||||
Args:
|
||||
messages(List[ModelMessage]): The messages.
|
||||
model_metadata(ModelMetadata): The model metadata.
|
||||
|
||||
Returns:
|
||||
List[ModelMessage]: The converted messages.
|
||||
"""
|
||||
|
||||
|
||||
class DefaultMessageConverter(MessageConverter):
|
||||
"""The default message converter."""
|
||||
|
||||
def __init__(self, prompt_sep: Optional[str] = None):
|
||||
self._prompt_sep = prompt_sep
|
||||
|
||||
def convert(
|
||||
self,
|
||||
messages: List[ModelMessage],
|
||||
model_metadata: Optional[ModelMetadata] = None,
|
||||
) -> List[ModelMessage]:
|
||||
"""Convert the messages.
|
||||
|
||||
There are three steps to convert the messages:
|
||||
|
||||
1. Just keep system, human and AI messages
|
||||
|
||||
2. Move the last user's message to the end of the list
|
||||
|
||||
3. Convert the messages to no system message if the model does not support system message
|
||||
|
||||
Args:
|
||||
messages(List[ModelMessage]): The messages.
|
||||
model_metadata(ModelMetadata): The model metadata.
|
||||
|
||||
Returns:
|
||||
List[ModelMessage]: The converted messages.
|
||||
"""
|
||||
# 1. Just keep system, human and AI messages
|
||||
messages = list(filter(lambda m: m.pass_to_model, messages))
|
||||
# 2. Move the last user's message to the end of the list
|
||||
messages = self.move_last_user_message_to_end(messages)
|
||||
|
||||
if not model_metadata or not model_metadata.ext_metadata:
|
||||
logger.warning("No model metadata, skip message system message conversion")
|
||||
return messages
|
||||
if model_metadata.ext_metadata.support_system_message:
|
||||
# 3. Convert the messages to no system message
|
||||
return self.convert_to_no_system_message(messages, model_metadata)
|
||||
return messages
|
||||
|
||||
def convert_to_no_system_message(
|
||||
self,
|
||||
messages: List[ModelMessage],
|
||||
model_metadata: Optional[ModelMetadata] = None,
|
||||
) -> List[ModelMessage]:
|
||||
"""Convert the messages to no system message.
|
||||
|
||||
Examples:
|
||||
>>> # Convert the messages to no system message, just merge system messages to the last user message
|
||||
>>> from typing import List
|
||||
>>> from dbgpt.core.interface.message import (
|
||||
... ModelMessage,
|
||||
... ModelMessageRoleType,
|
||||
... )
|
||||
>>> from dbgpt.core.interface.llm import (
|
||||
... DefaultMessageConverter,
|
||||
... ModelMetadata,
|
||||
... )
|
||||
>>> messages = [
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.SYSTEM,
|
||||
... content="You are a helpful assistant",
|
||||
... ),
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.HUMAN, content="Who are you"
|
||||
... ),
|
||||
... ]
|
||||
>>> converter = DefaultMessageConverter()
|
||||
>>> model_metadata = ModelMetadata(model="test")
|
||||
>>> converted_messages = converter.convert_to_no_system_message(
|
||||
... messages, model_metadata
|
||||
... )
|
||||
>>> assert converted_messages == [
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.HUMAN,
|
||||
... content="You are a helpful assistant\\nWho are you",
|
||||
... ),
|
||||
... ]
|
||||
"""
|
||||
if not model_metadata or not model_metadata.ext_metadata:
|
||||
logger.warning("No model metadata, skip message conversion")
|
||||
return messages
|
||||
ext_metadata = model_metadata.ext_metadata
|
||||
system_messages = []
|
||||
result_messages = []
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.SYSTEM:
|
||||
# Not support system message, append system message to the last user message
|
||||
system_messages.append(message)
|
||||
elif message.role in [
|
||||
ModelMessageRoleType.HUMAN,
|
||||
ModelMessageRoleType.AI,
|
||||
]:
|
||||
result_messages.append(message)
|
||||
prompt_sep = self._prompt_sep or ext_metadata.prompt_sep or "\n"
|
||||
system_message_str = None
|
||||
if len(system_messages) > 1:
|
||||
logger.warning("Your system messages have more than one message")
|
||||
system_message_str = prompt_sep.join([m.content for m in system_messages])
|
||||
elif len(system_messages) == 1:
|
||||
system_message_str = system_messages[0].content
|
||||
|
||||
if system_message_str and result_messages:
|
||||
# Not support system messages, merge system messages to the last user message
|
||||
result_messages[-1].content = (
|
||||
system_message_str + prompt_sep + result_messages[-1].content
|
||||
)
|
||||
return result_messages
|
||||
|
||||
def move_last_user_message_to_end(
|
||||
self, messages: List[ModelMessage]
|
||||
) -> List[ModelMessage]:
|
||||
"""Move the last user message to the end of the list.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from typing import List
|
||||
>>> from dbgpt.core.interface.message import (
|
||||
... ModelMessage,
|
||||
... ModelMessageRoleType,
|
||||
... )
|
||||
>>> from dbgpt.core.interface.llm import DefaultMessageConverter
|
||||
>>> messages = [
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.SYSTEM,
|
||||
... content="You are a helpful assistant",
|
||||
... ),
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.HUMAN, content="Who are you"
|
||||
... ),
|
||||
... ModelMessage(role=ModelMessageRoleType.AI, content="I'm a robot"),
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.HUMAN, content="What's your name"
|
||||
... ),
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.SYSTEM,
|
||||
... content="You are a helpful assistant",
|
||||
... ),
|
||||
... ]
|
||||
>>> converter = DefaultMessageConverter()
|
||||
>>> converted_messages = converter.move_last_user_message_to_end(messages)
|
||||
>>> assert converted_messages == [
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.SYSTEM,
|
||||
... content="You are a helpful assistant",
|
||||
... ),
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.HUMAN, content="Who are you"
|
||||
... ),
|
||||
... ModelMessage(role=ModelMessageRoleType.AI, content="I'm a robot"),
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.SYSTEM,
|
||||
... content="You are a helpful assistant",
|
||||
... ),
|
||||
... ModelMessage(
|
||||
... role=ModelMessageRoleType.HUMAN, content="What's your name"
|
||||
... ),
|
||||
... ]
|
||||
|
||||
Args:
|
||||
messages(List[ModelMessage]): The messages.
|
||||
|
||||
Returns:
|
||||
List[ModelMessage]: The converted messages.
|
||||
"""
|
||||
last_user_input_index = None
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
if messages[i].role == ModelMessageRoleType.HUMAN:
|
||||
last_user_input_index = i
|
||||
break
|
||||
if last_user_input_index is not None:
|
||||
last_user_input = messages.pop(last_user_input_index)
|
||||
messages.append(last_user_input)
|
||||
return messages
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class LLMClient(ABC):
|
||||
"""An abstract class for LLM client."""
|
||||
|
||||
# Cache the model metadata for 60 seconds
|
||||
_MODEL_CACHE_ = TTLCache(maxsize=100, ttl=60)
|
||||
|
||||
@property
|
||||
def cache(self) -> collections.abc.MutableMapping:
|
||||
"""The cache object to cache the model metadata.
|
||||
|
||||
You can override this property to use your own cache object.
|
||||
Returns:
|
||||
collections.abc.MutableMapping: The cache object.
|
||||
"""
|
||||
return self._MODEL_CACHE_
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self, request: ModelRequest) -> ModelOutput:
|
||||
async def generate(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> ModelOutput:
|
||||
"""Generate a response for a given model request.
|
||||
|
||||
Sometimes, different LLMs may have different message formats,
|
||||
you can use the message converter to convert the messages to the format of the LLM.
|
||||
|
||||
Args:
|
||||
request(ModelRequest): The model request.
|
||||
message_converter(MessageConverter): The message converter.
|
||||
|
||||
Returns:
|
||||
ModelOutput: The model output.
|
||||
@@ -315,12 +658,18 @@ class LLMClient(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def generate_stream(
|
||||
self, request: ModelRequest
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""Generate a stream of responses for a given model request.
|
||||
|
||||
Sometimes, different LLMs may have different message formats,
|
||||
you can use the message converter to convert the messages to the format of the LLM.
|
||||
|
||||
Args:
|
||||
request(ModelRequest): The model request.
|
||||
message_converter(MessageConverter): The message converter.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[ModelOutput]: The model output stream.
|
||||
@@ -345,3 +694,65 @@ class LLMClient(ABC):
|
||||
Returns:
|
||||
int: The number of tokens.
|
||||
"""
|
||||
|
||||
async def covert_message(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
message_converter: Optional[MessageConverter] = None,
|
||||
) -> ModelRequest:
|
||||
"""Covert the message.
|
||||
If no message converter is provided, the original request will be returned.
|
||||
|
||||
Args:
|
||||
request(ModelRequest): The model request.
|
||||
message_converter(MessageConverter): The message converter.
|
||||
|
||||
Returns:
|
||||
ModelRequest: The converted model request.
|
||||
"""
|
||||
if not message_converter:
|
||||
return request
|
||||
new_request = request.copy()
|
||||
model_metadata = await self.get_model_metadata(request.model)
|
||||
new_messages = message_converter.convert(request.messages, model_metadata)
|
||||
new_request.messages = new_messages
|
||||
return new_request
|
||||
|
||||
async def cached_models(self) -> List[ModelMetadata]:
|
||||
"""Get all the models from the cache or the llm server.
|
||||
|
||||
If the model metadata is not in the cache, it will be fetched from the llm server.
|
||||
|
||||
Returns:
|
||||
List[ModelMetadata]: A list of model metadata.
|
||||
"""
|
||||
key = "____$llm_client_models$____"
|
||||
if key not in self.cache:
|
||||
models = await self.models()
|
||||
self.cache[key] = models
|
||||
for model in models:
|
||||
model_metadata_key = (
|
||||
f"____$llm_client_models_metadata_{model.model}$____"
|
||||
)
|
||||
self.cache[model_metadata_key] = model
|
||||
return self.cache[key]
|
||||
|
||||
async def get_model_metadata(self, model: str) -> ModelMetadata:
|
||||
"""Get the model metadata.
|
||||
|
||||
Args:
|
||||
model(str): The model name.
|
||||
|
||||
Returns:
|
||||
ModelMetadata: The model metadata.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not found.
|
||||
"""
|
||||
model_metadata_key = f"____$llm_client_models_metadata_{model}$____"
|
||||
if model_metadata_key not in self.cache:
|
||||
await self.cached_models()
|
||||
model_metadata = self.cache.get(model_metadata_key)
|
||||
if not model_metadata:
|
||||
raise ValueError(f"Model {model} not found")
|
||||
return model_metadata
|
||||
|
@@ -5,7 +5,6 @@ from datetime import datetime
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.interface.storage import (
|
||||
InMemoryStorage,
|
||||
ResourceIdentifier,
|
||||
@@ -114,6 +113,50 @@ class ModelMessage(BaseModel):
|
||||
content: str
|
||||
round_index: Optional[int] = 0
|
||||
|
||||
@property
|
||||
def pass_to_model(self) -> bool:
|
||||
"""Whether the message will be passed to the model
|
||||
|
||||
The view message will not be passed to the model
|
||||
|
||||
Returns:
|
||||
bool: Whether the message will be passed to the model
|
||||
"""
|
||||
return self.role in [
|
||||
ModelMessageRoleType.SYSTEM,
|
||||
ModelMessageRoleType.HUMAN,
|
||||
ModelMessageRoleType.AI,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def from_base_messages(messages: List[BaseMessage]) -> List["ModelMessage"]:
|
||||
result = []
|
||||
for message in messages:
|
||||
content, round_index = message.content, message.round_index
|
||||
if isinstance(message, HumanMessage):
|
||||
result.append(
|
||||
ModelMessage(
|
||||
role=ModelMessageRoleType.HUMAN,
|
||||
content=content,
|
||||
round_index=round_index,
|
||||
)
|
||||
)
|
||||
elif isinstance(message, AIMessage):
|
||||
result.append(
|
||||
ModelMessage(
|
||||
role=ModelMessageRoleType.AI,
|
||||
content=content,
|
||||
round_index=round_index,
|
||||
)
|
||||
)
|
||||
elif isinstance(message, SystemMessage):
|
||||
result.append(
|
||||
ModelMessage(
|
||||
role=ModelMessageRoleType.SYSTEM, content=message.content
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def from_openai_messages(
|
||||
messages: Union[str, List[Dict[str, str]]]
|
||||
@@ -142,9 +185,15 @@ class ModelMessage(BaseModel):
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def to_openai_messages(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
|
||||
def to_openai_messages(
|
||||
messages: List["ModelMessage"], convert_to_compatible_format: bool = False
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Convert to OpenAI message format and
|
||||
hugggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
|
||||
|
||||
Args:
|
||||
messages (List["ModelMessage"]): The model messages
|
||||
convert_to_compatible_format (bool): Whether to convert to compatible format
|
||||
"""
|
||||
history = []
|
||||
# Add history conversation
|
||||
@@ -157,15 +206,16 @@ class ModelMessage(BaseModel):
|
||||
history.append({"role": "assistant", "content": message.content})
|
||||
else:
|
||||
pass
|
||||
# Move the last user's information to the end
|
||||
last_user_input_index = None
|
||||
for i in range(len(history) - 1, -1, -1):
|
||||
if history[i]["role"] == "user":
|
||||
last_user_input_index = i
|
||||
break
|
||||
if last_user_input_index:
|
||||
last_user_input = history.pop(last_user_input_index)
|
||||
history.append(last_user_input)
|
||||
if convert_to_compatible_format:
|
||||
# Move the last user's information to the end
|
||||
last_user_input_index = None
|
||||
for i in range(len(history) - 1, -1, -1):
|
||||
if history[i]["role"] == "user":
|
||||
last_user_input_index = i
|
||||
break
|
||||
if last_user_input_index:
|
||||
last_user_input = history.pop(last_user_input_index)
|
||||
history.append(last_user_input)
|
||||
return history
|
||||
|
||||
@staticmethod
|
||||
@@ -189,8 +239,8 @@ class ModelMessage(BaseModel):
|
||||
return str_msg
|
||||
|
||||
|
||||
_SingleRoundMessage = List[ModelMessage]
|
||||
_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[ModelMessage]]
|
||||
_SingleRoundMessage = List[BaseMessage]
|
||||
_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[BaseMessage]]
|
||||
|
||||
|
||||
def _message_to_dict(message: BaseMessage) -> Dict:
|
||||
@@ -338,7 +388,8 @@ class OnceConversation:
|
||||
"""Start a new round of conversation
|
||||
|
||||
Example:
|
||||
>>> conversation = OnceConversation()
|
||||
|
||||
>>> conversation = OnceConversation("chat_normal")
|
||||
>>> # The chat order will be 0, then we start a new round of conversation
|
||||
>>> assert conversation.chat_order == 0
|
||||
>>> conversation.start_new_round()
|
||||
@@ -585,6 +636,28 @@ class OnceConversation:
|
||||
)
|
||||
return messages
|
||||
|
||||
def get_history_message(
|
||||
self, include_system_message: bool = False
|
||||
) -> List[BaseMessage]:
|
||||
"""Get the history message
|
||||
|
||||
Not include the system messages.
|
||||
|
||||
Args:
|
||||
include_system_message (bool): Whether to include the system message
|
||||
|
||||
Returns:
|
||||
List[BaseMessage]: The history messages
|
||||
"""
|
||||
messages = []
|
||||
for message in self.messages:
|
||||
if message.pass_to_model:
|
||||
if include_system_message:
|
||||
messages.append(message)
|
||||
elif message.type != "system":
|
||||
messages.append(message)
|
||||
return messages
|
||||
|
||||
|
||||
class ConversationIdentifier(ResourceIdentifier):
|
||||
"""Conversation identifier"""
|
||||
|
114
dbgpt/core/interface/operator/composer_operator.py
Normal file
114
dbgpt/core/interface/operator/composer_operator.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import dataclasses
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from dbgpt.core import (
|
||||
ChatPromptTemplate,
|
||||
MessageStorageItem,
|
||||
ModelMessage,
|
||||
ModelRequest,
|
||||
StorageConversation,
|
||||
StorageInterface,
|
||||
)
|
||||
from dbgpt.core.awel import (
|
||||
DAG,
|
||||
BaseOperator,
|
||||
InputOperator,
|
||||
JoinOperator,
|
||||
MapOperator,
|
||||
SimpleCallDataInputSource,
|
||||
)
|
||||
from dbgpt.core.interface.operator.prompt_operator import HistoryPromptBuilderOperator
|
||||
|
||||
from .message_operator import (
|
||||
BufferedConversationMapperOperator,
|
||||
ChatHistoryLoadType,
|
||||
PreChatHistoryLoadOperator,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ChatComposerInput:
|
||||
"""The composer input."""
|
||||
|
||||
prompt_dict: Dict[str, Any]
|
||||
model_dict: Dict[str, Any]
|
||||
context: ChatHistoryLoadType
|
||||
|
||||
|
||||
class ChatHistoryPromptComposerOperator(MapOperator[ChatComposerInput, ModelRequest]):
|
||||
"""The chat history prompt composer operator.
|
||||
|
||||
For simple use, you can use this operator to compose the chat history prompt.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_template: ChatPromptTemplate,
|
||||
history_key: str = "chat_history",
|
||||
last_k_round: int = 2,
|
||||
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
|
||||
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._prompt_template = prompt_template
|
||||
self._history_key = history_key
|
||||
self._last_k_round = last_k_round
|
||||
self._storage = storage
|
||||
self._message_storage = message_storage
|
||||
self._sub_compose_dag = self._build_composer_dag()
|
||||
|
||||
async def map(self, input_value: ChatComposerInput) -> ModelRequest:
|
||||
end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
|
||||
# Sub dag, use the same dag context in the parent dag
|
||||
return await end_node.call(
|
||||
call_data={"data": input_value}, dag_ctx=self.current_dag_context
|
||||
)
|
||||
|
||||
def _build_composer_dag(self) -> DAG:
|
||||
with DAG("dbgpt_awel_chat_history_prompt_composer") as composer_dag:
|
||||
input_task = InputOperator(input_source=SimpleCallDataInputSource())
|
||||
# Load and store chat history, default use InMemoryStorage.
|
||||
chat_history_load_task = PreChatHistoryLoadOperator(
|
||||
storage=self._storage, message_storage=self._message_storage
|
||||
)
|
||||
# History transform task, here we keep last 5 round messages
|
||||
history_transform_task = BufferedConversationMapperOperator(
|
||||
last_k_round=self._last_k_round
|
||||
)
|
||||
history_prompt_build_task = HistoryPromptBuilderOperator(
|
||||
prompt=self._prompt_template, history_key=self._history_key
|
||||
)
|
||||
model_request_build_task = JoinOperator(self._build_model_request)
|
||||
|
||||
# Build composer dag
|
||||
(
|
||||
input_task
|
||||
>> MapOperator(lambda x: x.context)
|
||||
>> chat_history_load_task
|
||||
>> history_transform_task
|
||||
>> history_prompt_build_task
|
||||
)
|
||||
(
|
||||
input_task
|
||||
>> MapOperator(lambda x: x.prompt_dict)
|
||||
>> history_prompt_build_task
|
||||
)
|
||||
|
||||
history_prompt_build_task >> model_request_build_task
|
||||
(
|
||||
input_task
|
||||
>> MapOperator(lambda x: x.model_dict)
|
||||
>> model_request_build_task
|
||||
)
|
||||
|
||||
return composer_dag
|
||||
|
||||
def _build_model_request(
|
||||
self, messages: List[ModelMessage], model_dict: Dict[str, Any]
|
||||
) -> ModelRequest:
|
||||
return ModelRequest.build_request(messages=messages, **model_dict)
|
||||
|
||||
async def after_dag_end(self):
|
||||
# Should call after_dag_end() of sub dag
|
||||
await self._sub_compose_dag._after_dag_end()
|
@@ -1,11 +1,12 @@
|
||||
import dataclasses
|
||||
from abc import ABC
|
||||
from typing import Any, AsyncIterator, Dict, Optional, Union
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt.core.awel import (
|
||||
BranchFunc,
|
||||
BranchOperator,
|
||||
DAGContext,
|
||||
MapOperator,
|
||||
StreamifyAbsOperator,
|
||||
)
|
||||
@@ -22,20 +23,30 @@ RequestInput = Union[
|
||||
str,
|
||||
Dict[str, Any],
|
||||
BaseModel,
|
||||
ModelMessage,
|
||||
List[ModelMessage],
|
||||
]
|
||||
|
||||
|
||||
class RequestBuildOperator(MapOperator[RequestInput, ModelRequest], ABC):
|
||||
class RequestBuilderOperator(MapOperator[RequestInput, ModelRequest], ABC):
|
||||
"""Build the model request from the input value."""
|
||||
|
||||
def __init__(self, model: Optional[str] = None, **kwargs):
|
||||
self._model = model
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: RequestInput) -> ModelRequest:
|
||||
req_dict = {}
|
||||
if not input_value:
|
||||
raise ValueError("input_value is not set")
|
||||
if isinstance(input_value, str):
|
||||
req_dict = {"messages": [ModelMessage.build_human_message(input_value)]}
|
||||
elif isinstance(input_value, dict):
|
||||
req_dict = input_value
|
||||
elif isinstance(input_value, ModelMessage):
|
||||
req_dict = {"messages": [input_value]}
|
||||
elif isinstance(input_value, list) and isinstance(input_value[0], ModelMessage):
|
||||
req_dict = {"messages": input_value}
|
||||
elif dataclasses.is_dataclass(input_value):
|
||||
req_dict = dataclasses.asdict(input_value)
|
||||
elif isinstance(input_value, BaseModel):
|
||||
@@ -76,6 +87,7 @@ class BaseLLM:
|
||||
"""The abstract operator for a LLM."""
|
||||
|
||||
SHARE_DATA_KEY_MODEL_NAME = "share_data_key_model_name"
|
||||
SHARE_DATA_KEY_MODEL_OUTPUT = "share_data_key_model_output"
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||
self._llm_client = llm_client
|
||||
@@ -87,8 +99,16 @@ class BaseLLM:
|
||||
raise ValueError("llm_client is not set")
|
||||
return self._llm_client
|
||||
|
||||
async def save_model_output(
|
||||
self, current_dag_context: DAGContext, model_output: ModelOutput
|
||||
) -> None:
|
||||
"""Save the model output to the share data."""
|
||||
await current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_MODEL_OUTPUT, model_output
|
||||
)
|
||||
|
||||
class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
|
||||
|
||||
class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
|
||||
"""The operator for a LLM.
|
||||
|
||||
Args:
|
||||
@@ -105,10 +125,12 @@ class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_MODEL_NAME, request.model
|
||||
)
|
||||
return await self.llm_client.generate(request)
|
||||
model_output = await self.llm_client.generate(request)
|
||||
await self.save_model_output(self.current_dag_context, model_output)
|
||||
return model_output
|
||||
|
||||
|
||||
class StreamingLLMOperator(
|
||||
class BaseStreamingLLMOperator(
|
||||
BaseLLM, StreamifyAbsOperator[ModelRequest, ModelOutput], ABC
|
||||
):
|
||||
"""The streaming operator for a LLM.
|
||||
@@ -127,8 +149,12 @@ class StreamingLLMOperator(
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_MODEL_NAME, request.model
|
||||
)
|
||||
model_output = None
|
||||
async for output in self.llm_client.generate_stream(request):
|
||||
model_output = output
|
||||
yield output
|
||||
if model_output:
|
||||
await self.save_model_output(self.current_dag_context, model_output)
|
||||
|
||||
|
||||
class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
|
||||
|
@@ -1,19 +1,17 @@
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncIterator, List, Optional
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt.core import (
|
||||
MessageStorageItem,
|
||||
ModelMessage,
|
||||
ModelMessageRoleType,
|
||||
ModelOutput,
|
||||
ModelRequest,
|
||||
ModelRequestContext,
|
||||
StorageConversation,
|
||||
StorageInterface,
|
||||
)
|
||||
from dbgpt.core.awel import BaseOperator, MapOperator, TransformStreamAbsOperator
|
||||
from dbgpt.core.interface.message import _MultiRoundMessageMapper
|
||||
from dbgpt.core.awel import BaseOperator, MapOperator
|
||||
from dbgpt.core.interface.message import BaseMessage, _MultiRoundMessageMapper
|
||||
|
||||
|
||||
class BaseConversationOperator(BaseOperator, ABC):
|
||||
@@ -21,32 +19,41 @@ class BaseConversationOperator(BaseOperator, ABC):
|
||||
|
||||
SHARE_DATA_KEY_STORAGE_CONVERSATION = "share_data_key_storage_conversation"
|
||||
SHARE_DATA_KEY_MODEL_REQUEST = "share_data_key_model_request"
|
||||
SHARE_DATA_KEY_MODEL_REQUEST_CONTEXT = "share_data_key_model_request_context"
|
||||
|
||||
_check_storage: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
|
||||
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
|
||||
check_storage: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
self._check_storage = check_storage
|
||||
super().__init__(**kwargs)
|
||||
self._storage = storage
|
||||
self._message_storage = message_storage
|
||||
|
||||
@property
|
||||
def storage(self) -> StorageInterface[StorageConversation, Any]:
|
||||
def storage(self) -> Optional[StorageInterface[StorageConversation, Any]]:
|
||||
"""Return the LLM client."""
|
||||
if not self._storage:
|
||||
raise ValueError("Storage is not set")
|
||||
if self._check_storage:
|
||||
raise ValueError("Storage is not set")
|
||||
return None
|
||||
return self._storage
|
||||
|
||||
@property
|
||||
def message_storage(self) -> StorageInterface[MessageStorageItem, Any]:
|
||||
def message_storage(self) -> Optional[StorageInterface[MessageStorageItem, Any]]:
|
||||
"""Return the LLM client."""
|
||||
if not self._message_storage:
|
||||
raise ValueError("Message storage is not set")
|
||||
if self._check_storage:
|
||||
raise ValueError("Message storage is not set")
|
||||
return None
|
||||
return self._message_storage
|
||||
|
||||
async def get_storage_conversation(self) -> StorageConversation:
|
||||
async def get_storage_conversation(self) -> Optional[StorageConversation]:
|
||||
"""Get the storage conversation from share data.
|
||||
|
||||
Returns:
|
||||
@@ -58,104 +65,11 @@ class BaseConversationOperator(BaseOperator, ABC):
|
||||
)
|
||||
)
|
||||
if not storage_conv:
|
||||
raise ValueError("Storage conversation is not set")
|
||||
if self._check_storage:
|
||||
raise ValueError("Storage conversation is not set")
|
||||
return None
|
||||
return storage_conv
|
||||
|
||||
async def get_model_request(self) -> ModelRequest:
|
||||
"""Get the model request from share data.
|
||||
|
||||
Returns:
|
||||
ModelRequest: The model request.
|
||||
"""
|
||||
model_request: ModelRequest = (
|
||||
await self.current_dag_context.get_from_share_data(
|
||||
self.SHARE_DATA_KEY_MODEL_REQUEST
|
||||
)
|
||||
)
|
||||
if not model_request:
|
||||
raise ValueError("Model request is not set")
|
||||
return model_request
|
||||
|
||||
|
||||
class PreConversationOperator(
|
||||
BaseConversationOperator, MapOperator[ModelRequest, ModelRequest]
|
||||
):
|
||||
"""The operator to prepare the storage conversation.
|
||||
|
||||
In DB-GPT, conversation record and the messages in the conversation are stored in the storage,
|
||||
and they can store in different storage(for high performance).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
|
||||
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(storage=storage, message_storage=message_storage)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
|
||||
async def map(self, input_value: ModelRequest) -> ModelRequest:
|
||||
"""Map the input value to a ModelRequest.
|
||||
|
||||
Args:
|
||||
input_value (ModelRequest): The input value.
|
||||
|
||||
Returns:
|
||||
ModelRequest: The mapped ModelRequest.
|
||||
"""
|
||||
if input_value.context is None:
|
||||
input_value.context = ModelRequestContext()
|
||||
if not input_value.context.conv_uid:
|
||||
input_value.context.conv_uid = str(uuid.uuid4())
|
||||
if not input_value.context.extra:
|
||||
input_value.context.extra = {}
|
||||
|
||||
chat_mode = input_value.context.chat_mode
|
||||
|
||||
# Create a new storage conversation, this will load the conversation from storage, so we must do this async
|
||||
storage_conv: StorageConversation = await self.blocking_func_to_async(
|
||||
StorageConversation,
|
||||
conv_uid=input_value.context.conv_uid,
|
||||
chat_mode=chat_mode,
|
||||
user_name=input_value.context.user_name,
|
||||
sys_code=input_value.context.sys_code,
|
||||
conv_storage=self.storage,
|
||||
message_storage=self.message_storage,
|
||||
)
|
||||
input_messages = input_value.get_messages()
|
||||
await self.save_to_storage(storage_conv, input_messages)
|
||||
# Get all messages from current storage conversation, and overwrite the input value
|
||||
messages: List[ModelMessage] = storage_conv.get_model_messages()
|
||||
input_value.messages = messages
|
||||
|
||||
# Save the storage conversation to share data, for the child operators
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_STORAGE_CONVERSATION, storage_conv
|
||||
)
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_MODEL_REQUEST, input_value
|
||||
)
|
||||
return input_value
|
||||
|
||||
async def save_to_storage(
|
||||
self, storage_conv: StorageConversation, input_messages: List[ModelMessage]
|
||||
) -> None:
|
||||
"""Save the messages to storage.
|
||||
|
||||
Args:
|
||||
storage_conv (StorageConversation): The storage conversation.
|
||||
input_messages (List[ModelMessage]): The input messages.
|
||||
"""
|
||||
# check first
|
||||
self.check_messages(input_messages)
|
||||
storage_conv.start_new_round()
|
||||
for message in input_messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
storage_conv.add_user_message(message.content)
|
||||
else:
|
||||
storage_conv.add_system_message(message.content)
|
||||
|
||||
def check_messages(self, messages: List[ModelMessage]) -> None:
|
||||
"""Check the messages.
|
||||
|
||||
@@ -174,164 +88,147 @@ class PreConversationOperator(
|
||||
]:
|
||||
raise ValueError(f"Message role {message.role} is not supported")
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""The callback after DAG end"""
|
||||
# Save the storage conversation to storage after the whole DAG finished
|
||||
storage_conv: StorageConversation = await self.get_storage_conversation()
|
||||
# TODO dont save if the conversation has some internal error
|
||||
storage_conv.end_current_round()
|
||||
|
||||
ChatHistoryLoadType = Union[ModelRequestContext, Dict[str, Any]]
|
||||
|
||||
|
||||
class PostConversationOperator(
|
||||
BaseConversationOperator, MapOperator[ModelOutput, ModelOutput]
|
||||
class PreChatHistoryLoadOperator(
|
||||
BaseConversationOperator, MapOperator[ChatHistoryLoadType, List[BaseMessage]]
|
||||
):
|
||||
def __init__(self, **kwargs):
|
||||
"""The operator to prepare the storage conversation.
|
||||
|
||||
In DB-GPT, conversation record and the messages in the conversation are stored in the storage,
|
||||
and they can store in different storage(for high performance).
|
||||
|
||||
This operator just load the conversation and messages from storage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
|
||||
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
|
||||
include_system_message: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(storage=storage, message_storage=message_storage)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
self._include_system_message = include_system_message
|
||||
|
||||
async def map(self, input_value: ModelOutput) -> ModelOutput:
|
||||
"""Map the input value to a ModelOutput.
|
||||
async def map(self, input_value: ChatHistoryLoadType) -> List[BaseMessage]:
|
||||
"""Map the input value to a ModelRequest.
|
||||
|
||||
Args:
|
||||
input_value (ModelOutput): The input value.
|
||||
input_value (ChatHistoryLoadType): The input value.
|
||||
|
||||
Returns:
|
||||
ModelOutput: The mapped ModelOutput.
|
||||
List[BaseMessage]: The messages stored in the storage.
|
||||
"""
|
||||
# Get the storage conversation from share data
|
||||
storage_conv: StorageConversation = await self.get_storage_conversation()
|
||||
storage_conv.add_ai_message(input_value.text)
|
||||
return input_value
|
||||
if not input_value:
|
||||
raise ValueError("Model request context can't be None")
|
||||
if isinstance(input_value, dict):
|
||||
input_value = ModelRequestContext(**input_value)
|
||||
if not input_value.conv_uid:
|
||||
input_value.conv_uid = str(uuid.uuid4())
|
||||
if not input_value.extra:
|
||||
input_value.extra = {}
|
||||
|
||||
chat_mode = input_value.chat_mode
|
||||
|
||||
class PostStreamingConversationOperator(
|
||||
BaseConversationOperator, TransformStreamAbsOperator[ModelOutput, ModelOutput]
|
||||
):
|
||||
def __init__(self, **kwargs):
|
||||
TransformStreamAbsOperator.__init__(self, **kwargs)
|
||||
# Create a new storage conversation, this will load the conversation from storage, so we must do this async
|
||||
storage_conv: StorageConversation = await self.blocking_func_to_async(
|
||||
StorageConversation,
|
||||
conv_uid=input_value.conv_uid,
|
||||
chat_mode=chat_mode,
|
||||
user_name=input_value.user_name,
|
||||
sys_code=input_value.sys_code,
|
||||
conv_storage=self.storage,
|
||||
message_storage=self.message_storage,
|
||||
)
|
||||
|
||||
async def transform_stream(
|
||||
self, input_value: AsyncIterator[ModelOutput]
|
||||
) -> ModelOutput:
|
||||
"""Transform the input value to a ModelOutput.
|
||||
|
||||
Args:
|
||||
input_value (ModelOutput): The input value.
|
||||
|
||||
Returns:
|
||||
ModelOutput: The transformed ModelOutput.
|
||||
"""
|
||||
full_text = ""
|
||||
async for model_output in input_value:
|
||||
# Now model_output.text if full text, if it is a delta text, we should merge all delta text to a full text
|
||||
full_text = model_output.text
|
||||
yield model_output
|
||||
# Get the storage conversation from share data
|
||||
storage_conv: StorageConversation = await self.get_storage_conversation()
|
||||
storage_conv.add_ai_message(full_text)
|
||||
# Save the storage conversation to share data, for the child operators
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_STORAGE_CONVERSATION, storage_conv
|
||||
)
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_MODEL_REQUEST_CONTEXT, input_value
|
||||
)
|
||||
# Get history messages from storage
|
||||
history_messages: List[BaseMessage] = storage_conv.get_history_message(
|
||||
include_system_message=self._include_system_message
|
||||
)
|
||||
return history_messages
|
||||
|
||||
|
||||
class ConversationMapperOperator(
|
||||
BaseConversationOperator, MapOperator[ModelRequest, ModelRequest]
|
||||
BaseConversationOperator, MapOperator[List[BaseMessage], List[BaseMessage]]
|
||||
):
|
||||
def __init__(self, message_mapper: _MultiRoundMessageMapper = None, **kwargs):
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
self._message_mapper = message_mapper
|
||||
|
||||
async def map(self, input_value: ModelRequest) -> ModelRequest:
|
||||
"""Map the input value to a ModelRequest.
|
||||
async def map(self, input_value: List[BaseMessage]) -> List[BaseMessage]:
|
||||
return self.map_messages(input_value)
|
||||
|
||||
Args:
|
||||
input_value (ModelRequest): The input value.
|
||||
|
||||
Returns:
|
||||
ModelRequest: The mapped ModelRequest.
|
||||
"""
|
||||
input_value = input_value.copy()
|
||||
messages: List[ModelMessage] = self.map_messages(input_value.messages)
|
||||
# Overwrite the input value
|
||||
input_value.messages = messages
|
||||
return input_value
|
||||
|
||||
def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
|
||||
"""Map the input messages to a list of ModelMessage.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): The input messages.
|
||||
|
||||
Returns:
|
||||
List[ModelMessage]: The mapped ModelMessage.
|
||||
"""
|
||||
messages_by_round: List[List[ModelMessage]] = self._split_messages_by_round(
|
||||
def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
|
||||
messages_by_round: List[List[BaseMessage]] = self._split_messages_by_round(
|
||||
messages
|
||||
)
|
||||
message_mapper = self._message_mapper or self.map_multi_round_messages
|
||||
return message_mapper(messages_by_round)
|
||||
|
||||
def map_multi_round_messages(
|
||||
self, messages_by_round: List[List[ModelMessage]]
|
||||
) -> List[ModelMessage]:
|
||||
"""Map multi round messages to a list of ModelMessage
|
||||
self, messages_by_round: List[List[BaseMessage]]
|
||||
) -> List[BaseMessage]:
|
||||
"""Map multi round messages to a list of BaseMessage.
|
||||
|
||||
By default, just merge all multi round messages to a list of ModelMessage according origin order.
|
||||
By default, just merge all multi round messages to a list of BaseMessage according origin order.
|
||||
And you can overwrite this method to implement your own logic.
|
||||
|
||||
Examples:
|
||||
|
||||
Merge multi round messages to a list of ModelMessage according origin order.
|
||||
Merge multi round messages to a list of BaseMessage according origin order.
|
||||
|
||||
.. code-block:: python
|
||||
>>> from dbgpt.core.interface.message import (
|
||||
... AIMessage,
|
||||
... HumanMessage,
|
||||
... SystemMessage,
|
||||
... )
|
||||
>>> messages_by_round = [
|
||||
... [
|
||||
... HumanMessage(content="Hi", round_index=1),
|
||||
... AIMessage(content="Hello!", round_index=1),
|
||||
... ],
|
||||
... [
|
||||
... HumanMessage(content="What's the error?", round_index=2),
|
||||
... AIMessage(content="Just a joke.", round_index=2),
|
||||
... ],
|
||||
... ]
|
||||
>>> operator = ConversationMapperOperator()
|
||||
>>> messages = operator.map_multi_round_messages(messages_by_round)
|
||||
>>> assert messages == [
|
||||
... HumanMessage(content="Hi", round_index=1),
|
||||
... AIMessage(content="Hello!", round_index=1),
|
||||
... HumanMessage(content="What's the error?", round_index=2),
|
||||
... AIMessage(content="Just a joke.", round_index=2),
|
||||
... ]
|
||||
|
||||
import asyncio
|
||||
from dbgpt.core.operator import ConversationMapperOperator
|
||||
Map multi round messages to a list of BaseMessage just keep the last one round.
|
||||
|
||||
messages_by_round = [
|
||||
[
|
||||
ModelMessage(role="human", content="Hi", round_index=1),
|
||||
ModelMessage(role="ai", content="Hello!", round_index=1),
|
||||
],
|
||||
[
|
||||
ModelMessage(role="system", content="Error 404", round_index=2),
|
||||
ModelMessage(
|
||||
role="human", content="What's the error?", round_index=2
|
||||
),
|
||||
ModelMessage(role="ai", content="Just a joke.", round_index=2),
|
||||
],
|
||||
[
|
||||
ModelMessage(role="human", content="Funny!", round_index=3),
|
||||
],
|
||||
]
|
||||
operator = ConversationMapperOperator()
|
||||
messages = operator.map_multi_round_messages(messages_by_round)
|
||||
assert messages == [
|
||||
ModelMessage(role="human", content="Hi", round_index=1),
|
||||
ModelMessage(role="ai", content="Hello!", round_index=1),
|
||||
ModelMessage(role="system", content="Error 404", round_index=2),
|
||||
ModelMessage(
|
||||
role="human", content="What's the error?", round_index=2
|
||||
),
|
||||
ModelMessage(role="ai", content="Just a joke.", round_index=2),
|
||||
ModelMessage(role="human", content="Funny!", round_index=3),
|
||||
]
|
||||
|
||||
Map multi round messages to a list of ModelMessage just keep the last one round.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyMapper(ConversationMapperOperator):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def map_multi_round_messages(
|
||||
self, messages_by_round: List[List[ModelMessage]]
|
||||
) -> List[ModelMessage]:
|
||||
return messages_by_round[-1]
|
||||
|
||||
|
||||
operator = MyMapper()
|
||||
messages = operator.map_multi_round_messages(messages_by_round)
|
||||
assert messages == [
|
||||
ModelMessage(role="human", content="Funny!", round_index=3),
|
||||
]
|
||||
>>> class MyMapper(ConversationMapperOperator):
|
||||
... def __init__(self, **kwargs):
|
||||
... super().__init__(**kwargs)
|
||||
...
|
||||
... def map_multi_round_messages(
|
||||
... self, messages_by_round: List[List[BaseMessage]]
|
||||
... ) -> List[BaseMessage]:
|
||||
... return messages_by_round[-1]
|
||||
...
|
||||
>>> operator = MyMapper()
|
||||
>>> messages = operator.map_multi_round_messages(messages_by_round)
|
||||
>>> assert messages == [
|
||||
... HumanMessage(content="What's the error?", round_index=2),
|
||||
... AIMessage(content="Just a joke.", round_index=2),
|
||||
... ]
|
||||
|
||||
Args:
|
||||
"""
|
||||
@@ -340,17 +237,17 @@ class ConversationMapperOperator(
|
||||
return sum(messages_by_round, [])
|
||||
|
||||
def _split_messages_by_round(
|
||||
self, messages: List[ModelMessage]
|
||||
) -> List[List[ModelMessage]]:
|
||||
self, messages: List[BaseMessage]
|
||||
) -> List[List[BaseMessage]]:
|
||||
"""Split the messages by round index.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): The input messages.
|
||||
messages (List[BaseMessage]): The messages.
|
||||
|
||||
Returns:
|
||||
List[List[ModelMessage]]: The split messages.
|
||||
List[List[BaseMessage]]: The messages split by round.
|
||||
"""
|
||||
messages_by_round: List[List[ModelMessage]] = []
|
||||
messages_by_round: List[List[BaseMessage]] = []
|
||||
last_round_index = 0
|
||||
for message in messages:
|
||||
if not message.round_index:
|
||||
@@ -366,7 +263,7 @@ class ConversationMapperOperator(
|
||||
class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
"""The buffered conversation mapper operator.
|
||||
|
||||
This Operator must be used after the PreConversationOperator,
|
||||
This Operator must be used after the PreChatHistoryLoadOperator,
|
||||
and it will map the messages in the storage conversation.
|
||||
|
||||
Examples:
|
||||
@@ -419,8 +316,8 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
if message_mapper:
|
||||
|
||||
def new_message_mapper(
|
||||
messages_by_round: List[List[ModelMessage]],
|
||||
) -> List[ModelMessage]:
|
||||
messages_by_round: List[List[BaseMessage]],
|
||||
) -> List[BaseMessage]:
|
||||
# Apply keep k round messages first, then apply the custom message mapper
|
||||
messages_by_round = self._keep_last_round_messages(messages_by_round)
|
||||
return message_mapper(messages_by_round)
|
||||
@@ -428,23 +325,23 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
else:
|
||||
|
||||
def new_message_mapper(
|
||||
messages_by_round: List[List[ModelMessage]],
|
||||
) -> List[ModelMessage]:
|
||||
messages_by_round: List[List[BaseMessage]],
|
||||
) -> List[BaseMessage]:
|
||||
messages_by_round = self._keep_last_round_messages(messages_by_round)
|
||||
return sum(messages_by_round, [])
|
||||
|
||||
super().__init__(new_message_mapper, **kwargs)
|
||||
|
||||
def _keep_last_round_messages(
|
||||
self, messages_by_round: List[List[ModelMessage]]
|
||||
) -> List[List[ModelMessage]]:
|
||||
self, messages_by_round: List[List[BaseMessage]]
|
||||
) -> List[List[BaseMessage]]:
|
||||
"""Keep the last k round messages.
|
||||
|
||||
Args:
|
||||
messages_by_round (List[List[ModelMessage]]): The messages by round.
|
||||
messages_by_round (List[List[BaseMessage]]): The messages by round.
|
||||
|
||||
Returns:
|
||||
List[List[ModelMessage]]: The latest round messages.
|
||||
List[List[BaseMessage]]: The latest round messages.
|
||||
"""
|
||||
index = self._last_k_round + 1
|
||||
return messages_by_round[-index:]
|
||||
|
255
dbgpt/core/interface/operator/prompt_operator.py
Normal file
255
dbgpt/core/interface/operator/prompt_operator.py
Normal file
@@ -0,0 +1,255 @@
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt.core import (
|
||||
BasePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
ModelMessage,
|
||||
ModelMessageRoleType,
|
||||
ModelOutput,
|
||||
StorageConversation,
|
||||
)
|
||||
from dbgpt.core.awel import JoinOperator, MapOperator
|
||||
from dbgpt.core.interface.message import BaseMessage
|
||||
from dbgpt.core.interface.operator.llm_operator import BaseLLM
|
||||
from dbgpt.core.interface.operator.message_operator import BaseConversationOperator
|
||||
from dbgpt.core.interface.prompt import HumanPromptTemplate, MessageType
|
||||
from dbgpt.util.function_utils import rearrange_args_by_type
|
||||
|
||||
|
||||
class BasePromptBuilderOperator(BaseConversationOperator, ABC):
|
||||
"""The base prompt builder operator."""
|
||||
|
||||
def __init__(self, check_storage: bool, **kwargs):
|
||||
super().__init__(check_storage=check_storage, **kwargs)
|
||||
|
||||
async def format_prompt(
|
||||
self, prompt: ChatPromptTemplate, prompt_dict: Dict[str, Any]
|
||||
) -> List[ModelMessage]:
|
||||
"""Format the prompt.
|
||||
|
||||
Args:
|
||||
prompt (ChatPromptTemplate): The prompt.
|
||||
prompt_dict (Dict[str, Any]): The prompt dict.
|
||||
|
||||
Returns:
|
||||
List[ModelMessage]: The formatted prompt.
|
||||
"""
|
||||
kwargs = {}
|
||||
kwargs.update(prompt_dict)
|
||||
pass_kwargs = {k: v for k, v in kwargs.items() if k in prompt.input_variables}
|
||||
messages = prompt.format_messages(**pass_kwargs)
|
||||
messages = ModelMessage.from_base_messages(messages)
|
||||
# Start new round conversation, and save user message to storage
|
||||
await self.start_new_round_conv(messages)
|
||||
return messages
|
||||
|
||||
async def start_new_round_conv(self, messages: List[ModelMessage]) -> None:
|
||||
"""Start a new round conversation.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): The messages.
|
||||
"""
|
||||
|
||||
lass_user_message = None
|
||||
for message in messages[::-1]:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
lass_user_message = message.content
|
||||
break
|
||||
if not lass_user_message:
|
||||
raise ValueError("No user message")
|
||||
storage_conv: StorageConversation = await self.get_storage_conversation()
|
||||
if not storage_conv:
|
||||
return
|
||||
# Start new round
|
||||
storage_conv.start_new_round()
|
||||
storage_conv.add_user_message(lass_user_message)
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""The callback after DAG end"""
|
||||
# TODO remove this to start_new_round()
|
||||
# Save the storage conversation to storage after the whole DAG finished
|
||||
storage_conv: StorageConversation = await self.get_storage_conversation()
|
||||
if not storage_conv:
|
||||
return
|
||||
model_output: ModelOutput = await self.current_dag_context.get_from_share_data(
|
||||
BaseLLM.SHARE_DATA_KEY_MODEL_OUTPUT
|
||||
)
|
||||
if model_output:
|
||||
# Save model output message to storage
|
||||
storage_conv.add_ai_message(model_output.text)
|
||||
# End current conversation round and flush to storage
|
||||
storage_conv.end_current_round()
|
||||
|
||||
|
||||
PromptTemplateType = Union[ChatPromptTemplate, BasePromptTemplate, MessageType, str]
|
||||
|
||||
|
||||
class PromptBuilderOperator(
|
||||
BasePromptBuilderOperator, MapOperator[Dict[str, Any], List[ModelMessage]]
|
||||
):
|
||||
"""The operator to build the prompt with static prompt.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from dbgpt.core.awel import DAG
|
||||
from dbgpt.core import (
|
||||
ModelMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
HumanPromptTemplate,
|
||||
SystemPromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
)
|
||||
from dbgpt.core.operator import PromptBuilderOperator
|
||||
|
||||
with DAG("prompt_test") as dag:
|
||||
str_prompt = PromptBuilderOperator(
|
||||
"Please write a {dialect} SQL count the length of a field"
|
||||
)
|
||||
tp_prompt = PromptBuilderOperator(
|
||||
HumanPromptTemplate.from_template(
|
||||
"Please write a {dialect} SQL count the length of a field"
|
||||
)
|
||||
)
|
||||
chat_prompt = PromptBuilderOperator(
|
||||
ChatPromptTemplate(
|
||||
messages=[
|
||||
HumanPromptTemplate.from_template(
|
||||
"Please write a {dialect} SQL count the length of a field"
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
with_sys_prompt = PromptBuilderOperator(
|
||||
ChatPromptTemplate(
|
||||
messages=[
|
||||
SystemPromptTemplate.from_template(
|
||||
"You are a {dialect} SQL expert"
|
||||
),
|
||||
HumanPromptTemplate.from_template(
|
||||
"Please write a {dialect} SQL count the length of a field"
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
single_input = {"data": {"dialect": "mysql"}}
|
||||
single_expected_messages = [
|
||||
ModelMessage(
|
||||
content="Please write a mysql SQL count the length of a field",
|
||||
role="human",
|
||||
)
|
||||
]
|
||||
with_sys_expected_messages = [
|
||||
ModelMessage(content="You are a mysql SQL expert", role="system"),
|
||||
ModelMessage(
|
||||
content="Please write a mysql SQL count the length of a field",
|
||||
role="human",
|
||||
),
|
||||
]
|
||||
assert (
|
||||
asyncio.run(str_prompt.call(call_data=single_input))
|
||||
== single_expected_messages
|
||||
)
|
||||
assert (
|
||||
asyncio.run(tp_prompt.call(call_data=single_input))
|
||||
== single_expected_messages
|
||||
)
|
||||
assert (
|
||||
asyncio.run(chat_prompt.call(call_data=single_input))
|
||||
== single_expected_messages
|
||||
)
|
||||
assert (
|
||||
asyncio.run(with_sys_prompt.call(call_data=single_input))
|
||||
== with_sys_expected_messages
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, prompt: PromptTemplateType, **kwargs):
|
||||
if isinstance(prompt, str):
|
||||
prompt = ChatPromptTemplate(
|
||||
messages=[HumanPromptTemplate.from_template(prompt)]
|
||||
)
|
||||
elif isinstance(prompt, BasePromptTemplate) and not isinstance(
|
||||
prompt, ChatPromptTemplate
|
||||
):
|
||||
prompt = ChatPromptTemplate(
|
||||
messages=[HumanPromptTemplate.from_template(prompt.template)]
|
||||
)
|
||||
elif isinstance(prompt, MessageType):
|
||||
prompt = ChatPromptTemplate(messages=[prompt])
|
||||
self._prompt = prompt
|
||||
|
||||
super().__init__(check_storage=False, **kwargs)
|
||||
MapOperator.__init__(self, map_function=self.merge_prompt, **kwargs)
|
||||
|
||||
@rearrange_args_by_type
|
||||
async def merge_prompt(self, prompt_dict: Dict[str, Any]) -> List[ModelMessage]:
|
||||
return await self.format_prompt(self._prompt, prompt_dict)
|
||||
|
||||
|
||||
class DynamicPromptBuilderOperator(
|
||||
BasePromptBuilderOperator, JoinOperator[List[ModelMessage]]
|
||||
):
|
||||
"""The operator to build the prompt with dynamic prompt.
|
||||
|
||||
The prompt template is dynamic, and it created by parent operator.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(check_storage=False, **kwargs)
|
||||
JoinOperator.__init__(self, combine_function=self.merge_prompt, **kwargs)
|
||||
|
||||
@rearrange_args_by_type
|
||||
async def merge_prompt(
|
||||
self, prompt: ChatPromptTemplate, prompt_dict: Dict[str, Any]
|
||||
) -> List[ModelMessage]:
|
||||
return await self.format_prompt(prompt, prompt_dict)
|
||||
|
||||
|
||||
class HistoryPromptBuilderOperator(
|
||||
BasePromptBuilderOperator, JoinOperator[List[ModelMessage]]
|
||||
):
|
||||
def __init__(
|
||||
self, prompt: ChatPromptTemplate, history_key: Optional[str] = None, **kwargs
|
||||
):
|
||||
self._prompt = prompt
|
||||
self._history_key = history_key
|
||||
|
||||
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
|
||||
|
||||
@rearrange_args_by_type
|
||||
async def merge_history(
|
||||
self, history: List[BaseMessage], prompt_dict: Dict[str, Any]
|
||||
) -> List[ModelMessage]:
|
||||
prompt_dict[self._history_key] = history
|
||||
return await self.format_prompt(self._prompt, prompt_dict)
|
||||
|
||||
|
||||
class HistoryDynamicPromptBuilderOperator(
|
||||
BasePromptBuilderOperator, JoinOperator[List[ModelMessage]]
|
||||
):
|
||||
"""The operator to build the prompt with dynamic prompt.
|
||||
|
||||
The prompt template is dynamic, and it created by parent operator.
|
||||
"""
|
||||
|
||||
def __init__(self, history_key: Optional[str] = None, **kwargs):
|
||||
self._history_key = history_key
|
||||
|
||||
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
|
||||
|
||||
@rearrange_args_by_type
|
||||
async def merge_history(
|
||||
self,
|
||||
prompt: ChatPromptTemplate,
|
||||
history: List[BaseMessage],
|
||||
prompt_dict: Dict[str, Any],
|
||||
) -> List[ModelMessage]:
|
||||
prompt_dict[self._history_key] = history
|
||||
return await self.format_prompt(prompt, prompt_dict)
|
@@ -1,11 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from string import Formatter
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt._private.pydantic import BaseModel, root_validator
|
||||
from dbgpt.core._private.example_base import ExampleSelector
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.interface.message import BaseMessage, HumanMessage, SystemMessage
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser
|
||||
from dbgpt.core.interface.storage import (
|
||||
InMemoryStorage,
|
||||
@@ -38,15 +41,40 @@ _DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
|
||||
}
|
||||
|
||||
|
||||
class PromptTemplate(BaseModel, ABC):
|
||||
class BasePromptTemplate(BaseModel):
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
|
||||
template: Optional[str]
|
||||
"""The prompt template."""
|
||||
|
||||
template_format: Optional[str] = "f-string"
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs."""
|
||||
if self.template:
|
||||
return _DEFAULT_FORMATTER_MAPPING[self.template_format](True)(
|
||||
self.template, **kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls, template: str, template_format: Optional[str] = "f-string", **kwargs: Any
|
||||
) -> BasePromptTemplate:
|
||||
"""Create a prompt template from a template string."""
|
||||
input_variables = get_template_vars(template, template_format)
|
||||
return cls(
|
||||
template=template,
|
||||
input_variables=input_variables,
|
||||
template_format=template_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class PromptTemplate(BasePromptTemplate):
|
||||
template_scene: Optional[str]
|
||||
template_define: Optional[str]
|
||||
"""this template define"""
|
||||
template: Optional[str]
|
||||
"""The prompt template."""
|
||||
template_format: str = "f-string"
|
||||
"""strict template will check template args"""
|
||||
template_is_strict: bool = True
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
@@ -86,12 +114,114 @@ class PromptTemplate(BaseModel, ABC):
|
||||
self.template_is_strict
|
||||
)(self.template, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def from_template(template: str) -> "PromptTemplateOperator":
|
||||
|
||||
class BaseChatPromptTemplate(BaseModel, ABC):
|
||||
prompt: BasePromptTemplate
|
||||
|
||||
@property
|
||||
def input_variables(self) -> List[str]:
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
return self.prompt.input_variables
|
||||
|
||||
@abstractmethod
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format the prompt with the inputs."""
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls, template: str, template_format: Optional[str] = "f-string", **kwargs: Any
|
||||
) -> BaseChatPromptTemplate:
|
||||
"""Create a prompt template from a template string."""
|
||||
return PromptTemplateOperator(
|
||||
PromptTemplate(template=template, input_variables=[])
|
||||
)
|
||||
prompt = BasePromptTemplate.from_template(template, template_format)
|
||||
return cls(prompt=prompt, **kwargs)
|
||||
|
||||
|
||||
class SystemPromptTemplate(BaseChatPromptTemplate):
|
||||
"""The system prompt template."""
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
content = self.prompt.format(**kwargs)
|
||||
return [SystemMessage(content=content)]
|
||||
|
||||
|
||||
class HumanPromptTemplate(BaseChatPromptTemplate):
|
||||
"""The human prompt template."""
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
content = self.prompt.format(**kwargs)
|
||||
return [HumanMessage(content=content)]
|
||||
|
||||
|
||||
class MessagesPlaceholder(BaseChatPromptTemplate):
|
||||
"""The messages placeholder template.
|
||||
|
||||
Mostly used for the chat history.
|
||||
"""
|
||||
|
||||
variable_name: str
|
||||
prompt: BasePromptTemplate = None
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
messages = kwargs.get(self.variable_name, [])
|
||||
if not isinstance(messages, list):
|
||||
raise ValueError(
|
||||
f"Unsupported messages type: {type(messages)}, should be list."
|
||||
)
|
||||
for message in messages:
|
||||
if not isinstance(message, BaseMessage):
|
||||
raise ValueError(
|
||||
f"Unsupported message type: {type(message)}, should be BaseMessage."
|
||||
)
|
||||
return messages
|
||||
|
||||
@property
|
||||
def input_variables(self) -> List[str]:
|
||||
"""A list of the names of the variables the prompt template expects.
|
||||
|
||||
Returns:
|
||||
List[str]: The input variables.
|
||||
"""
|
||||
return [self.variable_name]
|
||||
|
||||
|
||||
MessageType = Union[BaseChatPromptTemplate, BaseMessage]
|
||||
|
||||
|
||||
class ChatPromptTemplate(BasePromptTemplate):
|
||||
messages: List[MessageType]
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format the prompt with the inputs."""
|
||||
result_messages = []
|
||||
for message in self.messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
result_messages.append(message)
|
||||
elif isinstance(message, BaseChatPromptTemplate):
|
||||
pass_kwargs = {
|
||||
k: v for k, v in kwargs.items() if k in message.input_variables
|
||||
}
|
||||
result_messages.extend(message.format_messages(**pass_kwargs))
|
||||
elif isinstance(message, MessagesPlaceholder):
|
||||
pass_kwargs = {
|
||||
k: v for k, v in kwargs.items() if k in message.input_variables
|
||||
}
|
||||
result_messages.extend(message.format_messages(**pass_kwargs))
|
||||
else:
|
||||
raise ValueError(f"Unsupported message type: {type(message)}")
|
||||
return result_messages
|
||||
|
||||
@root_validator(pre=True)
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre-fill the messages."""
|
||||
input_variables = values.get("input_variables", {})
|
||||
messages = values.get("messages", [])
|
||||
if not input_variables:
|
||||
input_variables = set()
|
||||
for message in messages:
|
||||
if isinstance(message, BaseChatPromptTemplate):
|
||||
input_variables.update(message.input_variables)
|
||||
values["input_variables"] = sorted(input_variables)
|
||||
return values
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -547,10 +677,36 @@ class PromptManager:
|
||||
self.storage.delete(identifier)
|
||||
|
||||
|
||||
class PromptTemplateOperator(MapOperator[Dict, str]):
|
||||
def __init__(self, prompt_template: PromptTemplate, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._prompt_template = prompt_template
|
||||
def _get_string_template_vars(template_str: str) -> Set[str]:
|
||||
"""Get template variables from a template string."""
|
||||
variables = set()
|
||||
formatter = Formatter()
|
||||
|
||||
async def map(self, input_value: Dict) -> str:
|
||||
return self._prompt_template.format(**input_value)
|
||||
for _, variable_name, _, _ in formatter.parse(template_str):
|
||||
if variable_name:
|
||||
variables.add(variable_name)
|
||||
|
||||
return variables
|
||||
|
||||
|
||||
def _get_jinja2_template_vars(template_str: str) -> Set[str]:
|
||||
"""Get template variables from a template string."""
|
||||
from jinja2 import Environment, meta
|
||||
|
||||
env = Environment()
|
||||
ast = env.parse(template_str)
|
||||
variables = meta.find_undeclared_variables(ast)
|
||||
return variables
|
||||
|
||||
|
||||
def get_template_vars(
|
||||
template_str: str, template_format: str = "f-string"
|
||||
) -> List[str]:
|
||||
"""Get template variables from a template string."""
|
||||
if template_format == "f-string":
|
||||
result = _get_string_template_vars(template_str)
|
||||
elif template_format == "jinja2":
|
||||
result = _get_jinja2_template_vars(template_str)
|
||||
else:
|
||||
raise ValueError(f"Unsupported template format: {template_format}")
|
||||
return sorted(result)
|
||||
|
@@ -413,13 +413,18 @@ def test_to_openai_messages(
|
||||
{"role": "user", "content": human_model_message.content},
|
||||
]
|
||||
|
||||
|
||||
def test_to_openai_messages_convert_to_compatible_format(
|
||||
human_model_message, ai_model_message, system_model_message
|
||||
):
|
||||
shuffle_messages = ModelMessage.to_openai_messages(
|
||||
[
|
||||
system_model_message,
|
||||
human_model_message,
|
||||
human_model_message,
|
||||
ai_model_message,
|
||||
]
|
||||
],
|
||||
convert_to_compatible_format=True,
|
||||
)
|
||||
assert shuffle_messages == [
|
||||
{"role": "system", "content": system_model_message.content},
|
||||
|
@@ -99,12 +99,6 @@ class TestPromptTemplate:
|
||||
formatted_output = prompt.format(response="hello")
|
||||
assert "Response: " in formatted_output
|
||||
|
||||
def test_from_template(self):
|
||||
template_str = "Hello {name}"
|
||||
prompt = PromptTemplate.from_template(template_str)
|
||||
assert prompt._prompt_template.template == template_str
|
||||
assert prompt._prompt_template.input_variables == []
|
||||
|
||||
def test_format_missing_variable(self):
|
||||
template_str = "Hello {name}"
|
||||
prompt = PromptTemplate(
|
||||
|
Reference in New Issue
Block a user