mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-03 12:07:36 +00:00
prompt ergonomics (#7799)
This commit is contained in:
@@ -56,6 +56,10 @@ class BaseMessagePromptTemplate(Serializable, ABC):
|
||||
List of input variables.
|
||||
"""
|
||||
|
||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||
prompt = ChatPromptTemplate(messages=[self])
|
||||
return prompt + other
|
||||
|
||||
|
||||
class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
"""Prompt template that assumes variable is already list of messages."""
|
||||
@@ -261,6 +265,18 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
|
||||
"""List of messages."""
|
||||
|
||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||
# Allow for easy combining
|
||||
if isinstance(other, ChatPromptTemplate):
|
||||
return ChatPromptTemplate(messages=self.messages + other.messages)
|
||||
elif isinstance(other, (BaseMessagePromptTemplate, BaseMessage)):
|
||||
return ChatPromptTemplate(messages=self.messages + [other])
|
||||
elif isinstance(other, str):
|
||||
prompt = HumanMessagePromptTemplate.from_template(other)
|
||||
return ChatPromptTemplate(messages=self.messages + [prompt])
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_input_variables(cls, values: dict) -> dict:
|
||||
"""
|
||||
|
@@ -43,6 +43,42 @@ class PromptTemplate(StringPromptTemplate):
|
||||
validate_template: bool = True
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
def __add__(self, other: Any) -> PromptTemplate:
|
||||
# Allow for easy combining
|
||||
if isinstance(other, PromptTemplate):
|
||||
if self.template_format != "f-string":
|
||||
raise ValueError(
|
||||
"Adding prompt templates only supported for f-strings."
|
||||
)
|
||||
if other.template_format != "f-string":
|
||||
raise ValueError(
|
||||
"Adding prompt templates only supported for f-strings."
|
||||
)
|
||||
input_variables = list(
|
||||
set(self.input_variables) | set(other.input_variables)
|
||||
)
|
||||
template = self.template + other.template
|
||||
# If any do not want to validate, then don't
|
||||
validate_template = self.validate_template and other.validate_template
|
||||
partial_variables = {k: v for k, v in self.partial_variables.items()}
|
||||
for k, v in other.partial_variables.items():
|
||||
if k in partial_variables:
|
||||
raise ValueError("Cannot have same variable partialed twice.")
|
||||
else:
|
||||
partial_variables[k] = v
|
||||
return PromptTemplate(
|
||||
template=template,
|
||||
input_variables=input_variables,
|
||||
partial_variables=partial_variables,
|
||||
template_format="f-string",
|
||||
validate_template=validate_template,
|
||||
)
|
||||
elif isinstance(other, str):
|
||||
prompt = PromptTemplate.from_template(other)
|
||||
return self + prompt
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Return the prompt type key."""
|
||||
|
@@ -1,12 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import List, Sequence
|
||||
from typing import TYPE_CHECKING, Any, List, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
|
||||
def get_buffer_string(
|
||||
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
||||
@@ -77,6 +80,12 @@ class BaseMessage(Serializable):
|
||||
"""Whether this class is LangChain serializable."""
|
||||
return True
|
||||
|
||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
prompt = ChatPromptTemplate(messages=[self])
|
||||
return prompt + other
|
||||
|
||||
|
||||
class HumanMessage(BaseMessage):
|
||||
"""A Message from a human."""
|
||||
|
Reference in New Issue
Block a user