mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-23 18:31:22 +00:00
feat(awel): New MessageConverter and more AWEL operators (#1039)
This commit is contained in:
@@ -1,11 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from string import Formatter
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt._private.pydantic import BaseModel, root_validator
|
||||
from dbgpt.core._private.example_base import ExampleSelector
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.interface.message import BaseMessage, HumanMessage, SystemMessage
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser
|
||||
from dbgpt.core.interface.storage import (
|
||||
InMemoryStorage,
|
||||
@@ -38,15 +41,40 @@ _DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
|
||||
}
|
||||
|
||||
|
||||
class PromptTemplate(BaseModel, ABC):
|
||||
class BasePromptTemplate(BaseModel):
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
|
||||
template: Optional[str]
|
||||
"""The prompt template."""
|
||||
|
||||
template_format: Optional[str] = "f-string"
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs."""
|
||||
if self.template:
|
||||
return _DEFAULT_FORMATTER_MAPPING[self.template_format](True)(
|
||||
self.template, **kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls, template: str, template_format: Optional[str] = "f-string", **kwargs: Any
|
||||
) -> BasePromptTemplate:
|
||||
"""Create a prompt template from a template string."""
|
||||
input_variables = get_template_vars(template, template_format)
|
||||
return cls(
|
||||
template=template,
|
||||
input_variables=input_variables,
|
||||
template_format=template_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class PromptTemplate(BasePromptTemplate):
|
||||
template_scene: Optional[str]
|
||||
template_define: Optional[str]
|
||||
"""this template define"""
|
||||
template: Optional[str]
|
||||
"""The prompt template."""
|
||||
template_format: str = "f-string"
|
||||
"""strict template will check template args"""
|
||||
template_is_strict: bool = True
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
@@ -86,12 +114,114 @@ class PromptTemplate(BaseModel, ABC):
|
||||
self.template_is_strict
|
||||
)(self.template, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def from_template(template: str) -> "PromptTemplateOperator":
|
||||
|
||||
class BaseChatPromptTemplate(BaseModel, ABC):
|
||||
prompt: BasePromptTemplate
|
||||
|
||||
@property
|
||||
def input_variables(self) -> List[str]:
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
return self.prompt.input_variables
|
||||
|
||||
@abstractmethod
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format the prompt with the inputs."""
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls, template: str, template_format: Optional[str] = "f-string", **kwargs: Any
|
||||
) -> BaseChatPromptTemplate:
|
||||
"""Create a prompt template from a template string."""
|
||||
return PromptTemplateOperator(
|
||||
PromptTemplate(template=template, input_variables=[])
|
||||
)
|
||||
prompt = BasePromptTemplate.from_template(template, template_format)
|
||||
return cls(prompt=prompt, **kwargs)
|
||||
|
||||
|
||||
class SystemPromptTemplate(BaseChatPromptTemplate):
|
||||
"""The system prompt template."""
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
content = self.prompt.format(**kwargs)
|
||||
return [SystemMessage(content=content)]
|
||||
|
||||
|
||||
class HumanPromptTemplate(BaseChatPromptTemplate):
|
||||
"""The human prompt template."""
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
content = self.prompt.format(**kwargs)
|
||||
return [HumanMessage(content=content)]
|
||||
|
||||
|
||||
class MessagesPlaceholder(BaseChatPromptTemplate):
|
||||
"""The messages placeholder template.
|
||||
|
||||
Mostly used for the chat history.
|
||||
"""
|
||||
|
||||
variable_name: str
|
||||
prompt: BasePromptTemplate = None
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
messages = kwargs.get(self.variable_name, [])
|
||||
if not isinstance(messages, list):
|
||||
raise ValueError(
|
||||
f"Unsupported messages type: {type(messages)}, should be list."
|
||||
)
|
||||
for message in messages:
|
||||
if not isinstance(message, BaseMessage):
|
||||
raise ValueError(
|
||||
f"Unsupported message type: {type(message)}, should be BaseMessage."
|
||||
)
|
||||
return messages
|
||||
|
||||
@property
|
||||
def input_variables(self) -> List[str]:
|
||||
"""A list of the names of the variables the prompt template expects.
|
||||
|
||||
Returns:
|
||||
List[str]: The input variables.
|
||||
"""
|
||||
return [self.variable_name]
|
||||
|
||||
|
||||
MessageType = Union[BaseChatPromptTemplate, BaseMessage]
|
||||
|
||||
|
||||
class ChatPromptTemplate(BasePromptTemplate):
|
||||
messages: List[MessageType]
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format the prompt with the inputs."""
|
||||
result_messages = []
|
||||
for message in self.messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
result_messages.append(message)
|
||||
elif isinstance(message, BaseChatPromptTemplate):
|
||||
pass_kwargs = {
|
||||
k: v for k, v in kwargs.items() if k in message.input_variables
|
||||
}
|
||||
result_messages.extend(message.format_messages(**pass_kwargs))
|
||||
elif isinstance(message, MessagesPlaceholder):
|
||||
pass_kwargs = {
|
||||
k: v for k, v in kwargs.items() if k in message.input_variables
|
||||
}
|
||||
result_messages.extend(message.format_messages(**pass_kwargs))
|
||||
else:
|
||||
raise ValueError(f"Unsupported message type: {type(message)}")
|
||||
return result_messages
|
||||
|
||||
@root_validator(pre=True)
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre-fill the messages."""
|
||||
input_variables = values.get("input_variables", {})
|
||||
messages = values.get("messages", [])
|
||||
if not input_variables:
|
||||
input_variables = set()
|
||||
for message in messages:
|
||||
if isinstance(message, BaseChatPromptTemplate):
|
||||
input_variables.update(message.input_variables)
|
||||
values["input_variables"] = sorted(input_variables)
|
||||
return values
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -547,10 +677,36 @@ class PromptManager:
|
||||
self.storage.delete(identifier)
|
||||
|
||||
|
||||
class PromptTemplateOperator(MapOperator[Dict, str]):
|
||||
def __init__(self, prompt_template: PromptTemplate, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._prompt_template = prompt_template
|
||||
def _get_string_template_vars(template_str: str) -> Set[str]:
|
||||
"""Get template variables from a template string."""
|
||||
variables = set()
|
||||
formatter = Formatter()
|
||||
|
||||
async def map(self, input_value: Dict) -> str:
|
||||
return self._prompt_template.format(**input_value)
|
||||
for _, variable_name, _, _ in formatter.parse(template_str):
|
||||
if variable_name:
|
||||
variables.add(variable_name)
|
||||
|
||||
return variables
|
||||
|
||||
|
||||
def _get_jinja2_template_vars(template_str: str) -> Set[str]:
|
||||
"""Get template variables from a template string."""
|
||||
from jinja2 import Environment, meta
|
||||
|
||||
env = Environment()
|
||||
ast = env.parse(template_str)
|
||||
variables = meta.find_undeclared_variables(ast)
|
||||
return variables
|
||||
|
||||
|
||||
def get_template_vars(
|
||||
template_str: str, template_format: str = "f-string"
|
||||
) -> List[str]:
|
||||
"""Get template variables from a template string."""
|
||||
if template_format == "f-string":
|
||||
result = _get_string_template_vars(template_str)
|
||||
elif template_format == "jinja2":
|
||||
result = _get_jinja2_template_vars(template_str)
|
||||
else:
|
||||
raise ValueError(f"Unsupported template format: {template_format}")
|
||||
return sorted(result)
|
||||
|
||||
Reference in New Issue
Block a user