From 8a140ee77cf72ef0f45e982693e14355d885ad8a Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Mon, 22 Jul 2024 13:30:16 -0700 Subject: [PATCH] core[patch]: don't serialize BasePromptTemplate.input_types (#24516) Candidate fix for #24513 --- libs/core/langchain_core/prompts/base.py | 2 +- .../prompts/__snapshots__/test_chat.ambr | 217 ++++++++++++++++++ .../tests/unit_tests/prompts/test_chat.py | 15 +- 3 files changed, 230 insertions(+), 4 deletions(-) diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index 11ccffe17ad..6b0db3598f0 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -47,7 +47,7 @@ class BasePromptTemplate( prompt.""" optional_variables: List[str] = Field(default=[]) """A list of the names of the variables that are optional in the prompt.""" - input_types: Dict[str, Any] = Field(default_factory=dict) + input_types: Dict[str, Any] = Field(default_factory=dict, exclude=True) """A dictionary of the types of the variables the prompt template expects. If not provided, all variables are assumed to be strings.""" output_parser: Optional[BaseOutputParser] = None diff --git a/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr b/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr index 3f59bb44292..4a7ae82b270 100644 --- a/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr +++ b/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr @@ -1216,3 +1216,220 @@ 'type': 'object', }) # --- +# name: test_chat_prompt_w_msgs_placeholder_ser_des[chat_prompt] + dict({ + 'graph': dict({ + 'edges': list([ + dict({ + 'source': 0, + 'target': 1, + }), + dict({ + 'source': 1, + 'target': 2, + }), + ]), + 'nodes': list([ + dict({ + 'data': 'PromptInput', + 'id': 0, + 'type': 'schema', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langchain', + 'prompts', + 'chat', + 'ChatPromptTemplate', + ]), + 'name': 'ChatPromptTemplate', + }), + 'id': 1, + 'type': 'runnable', + }), + dict({ + 'data': 'ChatPromptTemplateOutput', + 'id': 2, + 'type': 'schema', + }), + ]), + }), + 'id': list([ + 'langchain', + 'prompts', + 'chat', + 'ChatPromptTemplate', + ]), + 'kwargs': dict({ + 'input_variables': list([ + 'bar', + ]), + 'messages': list([ + dict({ + 'id': list([ + 'langchain', + 'prompts', + 'chat', + 'SystemMessagePromptTemplate', + ]), + 'kwargs': dict({ + 'prompt': dict({ + 'graph': dict({ + 'edges': list([ + dict({ + 'source': 0, + 'target': 1, + }), + dict({ + 'source': 1, + 'target': 2, + }), + ]), + 'nodes': list([ + dict({ + 'data': 'PromptInput', + 'id': 0, + 'type': 'schema', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langchain', + 'prompts', + 'prompt', + 'PromptTemplate', + ]), + 'name': 'PromptTemplate', + }), + 'id': 1, + 'type': 'runnable', + }), + dict({ + 'data': 'PromptTemplateOutput', + 'id': 2, + 'type': 'schema', + }), + ]), + }), + 'id': list([ + 'langchain', + 'prompts', + 'prompt', + 'PromptTemplate', + ]), + 'kwargs': dict({ + 'input_variables': list([ + ]), + 'template': 'foo', + 'template_format': 'f-string', + }), + 'lc': 1, + 'name': 'PromptTemplate', + 'type': 'constructor', + }), + }), + 'lc': 1, + 'type': 'constructor', + }), + dict({ + 'id': list([ + 'langchain', + 'prompts', + 'chat', + 'MessagesPlaceholder', + ]), + 'kwargs': dict({ + 'variable_name': 'bar', + }), + 'lc': 1, + 'type': 'constructor', + }), + dict({ + 'id': list([ + 'langchain', + 'prompts', + 'chat', + 'HumanMessagePromptTemplate', + ]), + 'kwargs': dict({ + 'prompt': dict({ + 'graph': dict({ + 'edges': list([ + dict({ + 'source': 0, + 'target': 1, + }), + dict({ + 'source': 1, + 'target': 2, + }), + ]), + 'nodes': list([ + dict({ + 'data': 'PromptInput', + 'id': 0, + 'type': 'schema', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langchain', + 'prompts', + 'prompt', + 'PromptTemplate', + ]), + 'name': 'PromptTemplate', + }), + 'id': 1, + 'type': 'runnable', + }), + dict({ + 'data': 'PromptTemplateOutput', + 'id': 2, + 'type': 'schema', + }), + ]), + }), + 'id': list([ + 'langchain', + 'prompts', + 'prompt', + 'PromptTemplate', + ]), + 'kwargs': dict({ + 'input_variables': list([ + ]), + 'template': 'baz', + 'template_format': 'f-string', + }), + 'lc': 1, + 'name': 'PromptTemplate', + 'type': 'constructor', + }), + }), + 'lc': 1, + 'type': 'constructor', + }), + ]), + }), + 'lc': 1, + 'name': 'ChatPromptTemplate', + 'type': 'constructor', + }) +# --- +# name: test_chat_prompt_w_msgs_placeholder_ser_des[placholder] + dict({ + 'id': list([ + 'langchain', + 'prompts', + 'chat', + 'MessagesPlaceholder', + ]), + 'kwargs': dict({ + 'variable_name': 'bar', + }), + 'lc': 1, + 'type': 'constructor', + }) +# --- diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 9c01d2a1e80..4a8fd4ede3d 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -6,9 +6,8 @@ from typing import Any, List, Union import pytest from syrupy import SnapshotAssertion -from langchain_core._api.deprecation import ( - LangChainPendingDeprecationWarning, -) +from langchain_core._api.deprecation import LangChainPendingDeprecationWarning +from langchain_core.load import dumpd, load from langchain_core.messages import ( AIMessage, BaseMessage, @@ -806,3 +805,13 @@ def test_chat_input_schema(snapshot: SnapshotAssertion) -> None: assert set(prompt_optional.input_variables) == {"input"} prompt_optional.input_schema(input="") # won't raise error assert prompt_optional.input_schema.schema() == snapshot(name="partial") + + +def test_chat_prompt_w_msgs_placeholder_ser_des(snapshot: SnapshotAssertion) -> None: + prompt = ChatPromptTemplate.from_messages( + [("system", "foo"), MessagesPlaceholder("bar"), ("human", "baz")] + ) + assert dumpd(MessagesPlaceholder("bar")) == snapshot(name="placholder") + assert load(dumpd(MessagesPlaceholder("bar"))) == MessagesPlaceholder("bar") + assert dumpd(prompt) == snapshot(name="chat_prompt") + assert load(dumpd(prompt)) == prompt