From d5533b70811f8330cc3504cc9edd58ea66218984 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 21 Dec 2023 12:36:37 -0800 Subject: [PATCH] Add option to make messages placeholder optional (#15031) --- libs/core/langchain_core/prompts/chat.py | 14 +++++++++++--- libs/core/tests/unit_tests/prompts/test_chat.py | 9 +++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 2b24bf3d334..e2978d383bf 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -87,13 +87,17 @@ class MessagesPlaceholder(BaseMessagePromptTemplate): variable_name: str """Name of variable to use as messages.""" + optional: bool = False + @classmethod def get_lc_namespace(cls) -> List[str]: """Get the namespace of the langchain object.""" return ["langchain", "prompts", "chat"] - def __init__(self, variable_name: str, **kwargs: Any): - return super().__init__(variable_name=variable_name, **kwargs) + def __init__(self, variable_name: str, *, optional: bool = False, **kwargs: Any): + return super().__init__( + variable_name=variable_name, optional=optional, **kwargs + ) def format_messages(self, **kwargs: Any) -> List[BaseMessage]: """Format messages from kwargs. @@ -104,7 +108,11 @@ class MessagesPlaceholder(BaseMessagePromptTemplate): Returns: List of BaseMessage. """ - value = kwargs[self.variable_name] + value = ( + kwargs.get(self.variable_name, []) + if self.optional + else kwargs[self.variable_name] + ) if not isinstance(value, list): raise ValueError( f"variable {self.variable_name} should be a list of base messages, " diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 3719573a6c6..2765d030d56 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -19,6 +19,7 @@ from langchain_core.prompts.chat import ( ChatMessagePromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate, + MessagesPlaceholder, SystemMessagePromptTemplate, _convert_to_message, ) @@ -360,3 +361,11 @@ def test_chat_message_partial() -> None: ] assert res == expected assert template2.format(input="hello") == get_buffer_string(expected) + + +def test_messages_placeholder() -> None: + prompt = MessagesPlaceholder("history") + with pytest.raises(KeyError): + prompt.format_messages() + prompt = MessagesPlaceholder("history", optional=True) + assert prompt.format_messages() == []