mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 02:03:44 +00:00
core[patch]: Fix regression requiring input_variables in few chat prompt templates (#24360)
* Fix regression that requires users passing input_variables=[]. * Regression introduced by my own changes to this PR: https://github.com/langchain-ai/langchain/pull/22851
This commit is contained in:
parent
034a8c7c1b
commit
96bac8e20d
@ -18,7 +18,7 @@ from langchain_core.prompts.string import (
|
|||||||
check_valid_template,
|
check_valid_template,
|
||||||
get_template_variables,
|
get_template_variables,
|
||||||
)
|
)
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||||
|
|
||||||
|
|
||||||
class _FewShotPromptTemplateMixin(BaseModel):
|
class _FewShotPromptTemplateMixin(BaseModel):
|
||||||
@ -135,6 +135,12 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
|
|||||||
template_format: Literal["f-string", "jinja2"] = "f-string"
|
template_format: Literal["f-string", "jinja2"] = "f-string"
|
||||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
"""Initialize the few shot prompt template."""
|
||||||
|
if "input_variables" not in kwargs and "example_prompt" in kwargs:
|
||||||
|
kwargs["input_variables"] = kwargs["example_prompt"].input_variables
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
@root_validator(pre=False, skip_on_failure=True)
|
@root_validator(pre=False, skip_on_failure=True)
|
||||||
def template_is_valid(cls, values: Dict) -> Dict:
|
def template_is_valid(cls, values: Dict) -> Dict:
|
||||||
"""Check that prefix, suffix, and input variables are consistent."""
|
"""Check that prefix, suffix, and input variables are consistent."""
|
||||||
@ -351,14 +357,18 @@ class FewShotChatMessagePromptTemplate(
|
|||||||
chain.invoke({"input": "What's 3+3?"})
|
chain.invoke({"input": "What's 3+3?"})
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
input_variables: List[str] = Field(default_factory=list)
|
||||||
|
"""A list of the names of the variables the prompt template will use
|
||||||
|
to pass to the example_selector, if provided."""
|
||||||
|
|
||||||
|
example_prompt: Union[BaseMessagePromptTemplate, BaseChatPromptTemplate]
|
||||||
|
"""The class to format each example."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_lc_serializable(cls) -> bool:
|
def is_lc_serializable(cls) -> bool:
|
||||||
"""Return whether or not the class is serializable."""
|
"""Return whether or not the class is serializable."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
example_prompt: Union[BaseMessagePromptTemplate, BaseChatPromptTemplate]
|
|
||||||
"""The class to format each example."""
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
@ -58,6 +58,17 @@ def test_suffix_only() -> None:
|
|||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_auto_infer_input_variables() -> None:
|
||||||
|
"""Test prompt works with just a suffix."""
|
||||||
|
suffix = "This is a {foo} test."
|
||||||
|
prompt = FewShotPromptTemplate(
|
||||||
|
suffix=suffix,
|
||||||
|
examples=[],
|
||||||
|
example_prompt=EXAMPLE_PROMPT,
|
||||||
|
)
|
||||||
|
assert prompt.input_variables == ["foo"]
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_missing_input_variables() -> None:
|
def test_prompt_missing_input_variables() -> None:
|
||||||
"""Test error is raised when input variables are not provided."""
|
"""Test error is raised when input variables are not provided."""
|
||||||
# Test when missing in suffix
|
# Test when missing in suffix
|
||||||
@ -422,6 +433,30 @@ def test_few_shot_chat_message_prompt_template_with_selector() -> None:
|
|||||||
assert messages == expected
|
assert messages == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_few_shot_chat_message_prompt_template_infer_input_variables() -> None:
|
||||||
|
"""Check that it can infer input variables if not provided."""
|
||||||
|
examples = [
|
||||||
|
{"input": "2+2", "output": "4"},
|
||||||
|
{"input": "2+3", "output": "5"},
|
||||||
|
]
|
||||||
|
example_selector = AsIsSelector(examples)
|
||||||
|
example_prompt = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
HumanMessagePromptTemplate.from_template("{input}"),
|
||||||
|
AIMessagePromptTemplate.from_template("{output}"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
few_shot_prompt = FewShotChatMessagePromptTemplate(
|
||||||
|
example_prompt=example_prompt,
|
||||||
|
example_selector=example_selector,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The prompt template does not have any inputs! They
|
||||||
|
# have already been filled in.
|
||||||
|
assert few_shot_prompt.input_variables == []
|
||||||
|
|
||||||
|
|
||||||
class AsyncAsIsSelector(BaseExampleSelector):
|
class AsyncAsIsSelector(BaseExampleSelector):
|
||||||
"""An example selector for testing purposes.
|
"""An example selector for testing purposes.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user