mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-23 01:49:58 +00:00
chore: Add pylint for DB-GPT core lib (#1076)
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
"""The prompt template interface."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
@@ -7,9 +9,7 @@ from string import Formatter
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, root_validator
|
||||
from dbgpt.core._private.example_base import ExampleSelector
|
||||
from dbgpt.core.interface.message import BaseMessage, HumanMessage, SystemMessage
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser
|
||||
from dbgpt.core.interface.storage import (
|
||||
InMemoryStorage,
|
||||
QuerySpec,
|
||||
@@ -42,64 +42,32 @@ _DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
|
||||
|
||||
|
||||
class BasePromptTemplate(BaseModel):
|
||||
"""Base class for all prompt templates, returning a prompt."""
|
||||
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
|
||||
template: Optional[str]
|
||||
|
||||
class PromptTemplate(BasePromptTemplate):
|
||||
"""Prompt template."""
|
||||
|
||||
template: str
|
||||
"""The prompt template."""
|
||||
|
||||
template_format: Optional[str] = "f-string"
|
||||
template_format: str = "f-string"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
|
||||
response_key: str = "response"
|
||||
|
||||
template_is_strict: bool = True
|
||||
"""strict template will check template args"""
|
||||
|
||||
response_format: Optional[str] = None
|
||||
|
||||
response_key: Optional[str] = "response"
|
||||
template_scene: Optional[str] = None
|
||||
|
||||
template_is_strict: Optional[bool] = True
|
||||
"""strict template will check template args"""
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs."""
|
||||
if self.template:
|
||||
if self.response_format:
|
||||
kwargs[self.response_key] = json.dumps(
|
||||
self.response_format, ensure_ascii=False, indent=4
|
||||
)
|
||||
return _DEFAULT_FORMATTER_MAPPING[self.template_format](
|
||||
self.template_is_strict
|
||||
)(self.template, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls, template: str, template_format: Optional[str] = "f-string", **kwargs: Any
|
||||
) -> BasePromptTemplate:
|
||||
"""Create a prompt template from a template string."""
|
||||
input_variables = get_template_vars(template, template_format)
|
||||
return cls(
|
||||
template=template,
|
||||
input_variables=input_variables,
|
||||
template_format=template_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class PromptTemplate(BasePromptTemplate):
|
||||
template_scene: Optional[str]
|
||||
template_define: Optional[str]
|
||||
template_define: Optional[str] = None
|
||||
"""this template define"""
|
||||
"""default use stream out"""
|
||||
stream_out: bool = True
|
||||
""""""
|
||||
output_parser: BaseOutputParser = None
|
||||
""""""
|
||||
sep: str = "###"
|
||||
|
||||
example_selector: ExampleSelector = None
|
||||
|
||||
need_historical_messages: bool = False
|
||||
|
||||
temperature: float = 0.6
|
||||
max_new_tokens: int = 1024
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -111,13 +79,38 @@ class PromptTemplate(BasePromptTemplate):
|
||||
"""Return the prompt type key."""
|
||||
return "prompt"
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs."""
|
||||
if self.response_format:
|
||||
kwargs[self.response_key] = json.dumps(
|
||||
self.response_format, ensure_ascii=False, indent=4
|
||||
)
|
||||
return _DEFAULT_FORMATTER_MAPPING[self.template_format](
|
||||
self.template_is_strict
|
||||
)(self.template, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls, template: str, template_format: str = "f-string", **kwargs: Any
|
||||
) -> BasePromptTemplate:
|
||||
"""Create a prompt template from a template string."""
|
||||
input_variables = get_template_vars(template, template_format)
|
||||
return cls(
|
||||
template=template,
|
||||
input_variables=input_variables,
|
||||
template_format=template_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class BaseChatPromptTemplate(BaseModel, ABC):
|
||||
"""The base chat prompt template."""
|
||||
|
||||
prompt: BasePromptTemplate
|
||||
|
||||
@property
|
||||
def input_variables(self) -> List[str]:
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
"""Return a list of the names of the variables the prompt template expects."""
|
||||
return self.prompt.input_variables
|
||||
|
||||
@abstractmethod
|
||||
@@ -128,14 +121,14 @@ class BaseChatPromptTemplate(BaseModel, ABC):
|
||||
def from_template(
|
||||
cls,
|
||||
template: str,
|
||||
template_format: Optional[str] = "f-string",
|
||||
template_format: str = "f-string",
|
||||
response_format: Optional[str] = None,
|
||||
response_key: Optional[str] = "response",
|
||||
response_key: str = "response",
|
||||
template_is_strict: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> BaseChatPromptTemplate:
|
||||
"""Create a prompt template from a template string."""
|
||||
prompt = BasePromptTemplate.from_template(
|
||||
prompt = PromptTemplate.from_template(
|
||||
template,
|
||||
template_format,
|
||||
response_format=response_format,
|
||||
@@ -149,6 +142,11 @@ class SystemPromptTemplate(BaseChatPromptTemplate):
|
||||
"""The system prompt template."""
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
Returns:
|
||||
List[BaseMessage]: The formatted messages.
|
||||
"""
|
||||
content = self.prompt.format(**kwargs)
|
||||
return [SystemMessage(content=content)]
|
||||
|
||||
@@ -157,20 +155,31 @@ class HumanPromptTemplate(BaseChatPromptTemplate):
|
||||
"""The human prompt template."""
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
Returns:
|
||||
List[BaseMessage]: The formatted messages.
|
||||
"""
|
||||
content = self.prompt.format(**kwargs)
|
||||
return [HumanMessage(content=content)]
|
||||
|
||||
|
||||
class MessagesPlaceholder(BaseChatPromptTemplate):
|
||||
class MessagesPlaceholder(BaseModel):
|
||||
"""The messages placeholder template.
|
||||
|
||||
Mostly used for the chat history.
|
||||
"""
|
||||
|
||||
variable_name: str
|
||||
prompt: BasePromptTemplate = None
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
Just return the messages from the kwargs with the variable name.
|
||||
|
||||
Returns:
|
||||
List[BaseMessage]: The messages.
|
||||
"""
|
||||
messages = kwargs.get(self.variable_name, [])
|
||||
if not isinstance(messages, list):
|
||||
raise ValueError(
|
||||
@@ -185,7 +194,7 @@ class MessagesPlaceholder(BaseChatPromptTemplate):
|
||||
|
||||
@property
|
||||
def input_variables(self) -> List[str]:
|
||||
"""A list of the names of the variables the prompt template expects.
|
||||
"""Return a list of the names of the variables the prompt template expects.
|
||||
|
||||
Returns:
|
||||
List[str]: The input variables.
|
||||
@@ -193,10 +202,26 @@ class MessagesPlaceholder(BaseChatPromptTemplate):
|
||||
return [self.variable_name]
|
||||
|
||||
|
||||
MessageType = Union[BaseChatPromptTemplate, BaseMessage]
|
||||
MessageType = Union[BaseChatPromptTemplate, MessagesPlaceholder, BaseMessage]
|
||||
|
||||
|
||||
class ChatPromptTemplate(BasePromptTemplate):
|
||||
"""The chat prompt template.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
prompt_template = ChatPromptTemplate(
|
||||
messages=[
|
||||
SystemPromptTemplate.from_template(
|
||||
"You are a helpful AI assistant."
|
||||
),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
HumanPromptTemplate.from_template("{question}"),
|
||||
]
|
||||
)
|
||||
"""
|
||||
|
||||
messages: List[MessageType]
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
@@ -205,12 +230,7 @@ class ChatPromptTemplate(BasePromptTemplate):
|
||||
for message in self.messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
result_messages.append(message)
|
||||
elif isinstance(message, BaseChatPromptTemplate):
|
||||
pass_kwargs = {
|
||||
k: v for k, v in kwargs.items() if k in message.input_variables
|
||||
}
|
||||
result_messages.extend(message.format_messages(**pass_kwargs))
|
||||
elif isinstance(message, MessagesPlaceholder):
|
||||
elif isinstance(message, (BaseChatPromptTemplate, MessagesPlaceholder)):
|
||||
pass_kwargs = {
|
||||
k: v for k, v in kwargs.items() if k in message.input_variables
|
||||
}
|
||||
@@ -227,7 +247,7 @@ class ChatPromptTemplate(BasePromptTemplate):
|
||||
if not input_variables:
|
||||
input_variables = set()
|
||||
for message in messages:
|
||||
if isinstance(message, BaseChatPromptTemplate):
|
||||
if isinstance(message, (BaseChatPromptTemplate, MessagesPlaceholder)):
|
||||
input_variables.update(message.input_variables)
|
||||
values["input_variables"] = sorted(input_variables)
|
||||
return values
|
||||
@@ -235,6 +255,8 @@ class ChatPromptTemplate(BasePromptTemplate):
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PromptTemplateIdentifier(ResourceIdentifier):
|
||||
"""The identifier of a prompt template."""
|
||||
|
||||
identifier_split: str = dataclasses.field(default="___$$$$___", init=False)
|
||||
prompt_name: str
|
||||
prompt_language: Optional[str] = None
|
||||
@@ -242,6 +264,7 @@ class PromptTemplateIdentifier(ResourceIdentifier):
|
||||
model: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post init method."""
|
||||
if self.prompt_name is None:
|
||||
raise ValueError("prompt_name cannot be None")
|
||||
|
||||
@@ -256,11 +279,13 @@ class PromptTemplateIdentifier(ResourceIdentifier):
|
||||
if key is not None
|
||||
):
|
||||
raise ValueError(
|
||||
f"identifier_split {self.identifier_split} is not allowed in prompt_name, prompt_language, sys_code, model"
|
||||
f"identifier_split {self.identifier_split} is not allowed in "
|
||||
f"prompt_name, prompt_language, sys_code, model"
|
||||
)
|
||||
|
||||
@property
|
||||
def str_identifier(self) -> str:
|
||||
"""Return the string identifier of the identifier."""
|
||||
return self.identifier_split.join(
|
||||
key
|
||||
for key in [
|
||||
@@ -273,6 +298,11 @@ class PromptTemplateIdentifier(ResourceIdentifier):
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert the identifier to a dict.
|
||||
|
||||
Returns:
|
||||
Dict: The dict of the identifier.
|
||||
"""
|
||||
return {
|
||||
"prompt_name": self.prompt_name,
|
||||
"prompt_language": self.prompt_language,
|
||||
@@ -283,6 +313,8 @@ class PromptTemplateIdentifier(ResourceIdentifier):
|
||||
|
||||
@dataclasses.dataclass
|
||||
class StoragePromptTemplate(StorageItem):
|
||||
"""The storage prompt template."""
|
||||
|
||||
prompt_name: str
|
||||
content: Optional[str] = None
|
||||
prompt_language: Optional[str] = None
|
||||
@@ -297,25 +329,28 @@ class StoragePromptTemplate(StorageItem):
|
||||
_identifier: PromptTemplateIdentifier = dataclasses.field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post init method."""
|
||||
self._identifier = PromptTemplateIdentifier(
|
||||
prompt_name=self.prompt_name,
|
||||
prompt_language=self.prompt_language,
|
||||
sys_code=self.sys_code,
|
||||
model=self.model,
|
||||
)
|
||||
self._check() # Assuming _check() is a method you need to call after initialization
|
||||
# Assuming _check() is a method you need to call after initialization
|
||||
self._check()
|
||||
|
||||
def to_prompt_template(self) -> PromptTemplate:
|
||||
"""Convert the storage prompt template to a prompt template."""
|
||||
input_variables = (
|
||||
[] if not self.input_variables else self.input_variables.strip().split(",")
|
||||
)
|
||||
template_format = self.prompt_format or "f-string"
|
||||
return PromptTemplate(
|
||||
input_variables=input_variables,
|
||||
template=self.content,
|
||||
template_scene=self.chat_scene,
|
||||
prompt_name=self.prompt_name,
|
||||
template_format=self.prompt_format,
|
||||
# prompt_name=self.prompt_name,
|
||||
template_format=template_format,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -335,12 +370,18 @@ class StoragePromptTemplate(StorageItem):
|
||||
Args:
|
||||
prompt_template (PromptTemplate): The prompt template to convert from.
|
||||
prompt_name (str): The name of the prompt.
|
||||
prompt_language (Optional[str], optional): The language of the prompt. Defaults to None. e.g. zh-cn, en.
|
||||
prompt_type (Optional[str], optional): The type of the prompt. Defaults to None. e.g. common, private.
|
||||
sys_code (Optional[str], optional): The system code of the prompt. Defaults to None.
|
||||
user_name (Optional[str], optional): The username of the prompt. Defaults to None.
|
||||
sub_chat_scene (Optional[str], optional): The sub chat scene of the prompt. Defaults to None.
|
||||
model (Optional[str], optional): The model name of the prompt. Defaults to None.
|
||||
prompt_language (Optional[str], optional): The language of the prompt.
|
||||
Defaults to None. e.g. zh-cn, en.
|
||||
prompt_type (Optional[str], optional): The type of the prompt.
|
||||
Defaults to None. e.g. common, private.
|
||||
sys_code (Optional[str], optional): The system code of the prompt.
|
||||
Defaults to None.
|
||||
user_name (Optional[str], optional): The username of the prompt.
|
||||
Defaults to None.
|
||||
sub_chat_scene (Optional[str], optional): The sub chat scene of the prompt.
|
||||
Defaults to None.
|
||||
model (Optional[str], optional): The model name of the prompt.
|
||||
Defaults to None.
|
||||
kwargs (Dict): Other params to build the storage prompt template.
|
||||
"""
|
||||
input_variables = prompt_template.input_variables or kwargs.get(
|
||||
@@ -365,6 +406,7 @@ class StoragePromptTemplate(StorageItem):
|
||||
|
||||
@property
|
||||
def identifier(self) -> PromptTemplateIdentifier:
|
||||
"""Return the identifier of the storage prompt template."""
|
||||
return self._identifier
|
||||
|
||||
def merge(self, other: "StorageItem") -> None:
|
||||
@@ -375,11 +417,17 @@ class StoragePromptTemplate(StorageItem):
|
||||
"""
|
||||
if not isinstance(other, StoragePromptTemplate):
|
||||
raise ValueError(
|
||||
f"Cannot merge {type(other)} into {type(self)} because they are not the same type."
|
||||
f"Cannot merge {type(other)} into {type(self)} because they are not "
|
||||
f"the same type."
|
||||
)
|
||||
self.from_object(other)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert the storage prompt template to a dict.
|
||||
|
||||
Returns:
|
||||
Dict: The dict of the storage prompt template.
|
||||
"""
|
||||
return {
|
||||
"prompt_name": self.prompt_name,
|
||||
"content": self.content,
|
||||
@@ -422,7 +470,6 @@ class PromptManager:
|
||||
Simple wrapper for the storage interface.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Default use InMemoryStorage
|
||||
@@ -458,13 +505,14 @@ class PromptManager:
|
||||
def __init__(
|
||||
self, storage: Optional[StorageInterface[StoragePromptTemplate, Any]] = None
|
||||
):
|
||||
"""Create a new prompt manager."""
|
||||
if storage is None:
|
||||
storage = InMemoryStorage()
|
||||
self._storage = storage
|
||||
|
||||
@property
|
||||
def storage(self) -> StorageInterface[StoragePromptTemplate, Any]:
|
||||
"""The storage interface for prompt templates."""
|
||||
"""Return the storage interface for prompt templates."""
|
||||
return self._storage
|
||||
|
||||
def prefer_query(
|
||||
@@ -477,11 +525,12 @@ class PromptManager:
|
||||
) -> List[StoragePromptTemplate]:
|
||||
"""Query prompt templates from storage with prefer params.
|
||||
|
||||
Sometimes, we want to query prompt templates with prefer params(e.g. some language or some model).
|
||||
This method will query prompt templates with prefer params first, if not found, will query all prompt templates.
|
||||
Sometimes, we want to query prompt templates with prefer params(e.g. some
|
||||
language or some model).
|
||||
This method will query prompt templates with prefer params first, if not found,
|
||||
will query all prompt templates.
|
||||
|
||||
Examples:
|
||||
|
||||
Query a prompt template.
|
||||
.. code-block:: python
|
||||
|
||||
@@ -500,7 +549,8 @@ class PromptManager:
|
||||
.. code-block:: python
|
||||
|
||||
# First query with prompt name "hello" exactly.
|
||||
# Second filter with prompt language "zh-cn", if not found, will return all prompt templates.
|
||||
# Second filter with prompt language "zh-cn", if not found, will return
|
||||
# all prompt templates.
|
||||
prompt_template_list = prompt_manager.prefer_query(
|
||||
"hello", prefer_prompt_language="zh-cn"
|
||||
)
|
||||
@@ -510,17 +560,22 @@ class PromptManager:
|
||||
.. code-block:: python
|
||||
|
||||
# First query with prompt name "hello" exactly.
|
||||
# Second filter with model "vicuna-13b-v1.5", if not found, will return all prompt templates.
|
||||
# Second filter with model "vicuna-13b-v1.5", if not found, will return
|
||||
# all prompt templates.
|
||||
prompt_template_list = prompt_manager.prefer_query(
|
||||
"hello", prefer_model="vicuna-13b-v1.5"
|
||||
)
|
||||
|
||||
Args:
|
||||
prompt_name (str): The name of the prompt template.
|
||||
sys_code (Optional[str], optional): The system code of the prompt template. Defaults to None.
|
||||
prefer_prompt_language (Optional[str], optional): The language of the prompt template. Defaults to None.
|
||||
prefer_model (Optional[str], optional): The model of the prompt template. Defaults to None.
|
||||
kwargs (Dict): Other query params(If some key and value not None, wo we query it exactly).
|
||||
sys_code (Optional[str], optional): The system code of the prompt template.
|
||||
Defaults to None.
|
||||
prefer_prompt_language (Optional[str], optional): The language of the
|
||||
prompt template. Defaults to None.
|
||||
prefer_model (Optional[str], optional): The model of the prompt template.
|
||||
Defaults to None.
|
||||
kwargs (Dict): Other query params(If some key and value not None, wo we
|
||||
query it exactly).
|
||||
"""
|
||||
query_spec = QuerySpec(
|
||||
conditions={
|
||||
@@ -559,7 +614,6 @@ class PromptManager:
|
||||
"""Save a prompt template to storage.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt_template = PromptTemplate(
|
||||
@@ -618,15 +672,17 @@ class PromptManager:
|
||||
if exist_prompt_template:
|
||||
return exist_prompt_template
|
||||
self.save(prompt_template, prompt_name, **kwargs)
|
||||
return self.storage.load(
|
||||
prompt = self.storage.load(
|
||||
storage_prompt_template.identifier, StoragePromptTemplate
|
||||
)
|
||||
if not prompt:
|
||||
raise ValueError("Can't read prompt from storage")
|
||||
return prompt
|
||||
|
||||
def list(self, **kwargs) -> List[StoragePromptTemplate]:
|
||||
"""List prompt templates from storage.
|
||||
|
||||
Examples:
|
||||
|
||||
List all prompt templates.
|
||||
.. code-block:: python
|
||||
|
||||
@@ -656,7 +712,6 @@ class PromptManager:
|
||||
"""Delete a prompt template from storage.
|
||||
|
||||
Examples:
|
||||
|
||||
Delete a prompt template.
|
||||
|
||||
.. code-block:: python
|
||||
@@ -673,9 +728,12 @@ class PromptManager:
|
||||
|
||||
Args:
|
||||
prompt_name (str): The name of the prompt template.
|
||||
prompt_language (Optional[str], optional): The language of the prompt template. Defaults to None.
|
||||
sys_code (Optional[str], optional): The system code of the prompt template. Defaults to None.
|
||||
model (Optional[str], optional): The model of the prompt template. Defaults to None.
|
||||
prompt_language (Optional[str], optional): The language of the prompt
|
||||
template. Defaults to None.
|
||||
sys_code (Optional[str], optional): The system code of the prompt template.
|
||||
Defaults to None.
|
||||
model (Optional[str], optional): The model of the prompt template.
|
||||
Defaults to None.
|
||||
"""
|
||||
identifier = PromptTemplateIdentifier(
|
||||
prompt_name=prompt_name,
|
||||
|
Reference in New Issue
Block a user