prompt ergonomics (#7799)

This commit is contained in:
Harrison Chase
2023-07-22 14:19:17 -07:00
committed by GitHub
parent d81d6e874f
commit cbf2fc8af8
4 changed files with 420 additions and 1 deletions

View File

@@ -56,6 +56,10 @@ class BaseMessagePromptTemplate(Serializable, ABC):
List of input variables.
"""
def __add__(self, other: Any) -> ChatPromptTemplate:
prompt = ChatPromptTemplate(messages=[self])
return prompt + other
class MessagesPlaceholder(BaseMessagePromptTemplate):
"""Prompt template that assumes variable is already list of messages."""
@@ -261,6 +265,18 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
"""List of messages."""
def __add__(self, other: Any) -> ChatPromptTemplate:
# Allow for easy combining
if isinstance(other, ChatPromptTemplate):
return ChatPromptTemplate(messages=self.messages + other.messages)
elif isinstance(other, (BaseMessagePromptTemplate, BaseMessage)):
return ChatPromptTemplate(messages=self.messages + [other])
elif isinstance(other, str):
prompt = HumanMessagePromptTemplate.from_template(other)
return ChatPromptTemplate(messages=self.messages + [prompt])
else:
raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")
@root_validator(pre=True)
def validate_input_variables(cls, values: dict) -> dict:
"""

View File

@@ -43,6 +43,42 @@ class PromptTemplate(StringPromptTemplate):
validate_template: bool = True
"""Whether or not to try validating the template."""
def __add__(self, other: Any) -> PromptTemplate:
# Allow for easy combining
if isinstance(other, PromptTemplate):
if self.template_format != "f-string":
raise ValueError(
"Adding prompt templates only supported for f-strings."
)
if other.template_format != "f-string":
raise ValueError(
"Adding prompt templates only supported for f-strings."
)
input_variables = list(
set(self.input_variables) | set(other.input_variables)
)
template = self.template + other.template
# If any do not want to validate, then don't
validate_template = self.validate_template and other.validate_template
partial_variables = {k: v for k, v in self.partial_variables.items()}
for k, v in other.partial_variables.items():
if k in partial_variables:
raise ValueError("Cannot have same variable partialed twice.")
else:
partial_variables[k] = v
return PromptTemplate(
template=template,
input_variables=input_variables,
partial_variables=partial_variables,
template_format="f-string",
validate_template=validate_template,
)
elif isinstance(other, str):
prompt = PromptTemplate.from_template(other)
return self + prompt
else:
raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""

View File

@@ -1,12 +1,15 @@
from __future__ import annotations
from abc import abstractmethod
from typing import List, Sequence
from typing import TYPE_CHECKING, Any, List, Sequence
from pydantic import Field
from langchain.load.serializable import Serializable
if TYPE_CHECKING:
from langchain.prompts.chat import ChatPromptTemplate
def get_buffer_string(
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
@@ -77,6 +80,12 @@ class BaseMessage(Serializable):
"""Whether this class is LangChain serializable."""
return True
def __add__(self, other: Any) -> ChatPromptTemplate:
from langchain.prompts.chat import ChatPromptTemplate
prompt = ChatPromptTemplate(messages=[self])
return prompt + other
class HumanMessage(BaseMessage):
"""A Message from a human."""