DB-GPT/dbgpt/core/interface/prompt.py
2024-08-29 19:39:42 +08:00

794 lines
26 KiB
Python

"""The prompt template interface."""
from __future__ import annotations
import dataclasses
import json
from abc import ABC, abstractmethod
from string import Formatter
from typing import Any, Callable, Dict, List, Optional, Set, Type, TypeVar, Union
from dbgpt._private.pydantic import BaseModel, ConfigDict, model_validator
from dbgpt.core.interface.message import BaseMessage, HumanMessage, SystemMessage
from dbgpt.core.interface.storage import (
InMemoryStorage,
QuerySpec,
ResourceIdentifier,
StorageInterface,
StorageItem,
)
from dbgpt.util.formatting import formatter, no_strict_formatter
T = TypeVar("T", bound="BasePromptTemplate")
def _jinja2_formatter(template: str, **kwargs: Any) -> str:
"""Format a template using jinja2."""
try:
from jinja2 import Template
except ImportError:
raise ImportError(
"jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`."
)
return Template(template).render(**kwargs)
_DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
"f-string": lambda is_strict: (
formatter.format if is_strict else no_strict_formatter.format
),
"jinja2": lambda is_strict: _jinja2_formatter,
}
class BasePromptTemplate(BaseModel):
"""Base class for all prompt templates, returning a prompt."""
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
class PromptTemplate(BasePromptTemplate):
"""Prompt template."""
model_config = ConfigDict(arbitrary_types_allowed=True)
template: str
"""The prompt template."""
template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
response_key: str = "response"
template_is_strict: bool = True
"""strict template will check template args"""
response_format: Optional[str] = None
template_scene: Optional[str] = None
template_define: Optional[str] = None
"""this template define"""
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
return "prompt"
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs."""
if self.response_format:
kwargs[self.response_key] = json.dumps(
self.response_format, ensure_ascii=False, indent=4
)
return _DEFAULT_FORMATTER_MAPPING[self.template_format](
self.template_is_strict
)(self.template, **kwargs)
@classmethod
def from_template(
cls: Type[T], template: str, template_format: str = "f-string", **kwargs: Any
) -> T:
"""Create a prompt template from a template string."""
input_variables = get_template_vars(template, template_format)
return cls(
template=template,
input_variables=input_variables,
template_format=template_format,
**kwargs,
)
class BaseChatPromptTemplate(BaseModel, ABC):
"""The base chat prompt template."""
prompt: BasePromptTemplate
@property
def input_variables(self) -> List[str]:
"""Return a list of the names of the variables the prompt template expects."""
return self.prompt.input_variables
@abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format the prompt with the inputs."""
@classmethod
def from_template(
cls: Type[T],
template: str,
template_format: str = "f-string",
response_format: Optional[str] = None,
response_key: str = "response",
template_is_strict: bool = True,
**kwargs: Any,
) -> T:
"""Create a prompt template from a template string."""
prompt = PromptTemplate.from_template(
template,
template_format,
response_format=response_format,
response_key=response_key,
template_is_strict=template_is_strict,
)
return cls(prompt=prompt, **kwargs)
class SystemPromptTemplate(BaseChatPromptTemplate):
"""The system prompt template."""
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format the prompt with the inputs.
Returns:
List[BaseMessage]: The formatted messages.
"""
content = self.prompt.format(**kwargs)
return [SystemMessage(content=content)]
class HumanPromptTemplate(BaseChatPromptTemplate):
"""The human prompt template."""
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format the prompt with the inputs.
Returns:
List[BaseMessage]: The formatted messages.
"""
content = self.prompt.format(**kwargs)
return [HumanMessage(content=content)]
class MessagesPlaceholder(BaseModel):
"""The messages placeholder template.
Mostly used for the chat history.
"""
variable_name: str
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format the prompt with the inputs.
Just return the messages from the kwargs with the variable name.
Returns:
List[BaseMessage]: The messages.
"""
messages = kwargs.get(self.variable_name, [])
if not isinstance(messages, list):
raise ValueError(
f"Unsupported messages type: {type(messages)}, should be list."
)
for message in messages:
if not isinstance(message, BaseMessage):
raise ValueError(
f"Unsupported message type: {type(message)}, should be BaseMessage."
)
return messages
@property
def input_variables(self) -> List[str]:
"""Return a list of the names of the variables the prompt template expects.
Returns:
List[str]: The input variables.
"""
return [self.variable_name]
MessageType = Union[BaseChatPromptTemplate, MessagesPlaceholder, BaseMessage]
class ChatPromptTemplate(BasePromptTemplate):
"""The chat prompt template.
Examples:
.. code-block:: python
prompt_template = ChatPromptTemplate(
messages=[
SystemPromptTemplate.from_template(
"You are a helpful AI assistant."
),
MessagesPlaceholder(variable_name="chat_history"),
HumanPromptTemplate.from_template("{question}"),
]
)
"""
messages: List[MessageType]
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format the prompt with the inputs."""
result_messages = []
for message in self.messages:
if isinstance(message, BaseMessage):
result_messages.append(message)
elif isinstance(message, (BaseChatPromptTemplate, MessagesPlaceholder)):
pass_kwargs = {
k: v for k, v in kwargs.items() if k in message.input_variables
}
result_messages.extend(message.format_messages(**pass_kwargs))
else:
raise ValueError(f"Unsupported message type: {type(message)}")
return result_messages
@model_validator(mode="before")
@classmethod
def base_pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre-fill the messages."""
if not isinstance(values, dict):
return values
input_variables = values.get("input_variables", {})
messages = values.get("messages", [])
if not input_variables:
input_variables = set()
for message in messages:
if isinstance(message, (BaseChatPromptTemplate, MessagesPlaceholder)):
input_variables.update(message.input_variables)
values["input_variables"] = sorted(input_variables)
return values
def get_placeholders(self) -> List[str]:
"""Get all placeholders in the prompt template.
Returns:
List[str]: The placeholders.
"""
placeholders = set()
for message in self.messages:
if isinstance(message, MessagesPlaceholder):
placeholders.add(message.variable_name)
return sorted(placeholders)
@dataclasses.dataclass
class PromptTemplateIdentifier(ResourceIdentifier):
"""The identifier of a prompt template."""
identifier_split: str = dataclasses.field(default="___$$$$___", init=False)
prompt_name: str
prompt_language: Optional[str] = None
sys_code: Optional[str] = None
model: Optional[str] = None
def __post_init__(self):
"""Post init method."""
if self.prompt_name is None:
raise ValueError("prompt_name cannot be None")
if any(
self.identifier_split in key
for key in [
self.prompt_name,
self.prompt_language,
self.sys_code,
self.model,
]
if key is not None
):
raise ValueError(
f"identifier_split {self.identifier_split} is not allowed in "
f"prompt_name, prompt_language, sys_code, model"
)
@property
def str_identifier(self) -> str:
"""Return the string identifier of the identifier."""
return self.identifier_split.join(
key
for key in [
self.prompt_name,
self.prompt_language,
self.sys_code,
self.model,
]
if key is not None
)
def to_dict(self) -> Dict:
"""Convert the identifier to a dict.
Returns:
Dict: The dict of the identifier.
"""
return {
"prompt_name": self.prompt_name,
"prompt_language": self.prompt_language,
"sys_code": self.sys_code,
"model": self.model,
}
@dataclasses.dataclass
class StoragePromptTemplate(StorageItem):
"""The storage prompt template."""
prompt_name: str
content: Optional[str] = None
prompt_language: Optional[str] = None
prompt_format: Optional[str] = None
input_variables: Optional[str] = None
model: Optional[str] = None
chat_scene: Optional[str] = None
sub_chat_scene: Optional[str] = None
prompt_type: Optional[str] = None
user_name: Optional[str] = None
sys_code: Optional[str] = None
_identifier: PromptTemplateIdentifier = dataclasses.field(init=False)
def __post_init__(self):
"""Post init method."""
self._identifier = PromptTemplateIdentifier(
prompt_name=self.prompt_name,
prompt_language=self.prompt_language,
sys_code=self.sys_code,
model=self.model,
)
# Assuming _check() is a method you need to call after initialization
self._check()
def to_prompt_template(self) -> PromptTemplate:
"""Convert the storage prompt template to a prompt template."""
input_variables = (
[] if not self.input_variables else self.input_variables.strip().split(",")
)
template_format = self.prompt_format or "f-string"
return PromptTemplate(
input_variables=input_variables,
template=self.content,
template_scene=self.chat_scene,
# prompt_name=self.prompt_name,
template_format=template_format,
)
@staticmethod
def from_prompt_template(
prompt_template: PromptTemplate,
prompt_name: str,
prompt_language: Optional[str] = None,
prompt_type: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
sub_chat_scene: Optional[str] = None,
model: Optional[str] = None,
**kwargs,
) -> "StoragePromptTemplate":
"""Convert a prompt template to a storage prompt template.
Args:
prompt_template (PromptTemplate): The prompt template to convert from.
prompt_name (str): The name of the prompt.
prompt_language (Optional[str], optional): The language of the prompt.
Defaults to None. e.g. zh-cn, en.
prompt_type (Optional[str], optional): The type of the prompt.
Defaults to None. e.g. common, private.
sys_code (Optional[str], optional): The system code of the prompt.
Defaults to None.
user_name (Optional[str], optional): The username of the prompt.
Defaults to None.
sub_chat_scene (Optional[str], optional): The sub chat scene of the prompt.
Defaults to None.
model (Optional[str], optional): The model name of the prompt.
Defaults to None.
kwargs (Dict): Other params to build the storage prompt template.
"""
input_variables = prompt_template.input_variables or kwargs.get(
"input_variables"
)
if input_variables and isinstance(input_variables, list):
input_variables = ",".join(input_variables)
return StoragePromptTemplate(
prompt_name=prompt_name,
sys_code=sys_code,
user_name=user_name,
input_variables=input_variables,
model=model,
content=prompt_template.template or kwargs.get("content"),
prompt_language=prompt_language,
prompt_format=prompt_template.template_format
or kwargs.get("prompt_format"),
chat_scene=prompt_template.template_scene or kwargs.get("chat_scene"),
sub_chat_scene=sub_chat_scene,
prompt_type=prompt_type,
)
@property
def identifier(self) -> PromptTemplateIdentifier:
"""Return the identifier of the storage prompt template."""
return self._identifier
def merge(self, other: "StorageItem") -> None:
"""Merge the other item into the current item.
Args:
other (StorageItem): The other item to merge
"""
if not isinstance(other, StoragePromptTemplate):
raise ValueError(
f"Cannot merge {type(other)} into {type(self)} because they are not "
f"the same type."
)
self.from_object(other)
def to_dict(self) -> Dict:
"""Convert the storage prompt template to a dict.
Returns:
Dict: The dict of the storage prompt template.
"""
return {
"prompt_name": self.prompt_name,
"content": self.content,
"prompt_language": self.prompt_language,
"prompt_format": self.prompt_format,
"input_variables": self.input_variables,
"model": self.model,
"chat_scene": self.chat_scene,
"sub_chat_scene": self.sub_chat_scene,
"prompt_type": self.prompt_type,
"user_name": self.user_name,
"sys_code": self.sys_code,
}
def _check(self):
if self.prompt_name is None:
raise ValueError("prompt_name cannot be None")
if self.content is None:
raise ValueError("content cannot be None")
def from_object(self, template: "StoragePromptTemplate") -> None:
"""Load the prompt template from an existing prompt template object.
Args:
template (PromptTemplate): The prompt template to load from.
"""
self.content = template.content
self.prompt_format = template.prompt_format
self.input_variables = template.input_variables
self.model = template.model
self.chat_scene = template.chat_scene
self.sub_chat_scene = template.sub_chat_scene
self.prompt_type = template.prompt_type
self.user_name = template.user_name
class PromptManager:
"""The manager class for prompt templates.
Simple wrapper for the storage interface.
Examples:
.. code-block:: python
# Default use InMemoryStorage
prompt_manager = PromptManager()
prompt_template = PromptTemplate(
template="hello {input}",
input_variables=["input"],
template_scene="chat_normal",
)
prompt_manager.save(prompt_template, prompt_name="hello")
prompt_template_list = prompt_manager.list()
prompt_template_list = prompt_manager.prefer_query("hello")
With a custom storage interface.
.. code-block:: python
from dbgpt.core.interface.storage import InMemoryStorage
prompt_manager = PromptManager(InMemoryStorage())
prompt_template = PromptTemplate(
template="hello {input}",
input_variables=["input"],
template_scene="chat_normal",
)
prompt_manager.save(prompt_template, prompt_name="hello")
prompt_template_list = prompt_manager.list()
prompt_template_list = prompt_manager.prefer_query("hello")
"""
def __init__(
self, storage: Optional[StorageInterface[StoragePromptTemplate, Any]] = None
):
"""Create a new prompt manager."""
if storage is None:
storage = InMemoryStorage()
self._storage = storage
@property
def storage(self) -> StorageInterface[StoragePromptTemplate, Any]:
"""Return the storage interface for prompt templates."""
return self._storage
def prefer_query(
self,
prompt_name: str,
sys_code: Optional[str] = None,
prefer_prompt_language: Optional[str] = None,
prefer_model: Optional[str] = None,
**kwargs,
) -> List[StoragePromptTemplate]:
"""Query prompt templates from storage with prefer params.
Sometimes, we want to query prompt templates with prefer params(e.g. some
language or some model).
This method will query prompt templates with prefer params first, if not found,
will query all prompt templates.
Examples:
Query a prompt template.
.. code-block:: python
prompt_template_list = prompt_manager.prefer_query("hello")
Query with sys_code and username.
.. code-block:: python
prompt_template_list = prompt_manager.prefer_query(
"hello", sys_code="sys_code", user_name="user_name"
)
Query with prefer prompt language.
.. code-block:: python
# First query with prompt name "hello" exactly.
# Second filter with prompt language "zh-cn", if not found, will return
# all prompt templates.
prompt_template_list = prompt_manager.prefer_query(
"hello", prefer_prompt_language="zh-cn"
)
Query with prefer model.
.. code-block:: python
# First query with prompt name "hello" exactly.
# Second filter with model "vicuna-13b-v1.5", if not found, will return
# all prompt templates.
prompt_template_list = prompt_manager.prefer_query(
"hello", prefer_model="vicuna-13b-v1.5"
)
Args:
prompt_name (str): The name of the prompt template.
sys_code (Optional[str], optional): The system code of the prompt template.
Defaults to None.
prefer_prompt_language (Optional[str], optional): The language of the
prompt template. Defaults to None.
prefer_model (Optional[str], optional): The model of the prompt template.
Defaults to None.
kwargs (Dict): Other query params(If some key and value not None, wo we
query it exactly).
"""
query_spec = QuerySpec(
conditions={
"prompt_name": prompt_name,
"sys_code": sys_code,
**kwargs,
}
)
queries: List[StoragePromptTemplate] = self.storage.query(
query_spec, StoragePromptTemplate
)
if not queries:
return []
if prefer_prompt_language:
prefer_prompt_language = prefer_prompt_language.lower()
temp_queries = [
query
for query in queries
if query.prompt_language
and query.prompt_language.lower() == prefer_prompt_language
]
if temp_queries:
queries = temp_queries
if prefer_model:
prefer_model = prefer_model.lower()
temp_queries = [
query
for query in queries
if query.model and query.model.lower() == prefer_model
]
if temp_queries:
queries = temp_queries
return queries
def save(self, prompt_template: PromptTemplate, prompt_name: str, **kwargs) -> None:
"""Save a prompt template to storage.
Examples:
.. code-block:: python
prompt_template = PromptTemplate(
template="hello {input}",
input_variables=["input"],
template_scene="chat_normal",
prompt_name="hello",
)
prompt_manager.save(prompt_template)
Save with sys_code and username.
.. code-block:: python
prompt_template = PromptTemplate(
template="hello {input}",
input_variables=["input"],
template_scene="chat_normal",
prompt_name="hello",
)
prompt_manager.save(
prompt_template, sys_code="sys_code", user_name="user_name"
)
Args:
prompt_template (PromptTemplate): The prompt template to save.
prompt_name (str): The name of the prompt template.
kwargs (Dict): Other params to build the storage prompt template.
More details in :meth:`~StoragePromptTemplate.from_prompt_template`.
"""
storage_prompt_template = StoragePromptTemplate.from_prompt_template(
prompt_template, prompt_name, **kwargs
)
self.storage.save(storage_prompt_template)
def query_or_save(
self, prompt_template: PromptTemplate, prompt_name: str, **kwargs
) -> StoragePromptTemplate:
"""Query a prompt template from storage, if not found, save it.
Args:
prompt_template (PromptTemplate): The prompt template to save.
prompt_name (str): The name of the prompt template.
kwargs (Dict): Other params to build the storage prompt template.
More details in :meth:`~StoragePromptTemplate.from_prompt_template`.
Returns:
StoragePromptTemplate: The storage prompt template.
"""
storage_prompt_template = StoragePromptTemplate.from_prompt_template(
prompt_template, prompt_name, **kwargs
)
exist_prompt_template = self.storage.load(
storage_prompt_template.identifier, StoragePromptTemplate
)
if exist_prompt_template:
return exist_prompt_template
self.save(prompt_template, prompt_name, **kwargs)
prompt = self.storage.load(
storage_prompt_template.identifier, StoragePromptTemplate
)
if not prompt:
raise ValueError("Can't read prompt from storage")
return prompt
def list(self, **kwargs) -> List[StoragePromptTemplate]:
"""List prompt templates from storage.
Examples:
List all prompt templates.
.. code-block:: python
all_prompt_templates = prompt_manager.list()
List with sys_code and username.
.. code-block:: python
templates = prompt_manager.list(
sys_code="sys_code", user_name="user_name"
)
Args:
kwargs (Dict): Other query params.
"""
query_spec = QuerySpec(conditions=kwargs)
return self.storage.query(query_spec, StoragePromptTemplate)
def delete(
self,
prompt_name: str,
prompt_language: Optional[str] = None,
sys_code: Optional[str] = None,
model: Optional[str] = None,
) -> None:
"""Delete a prompt template from storage.
Examples:
Delete a prompt template.
.. code-block:: python
prompt_manager.delete("hello")
Delete with sys_code and username.
.. code-block:: python
prompt_manager.delete(
"hello", sys_code="sys_code", user_name="user_name"
)
Args:
prompt_name (str): The name of the prompt template.
prompt_language (Optional[str], optional): The language of the prompt
template. Defaults to None.
sys_code (Optional[str], optional): The system code of the prompt template.
Defaults to None.
model (Optional[str], optional): The model of the prompt template.
Defaults to None.
"""
identifier = PromptTemplateIdentifier(
prompt_name=prompt_name,
prompt_language=prompt_language,
sys_code=sys_code,
model=model,
)
self.storage.delete(identifier)
def _get_string_template_vars(template_str: str) -> Set[str]:
"""Get template variables from a template string."""
variables = set()
formatter = Formatter()
for _, variable_name, _, _ in formatter.parse(template_str):
if variable_name:
variables.add(variable_name)
return variables
def _get_jinja2_template_vars(template_str: str) -> Set[str]:
"""Get template variables from a template string."""
from jinja2 import Environment, meta
env = Environment()
ast = env.parse(template_str)
variables = meta.find_undeclared_variables(ast)
return variables
def get_template_vars(
template_str: str, template_format: str = "f-string"
) -> List[str]:
"""Get template variables from a template string."""
if template_format == "f-string":
result = _get_string_template_vars(template_str)
elif template_format == "jinja2":
result = _get_jinja2_template_vars(template_str)
else:
raise ValueError(f"Unsupported template format: {template_format}")
return sorted(result)