core[patch]: don't serialize BasePromptTemplate.input_types (#24516)

Candidate fix for #24513
This commit is contained in:
Bagatur 2024-07-22 13:30:16 -07:00 committed by GitHub
parent df357f82ca
commit 8a140ee77c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 230 additions and 4 deletions

View File

@ -47,7 +47,7 @@ class BasePromptTemplate(
prompt.""" prompt."""
optional_variables: List[str] = Field(default=[]) optional_variables: List[str] = Field(default=[])
"""A list of the names of the variables that are optional in the prompt.""" """A list of the names of the variables that are optional in the prompt."""
input_types: Dict[str, Any] = Field(default_factory=dict) input_types: Dict[str, Any] = Field(default_factory=dict, exclude=True)
"""A dictionary of the types of the variables the prompt template expects. """A dictionary of the types of the variables the prompt template expects.
If not provided, all variables are assumed to be strings.""" If not provided, all variables are assumed to be strings."""
output_parser: Optional[BaseOutputParser] = None output_parser: Optional[BaseOutputParser] = None

View File

@ -1216,3 +1216,220 @@
'type': 'object', 'type': 'object',
}) })
# --- # ---
# name: test_chat_prompt_w_msgs_placeholder_ser_des[chat_prompt]
dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'PromptInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'prompts',
'chat',
'ChatPromptTemplate',
]),
'name': 'ChatPromptTemplate',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'ChatPromptTemplateOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain',
'prompts',
'chat',
'ChatPromptTemplate',
]),
'kwargs': dict({
'input_variables': list([
'bar',
]),
'messages': list([
dict({
'id': list([
'langchain',
'prompts',
'chat',
'SystemMessagePromptTemplate',
]),
'kwargs': dict({
'prompt': dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'PromptInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'prompts',
'prompt',
'PromptTemplate',
]),
'name': 'PromptTemplate',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'PromptTemplateOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain',
'prompts',
'prompt',
'PromptTemplate',
]),
'kwargs': dict({
'input_variables': list([
]),
'template': 'foo',
'template_format': 'f-string',
}),
'lc': 1,
'name': 'PromptTemplate',
'type': 'constructor',
}),
}),
'lc': 1,
'type': 'constructor',
}),
dict({
'id': list([
'langchain',
'prompts',
'chat',
'MessagesPlaceholder',
]),
'kwargs': dict({
'variable_name': 'bar',
}),
'lc': 1,
'type': 'constructor',
}),
dict({
'id': list([
'langchain',
'prompts',
'chat',
'HumanMessagePromptTemplate',
]),
'kwargs': dict({
'prompt': dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'PromptInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'prompts',
'prompt',
'PromptTemplate',
]),
'name': 'PromptTemplate',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'PromptTemplateOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain',
'prompts',
'prompt',
'PromptTemplate',
]),
'kwargs': dict({
'input_variables': list([
]),
'template': 'baz',
'template_format': 'f-string',
}),
'lc': 1,
'name': 'PromptTemplate',
'type': 'constructor',
}),
}),
'lc': 1,
'type': 'constructor',
}),
]),
}),
'lc': 1,
'name': 'ChatPromptTemplate',
'type': 'constructor',
})
# ---
# name: test_chat_prompt_w_msgs_placeholder_ser_des[placholder]
dict({
'id': list([
'langchain',
'prompts',
'chat',
'MessagesPlaceholder',
]),
'kwargs': dict({
'variable_name': 'bar',
}),
'lc': 1,
'type': 'constructor',
})
# ---

View File

@ -6,9 +6,8 @@ from typing import Any, List, Union
import pytest import pytest
from syrupy import SnapshotAssertion from syrupy import SnapshotAssertion
from langchain_core._api.deprecation import ( from langchain_core._api.deprecation import LangChainPendingDeprecationWarning
LangChainPendingDeprecationWarning, from langchain_core.load import dumpd, load
)
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
BaseMessage, BaseMessage,
@ -806,3 +805,13 @@ def test_chat_input_schema(snapshot: SnapshotAssertion) -> None:
assert set(prompt_optional.input_variables) == {"input"} assert set(prompt_optional.input_variables) == {"input"}
prompt_optional.input_schema(input="") # won't raise error prompt_optional.input_schema(input="") # won't raise error
assert prompt_optional.input_schema.schema() == snapshot(name="partial") assert prompt_optional.input_schema.schema() == snapshot(name="partial")
def test_chat_prompt_w_msgs_placeholder_ser_des(snapshot: SnapshotAssertion) -> None:
prompt = ChatPromptTemplate.from_messages(
[("system", "foo"), MessagesPlaceholder("bar"), ("human", "baz")]
)
assert dumpd(MessagesPlaceholder("bar")) == snapshot(name="placholder")
assert load(dumpd(MessagesPlaceholder("bar"))) == MessagesPlaceholder("bar")
assert dumpd(prompt) == snapshot(name="chat_prompt")
assert load(dumpd(prompt)) == prompt