feat(awel): New MessageConverter and more AWEL operators (#1039)

This commit is contained in:
Fangyin Cheng
2024-01-08 09:40:05 +08:00
committed by GitHub
parent 765fb181f6
commit e8861bd8fa
48 changed files with 2333 additions and 719 deletions

View File

@@ -1,11 +1,14 @@
from __future__ import annotations
import dataclasses
import json
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional
from string import Formatter
from typing import Any, Callable, Dict, List, Optional, Set, Union
from dbgpt._private.pydantic import BaseModel
from dbgpt._private.pydantic import BaseModel, root_validator
from dbgpt.core._private.example_base import ExampleSelector
from dbgpt.core.awel import MapOperator
from dbgpt.core.interface.message import BaseMessage, HumanMessage, SystemMessage
from dbgpt.core.interface.output_parser import BaseOutputParser
from dbgpt.core.interface.storage import (
InMemoryStorage,
@@ -38,15 +41,40 @@ _DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
}
class PromptTemplate(BaseModel, ABC):
class BasePromptTemplate(BaseModel):
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
template: Optional[str]
"""The prompt template."""
template_format: Optional[str] = "f-string"
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs."""
if self.template:
return _DEFAULT_FORMATTER_MAPPING[self.template_format](True)(
self.template, **kwargs
)
@classmethod
def from_template(
cls, template: str, template_format: Optional[str] = "f-string", **kwargs: Any
) -> BasePromptTemplate:
"""Create a prompt template from a template string."""
input_variables = get_template_vars(template, template_format)
return cls(
template=template,
input_variables=input_variables,
template_format=template_format,
**kwargs,
)
class PromptTemplate(BasePromptTemplate):
template_scene: Optional[str]
template_define: Optional[str]
"""this template define"""
template: Optional[str]
"""The prompt template."""
template_format: str = "f-string"
"""strict template will check template args"""
template_is_strict: bool = True
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
@@ -86,12 +114,114 @@ class PromptTemplate(BaseModel, ABC):
self.template_is_strict
)(self.template, **kwargs)
@staticmethod
def from_template(template: str) -> "PromptTemplateOperator":
class BaseChatPromptTemplate(BaseModel, ABC):
prompt: BasePromptTemplate
@property
def input_variables(self) -> List[str]:
"""A list of the names of the variables the prompt template expects."""
return self.prompt.input_variables
@abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format the prompt with the inputs."""
@classmethod
def from_template(
cls, template: str, template_format: Optional[str] = "f-string", **kwargs: Any
) -> BaseChatPromptTemplate:
"""Create a prompt template from a template string."""
return PromptTemplateOperator(
PromptTemplate(template=template, input_variables=[])
)
prompt = BasePromptTemplate.from_template(template, template_format)
return cls(prompt=prompt, **kwargs)
class SystemPromptTemplate(BaseChatPromptTemplate):
"""The system prompt template."""
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
content = self.prompt.format(**kwargs)
return [SystemMessage(content=content)]
class HumanPromptTemplate(BaseChatPromptTemplate):
"""The human prompt template."""
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
content = self.prompt.format(**kwargs)
return [HumanMessage(content=content)]
class MessagesPlaceholder(BaseChatPromptTemplate):
"""The messages placeholder template.
Mostly used for the chat history.
"""
variable_name: str
prompt: BasePromptTemplate = None
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
messages = kwargs.get(self.variable_name, [])
if not isinstance(messages, list):
raise ValueError(
f"Unsupported messages type: {type(messages)}, should be list."
)
for message in messages:
if not isinstance(message, BaseMessage):
raise ValueError(
f"Unsupported message type: {type(message)}, should be BaseMessage."
)
return messages
@property
def input_variables(self) -> List[str]:
"""A list of the names of the variables the prompt template expects.
Returns:
List[str]: The input variables.
"""
return [self.variable_name]
MessageType = Union[BaseChatPromptTemplate, BaseMessage]
class ChatPromptTemplate(BasePromptTemplate):
messages: List[MessageType]
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format the prompt with the inputs."""
result_messages = []
for message in self.messages:
if isinstance(message, BaseMessage):
result_messages.append(message)
elif isinstance(message, BaseChatPromptTemplate):
pass_kwargs = {
k: v for k, v in kwargs.items() if k in message.input_variables
}
result_messages.extend(message.format_messages(**pass_kwargs))
elif isinstance(message, MessagesPlaceholder):
pass_kwargs = {
k: v for k, v in kwargs.items() if k in message.input_variables
}
result_messages.extend(message.format_messages(**pass_kwargs))
else:
raise ValueError(f"Unsupported message type: {type(message)}")
return result_messages
@root_validator(pre=True)
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre-fill the messages."""
input_variables = values.get("input_variables", {})
messages = values.get("messages", [])
if not input_variables:
input_variables = set()
for message in messages:
if isinstance(message, BaseChatPromptTemplate):
input_variables.update(message.input_variables)
values["input_variables"] = sorted(input_variables)
return values
@dataclasses.dataclass
@@ -547,10 +677,36 @@ class PromptManager:
self.storage.delete(identifier)
class PromptTemplateOperator(MapOperator[Dict, str]):
def __init__(self, prompt_template: PromptTemplate, **kwargs: Any):
super().__init__(**kwargs)
self._prompt_template = prompt_template
def _get_string_template_vars(template_str: str) -> Set[str]:
"""Get template variables from a template string."""
variables = set()
formatter = Formatter()
async def map(self, input_value: Dict) -> str:
return self._prompt_template.format(**input_value)
for _, variable_name, _, _ in formatter.parse(template_str):
if variable_name:
variables.add(variable_name)
return variables
def _get_jinja2_template_vars(template_str: str) -> Set[str]:
"""Get template variables from a template string."""
from jinja2 import Environment, meta
env = Environment()
ast = env.parse(template_str)
variables = meta.find_undeclared_variables(ast)
return variables
def get_template_vars(
template_str: str, template_format: str = "f-string"
) -> List[str]:
"""Get template variables from a template string."""
if template_format == "f-string":
result = _get_string_template_vars(template_str)
elif template_format == "jinja2":
result = _get_jinja2_template_vars(template_str)
else:
raise ValueError(f"Unsupported template format: {template_format}")
return sorted(result)