mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-31 12:09:58 +00:00
core[patch]: fix chat prompt partial messages placeholder var (#16918)
This commit is contained in:
parent
3b0fa9079d
commit
c29e9b6412
@ -47,9 +47,7 @@ class BasePromptTemplate(
|
||||
If not provided, all variables are assumed to be strings."""
|
||||
output_parser: Optional[BaseOutputParser] = None
|
||||
"""How to parse the output of calling an LLM on this formatted prompt."""
|
||||
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field(
|
||||
default_factory=dict
|
||||
)
|
||||
partial_variables: Mapping[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
@ -143,8 +141,7 @@ class BasePromptTemplate(
|
||||
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
# Get partial params:
|
||||
partial_kwargs = {
|
||||
k: v if isinstance(v, str) else v()
|
||||
for k, v in self.partial_variables.items()
|
||||
k: v if not callable(v) else v() for k, v in self.partial_variables.items()
|
||||
}
|
||||
return {**partial_kwargs, **kwargs}
|
||||
|
||||
|
@ -5,7 +5,6 @@ from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
@ -130,13 +129,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
f"variable {self.variable_name} should be a list of base messages, "
|
||||
f"got {value}"
|
||||
)
|
||||
for v in convert_to_messages(value):
|
||||
if not isinstance(v, BaseMessage):
|
||||
raise ValueError(
|
||||
f"variable {self.variable_name} should be a list of base messages,"
|
||||
f" got {value}"
|
||||
)
|
||||
return value
|
||||
return convert_to_messages(value)
|
||||
|
||||
@property
|
||||
def input_variables(self) -> List[str]:
|
||||
@ -755,13 +748,20 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
|
||||
# Automatically infer input variables from messages
|
||||
input_vars: Set[str] = set()
|
||||
partial_vars: Dict[str, Any] = {}
|
||||
for _message in _messages:
|
||||
if isinstance(
|
||||
if isinstance(_message, MessagesPlaceholder) and _message.optional:
|
||||
partial_vars[_message.variable_name] = []
|
||||
elif isinstance(
|
||||
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)
|
||||
):
|
||||
input_vars.update(_message.input_variables)
|
||||
|
||||
return cls(input_variables=sorted(input_vars), messages=_messages)
|
||||
return cls(
|
||||
input_variables=sorted(input_vars),
|
||||
messages=_messages,
|
||||
partial_variables=partial_vars,
|
||||
)
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the chat template into a string.
|
||||
@ -799,7 +799,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
raise ValueError(f"Unexpected input: {message_template}")
|
||||
return result
|
||||
|
||||
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> ChatPromptTemplate:
|
||||
def partial(self, **kwargs: Any) -> ChatPromptTemplate:
|
||||
"""Get a new ChatPromptTemplate with some input variables already filled in.
|
||||
|
||||
Args:
|
||||
|
@ -503,9 +503,27 @@ def test_messages_placeholder() -> None:
|
||||
prompt.format_messages()
|
||||
prompt = MessagesPlaceholder("history", optional=True)
|
||||
assert prompt.format_messages() == []
|
||||
prompt.format_messages(
|
||||
assert prompt.format_messages(
|
||||
history=[("system", "You are an AI assistant."), "Hello!"]
|
||||
) == [
|
||||
SystemMessage(content="You are an AI assistant."),
|
||||
HumanMessage(content="Hello!"),
|
||||
]
|
||||
|
||||
|
||||
def test_chat_prompt_message_placeholder_partial() -> None:
|
||||
prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("history")])
|
||||
prompt = prompt.partial(history=[("system", "foo")])
|
||||
assert prompt.format_messages() == [SystemMessage(content="foo")]
|
||||
assert prompt.format_messages(history=[("system", "bar")]) == [
|
||||
SystemMessage(content="bar")
|
||||
]
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
MessagesPlaceholder("history", optional=True),
|
||||
]
|
||||
)
|
||||
assert prompt.format_messages() == []
|
||||
prompt = prompt.partial(history=[("system", "foo")])
|
||||
assert prompt.format_messages() == [SystemMessage(content="foo")]
|
||||
|
@ -1544,7 +1544,8 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
],
|
||||
"partial_variables": {}
|
||||
}
|
||||
},
|
||||
"middle": [],
|
||||
@ -1617,7 +1618,8 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
],
|
||||
"partial_variables": {}
|
||||
}
|
||||
},
|
||||
"middle": [],
|
||||
|
@ -191,7 +191,8 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
],
|
||||
"partial_variables": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user