core[patch]: fix structured prompt template format (#27003)

template_format is an init argument on ChatPromptTemplate but not an
attribute on the object so was getting shoved into
StructuredPrompt.structured_ouptut_kwargs
This commit is contained in:
Bagatur 2024-09-30 11:47:46 -07:00 committed by GitHub
parent 0078493a80
commit 248be02259
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 3 deletions

View File

@ -2,6 +2,7 @@ from collections.abc import Iterator, Mapping, Sequence
from typing import (
Any,
Callable,
Literal,
Optional,
Union,
)
@ -37,6 +38,7 @@ class StructuredPrompt(ChatPromptTemplate):
schema_: Optional[Union[dict, type[BaseModel]]] = None,
*,
structured_output_kwargs: Optional[dict[str, Any]] = None,
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
**kwargs: Any,
) -> None:
schema_ = schema_ or kwargs.pop("schema")
@ -47,6 +49,7 @@ class StructuredPrompt(ChatPromptTemplate):
messages=messages,
schema_=schema_,
structured_output_kwargs=structured_output_kwargs,
template_format=template_format,
**kwargs,
)

View File

@ -1,13 +1,13 @@
from functools import partial
from inspect import isclass
from typing import Any, Union, cast
from typing import Optional as Optional
from pydantic import BaseModel
from langchain_core.language_models import FakeListChatModel
from langchain_core.load.dump import dumps
from langchain_core.load.load import loads
from langchain_core.messages import HumanMessage
from langchain_core.prompts.structured import StructuredPrompt
from langchain_core.runnables.base import Runnable, RunnableLambda
from langchain_core.utils.pydantic import is_basemodel_subclass
@ -121,3 +121,14 @@ def test_structured_prompt_kwargs() -> None:
chain = prompt | model
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=7)
def test_structured_prompt_template_format() -> None:
prompt = StructuredPrompt(
[("human", "hi {{person.name}}")], schema={}, template_format="mustache"
)
assert prompt.messages[0].prompt.template_format == "mustache" # type: ignore[union-attr, union-attr]
assert prompt.input_variables == ["person"]
assert prompt.invoke({"person": {"name": "foo"}}).to_messages() == [
HumanMessage("hi foo")
]

View File

@ -126,8 +126,6 @@ def test_configurable() -> None:
"tiktoken_model_name": None,
"default_headers": None,
"default_query": None,
"http_client": None,
"http_async_client": None,
"stop": None,
"extra_body": None,
"include_response_headers": False,