From 372dc7f991eb34cd38dbbb5475821c73731503f6 Mon Sep 17 00:00:00 2001 From: Thommy257 <37929273+Thommy257@users.noreply.github.com> Date: Fri, 28 Mar 2025 21:41:57 +0100 Subject: [PATCH] core[patch]: fix loss of partially initialized variables during prompt composition (#30096) **Description:** This PR addresses the loss of partially initialised variables when composing different prompts. I.e. it allows the following snippet to run: ```python from langchain_core.prompts import ChatPromptTemplate prompt = ChatPromptTemplate.from_messages([('system', 'Prompt {x} {y}')]).partial(x='1') appendix = ChatPromptTemplate.from_messages([('system', 'Appendix {z}')]) (prompt + appendix).invoke({'y': '2', 'z': '3'}) ``` Previously, this would have raised a `KeyError`, stating that variable `x` remains undefined. **Issue** References issue #30049 **Todo** - [x] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ --------- Co-authored-by: Eugene Yurtsev --- libs/core/langchain_core/prompts/chat.py | 23 +++++++++++++++---- .../tests/unit_tests/prompts/test_chat.py | 17 ++++++++++++++ 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 23daa02716b..a7ef3d36c2e 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -1040,19 +1040,34 @@ class ChatPromptTemplate(BaseChatPromptTemplate): Returns: Combined prompt template. """ + partials = {**self.partial_variables} + + # Need to check that other has partial variables since it may not be + # a ChatPromptTemplate. + if hasattr(other, "partial_variables") and other.partial_variables: + partials.update(other.partial_variables) + # Allow for easy combining if isinstance(other, ChatPromptTemplate): - return ChatPromptTemplate(messages=self.messages + other.messages) # type: ignore[call-arg] + return ChatPromptTemplate(messages=self.messages + other.messages).partial( + **partials + ) # type: ignore[call-arg] elif isinstance( other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate) ): - return ChatPromptTemplate(messages=self.messages + [other]) # type: ignore[call-arg] + return ChatPromptTemplate(messages=self.messages + [other]).partial( + **partials + ) # type: ignore[call-arg] elif isinstance(other, (list, tuple)): _other = ChatPromptTemplate.from_messages(other) - return ChatPromptTemplate(messages=self.messages + _other.messages) # type: ignore[call-arg] + return ChatPromptTemplate(messages=self.messages + _other.messages).partial( + **partials + ) # type: ignore[call-arg] elif isinstance(other, str): prompt = HumanMessagePromptTemplate.from_template(other) - return ChatPromptTemplate(messages=self.messages + [prompt]) # type: ignore[call-arg] + return ChatPromptTemplate(messages=self.messages + [prompt]).partial( + **partials + ) # type: ignore[call-arg] else: msg = f"Unsupported operand type for +: {type(other)}" raise NotImplementedError(msg) diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 020c959778c..3331014a271 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -582,6 +582,23 @@ def test_chat_message_partial() -> None: assert template2.format(input="hello") == get_buffer_string(expected) +def test_chat_message_partial_composition() -> None: + """Test composition of partially initialized messages.""" + prompt = ChatPromptTemplate.from_messages([("system", "Prompt {x} {y}")]).partial( + x="1" + ) + + appendix = ChatPromptTemplate.from_messages([("system", "Appendix {z}")]) + + res = (prompt + appendix).format_messages(y="2", z="3") + expected = [ + SystemMessage(content="Prompt 1 2"), + SystemMessage(content="Appendix 3"), + ] + + assert res == expected + + async def test_chat_tmpl_from_messages_multipart_text() -> None: template = ChatPromptTemplate.from_messages( [