allow passing of messages into prompt template (#1505)

This commit is contained in:
Harrison Chase 2023-03-07 21:10:12 -08:00 committed by GitHub
parent a4a2d79087
commit 7ade419a0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 8 deletions

View File

@ -120,7 +120,7 @@ class ChatPromptValue(PromptValue):
class ChatPromptTemplate(BasePromptTemplate, ABC): class ChatPromptTemplate(BasePromptTemplate, ABC):
input_variables: List[str] input_variables: List[str]
messages: List[BaseMessagePromptTemplate] messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
@classmethod @classmethod
def from_role_strings( def from_role_strings(
@ -146,11 +146,12 @@ class ChatPromptTemplate(BasePromptTemplate, ABC):
@classmethod @classmethod
def from_messages( def from_messages(
cls, messages: Sequence[BaseMessagePromptTemplate] cls, messages: Sequence[Union[BaseMessagePromptTemplate, BaseMessage]]
) -> ChatPromptTemplate: ) -> ChatPromptTemplate:
input_vars = set() input_vars = set()
for message in messages: for message in messages:
input_vars.update(message.input_variables) if isinstance(message, BaseMessagePromptTemplate):
input_vars.update(message.input_variables)
return cls(input_variables=list(input_vars), messages=messages) return cls(input_variables=list(input_vars), messages=messages)
def format(self, **kwargs: Any) -> str: def format(self, **kwargs: Any) -> str:
@ -159,11 +160,18 @@ class ChatPromptTemplate(BasePromptTemplate, ABC):
def format_prompt(self, **kwargs: Any) -> PromptValue: def format_prompt(self, **kwargs: Any) -> PromptValue:
result = [] result = []
for message_template in self.messages: for message_template in self.messages:
rel_params = { if isinstance(message_template, BaseMessage):
k: v for k, v in kwargs.items() if k in message_template.input_variables result.extend([message_template])
} elif isinstance(message_template, BaseMessagePromptTemplate):
message = message_template.format_messages(**rel_params) rel_params = {
result.extend(message) k: v
for k, v in kwargs.items()
if k in message_template.input_variables
}
message = message_template.format_messages(**rel_params)
result.extend(message)
else:
raise ValueError(f"Unexpected input: {message_template}")
return ChatPromptValue(messages=result) return ChatPromptValue(messages=result)
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate: def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:

View File

@ -10,6 +10,7 @@ from langchain.prompts.chat import (
HumanMessagePromptTemplate, HumanMessagePromptTemplate,
SystemMessagePromptTemplate, SystemMessagePromptTemplate,
) )
from langchain.schema import HumanMessage
def create_messages() -> List[BaseMessagePromptTemplate]: def create_messages() -> List[BaseMessagePromptTemplate]:
@ -89,3 +90,17 @@ def test_chat_prompt_template_from_messages() -> None:
["context", "foo", "bar"] ["context", "foo", "bar"]
) )
assert len(chat_prompt_template.messages) == 4 assert len(chat_prompt_template.messages) == 4
def test_chat_prompt_template_with_messages() -> None:
messages = create_messages() + [HumanMessage(content="foo")]
chat_prompt_template = ChatPromptTemplate.from_messages(messages)
assert sorted(chat_prompt_template.input_variables) == sorted(
["context", "foo", "bar"]
)
assert len(chat_prompt_template.messages) == 5
prompt_value = chat_prompt_template.format_prompt(
context="see", foo="this", bar="magic"
)
prompt_value_messages = prompt_value.to_messages()
assert prompt_value_messages[-1] == HumanMessage(content="foo")