core[patch]: fix structured prompt template format (#27003)

template_format is an init argument on ChatPromptTemplate but not an
attribute on the object so was getting shoved into
StructuredPrompt.structured_ouptut_kwargs
This commit is contained in:
Bagatur 2024-09-30 11:47:46 -07:00 committed by GitHub
parent 0078493a80
commit 248be02259
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 3 deletions

View File

@ -2,6 +2,7 @@ from collections.abc import Iterator, Mapping, Sequence
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Literal,
Optional, Optional,
Union, Union,
) )
@ -37,6 +38,7 @@ class StructuredPrompt(ChatPromptTemplate):
schema_: Optional[Union[dict, type[BaseModel]]] = None, schema_: Optional[Union[dict, type[BaseModel]]] = None,
*, *,
structured_output_kwargs: Optional[dict[str, Any]] = None, structured_output_kwargs: Optional[dict[str, Any]] = None,
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
schema_ = schema_ or kwargs.pop("schema") schema_ = schema_ or kwargs.pop("schema")
@ -47,6 +49,7 @@ class StructuredPrompt(ChatPromptTemplate):
messages=messages, messages=messages,
schema_=schema_, schema_=schema_,
structured_output_kwargs=structured_output_kwargs, structured_output_kwargs=structured_output_kwargs,
template_format=template_format,
**kwargs, **kwargs,
) )

View File

@ -1,13 +1,13 @@
from functools import partial from functools import partial
from inspect import isclass from inspect import isclass
from typing import Any, Union, cast from typing import Any, Union, cast
from typing import Optional as Optional
from pydantic import BaseModel from pydantic import BaseModel
from langchain_core.language_models import FakeListChatModel from langchain_core.language_models import FakeListChatModel
from langchain_core.load.dump import dumps from langchain_core.load.dump import dumps
from langchain_core.load.load import loads from langchain_core.load.load import loads
from langchain_core.messages import HumanMessage
from langchain_core.prompts.structured import StructuredPrompt from langchain_core.prompts.structured import StructuredPrompt
from langchain_core.runnables.base import Runnable, RunnableLambda from langchain_core.runnables.base import Runnable, RunnableLambda
from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_core.utils.pydantic import is_basemodel_subclass
@ -121,3 +121,14 @@ def test_structured_prompt_kwargs() -> None:
chain = prompt | model chain = prompt | model
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=7) assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=7)
def test_structured_prompt_template_format() -> None:
prompt = StructuredPrompt(
[("human", "hi {{person.name}}")], schema={}, template_format="mustache"
)
assert prompt.messages[0].prompt.template_format == "mustache" # type: ignore[union-attr, union-attr]
assert prompt.input_variables == ["person"]
assert prompt.invoke({"person": {"name": "foo"}}).to_messages() == [
HumanMessage("hi foo")
]

View File

@ -126,8 +126,6 @@ def test_configurable() -> None:
"tiktoken_model_name": None, "tiktoken_model_name": None,
"default_headers": None, "default_headers": None,
"default_query": None, "default_query": None,
"http_client": None,
"http_async_client": None,
"stop": None, "stop": None,
"extra_body": None, "extra_body": None,
"include_response_headers": False, "include_response_headers": False,