mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-06 07:04:01 +00:00
allow passing of messages into prompt template (#1505)
This commit is contained in:
parent
a4a2d79087
commit
7ade419a0e
@ -120,7 +120,7 @@ class ChatPromptValue(PromptValue):
|
|||||||
|
|
||||||
class ChatPromptTemplate(BasePromptTemplate, ABC):
|
class ChatPromptTemplate(BasePromptTemplate, ABC):
|
||||||
input_variables: List[str]
|
input_variables: List[str]
|
||||||
messages: List[BaseMessagePromptTemplate]
|
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_role_strings(
|
def from_role_strings(
|
||||||
@ -146,11 +146,12 @@ class ChatPromptTemplate(BasePromptTemplate, ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_messages(
|
def from_messages(
|
||||||
cls, messages: Sequence[BaseMessagePromptTemplate]
|
cls, messages: Sequence[Union[BaseMessagePromptTemplate, BaseMessage]]
|
||||||
) -> ChatPromptTemplate:
|
) -> ChatPromptTemplate:
|
||||||
input_vars = set()
|
input_vars = set()
|
||||||
for message in messages:
|
for message in messages:
|
||||||
input_vars.update(message.input_variables)
|
if isinstance(message, BaseMessagePromptTemplate):
|
||||||
|
input_vars.update(message.input_variables)
|
||||||
return cls(input_variables=list(input_vars), messages=messages)
|
return cls(input_variables=list(input_vars), messages=messages)
|
||||||
|
|
||||||
def format(self, **kwargs: Any) -> str:
|
def format(self, **kwargs: Any) -> str:
|
||||||
@ -159,11 +160,18 @@ class ChatPromptTemplate(BasePromptTemplate, ABC):
|
|||||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||||
result = []
|
result = []
|
||||||
for message_template in self.messages:
|
for message_template in self.messages:
|
||||||
rel_params = {
|
if isinstance(message_template, BaseMessage):
|
||||||
k: v for k, v in kwargs.items() if k in message_template.input_variables
|
result.extend([message_template])
|
||||||
}
|
elif isinstance(message_template, BaseMessagePromptTemplate):
|
||||||
message = message_template.format_messages(**rel_params)
|
rel_params = {
|
||||||
result.extend(message)
|
k: v
|
||||||
|
for k, v in kwargs.items()
|
||||||
|
if k in message_template.input_variables
|
||||||
|
}
|
||||||
|
message = message_template.format_messages(**rel_params)
|
||||||
|
result.extend(message)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected input: {message_template}")
|
||||||
return ChatPromptValue(messages=result)
|
return ChatPromptValue(messages=result)
|
||||||
|
|
||||||
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
|
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
|
||||||
|
@ -10,6 +10,7 @@ from langchain.prompts.chat import (
|
|||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
|
from langchain.schema import HumanMessage
|
||||||
|
|
||||||
|
|
||||||
def create_messages() -> List[BaseMessagePromptTemplate]:
|
def create_messages() -> List[BaseMessagePromptTemplate]:
|
||||||
@ -89,3 +90,17 @@ def test_chat_prompt_template_from_messages() -> None:
|
|||||||
["context", "foo", "bar"]
|
["context", "foo", "bar"]
|
||||||
)
|
)
|
||||||
assert len(chat_prompt_template.messages) == 4
|
assert len(chat_prompt_template.messages) == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_prompt_template_with_messages() -> None:
|
||||||
|
messages = create_messages() + [HumanMessage(content="foo")]
|
||||||
|
chat_prompt_template = ChatPromptTemplate.from_messages(messages)
|
||||||
|
assert sorted(chat_prompt_template.input_variables) == sorted(
|
||||||
|
["context", "foo", "bar"]
|
||||||
|
)
|
||||||
|
assert len(chat_prompt_template.messages) == 5
|
||||||
|
prompt_value = chat_prompt_template.format_prompt(
|
||||||
|
context="see", foo="this", bar="magic"
|
||||||
|
)
|
||||||
|
prompt_value_messages = prompt_value.to_messages()
|
||||||
|
assert prompt_value_messages[-1] == HumanMessage(content="foo")
|
||||||
|
Loading…
Reference in New Issue
Block a user