mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-15 09:23:57 +00:00
Add option to make messages placeholder optional (#15031)
<!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes if applicable, - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
40f42b8947
commit
d5533b7081
@ -87,13 +87,17 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
variable_name: str
|
||||
"""Name of variable to use as messages."""
|
||||
|
||||
optional: bool = False
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "chat"]
|
||||
|
||||
def __init__(self, variable_name: str, **kwargs: Any):
|
||||
return super().__init__(variable_name=variable_name, **kwargs)
|
||||
def __init__(self, variable_name: str, *, optional: bool = False, **kwargs: Any):
|
||||
return super().__init__(
|
||||
variable_name=variable_name, optional=optional, **kwargs
|
||||
)
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format messages from kwargs.
|
||||
@ -104,7 +108,11 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
Returns:
|
||||
List of BaseMessage.
|
||||
"""
|
||||
value = kwargs[self.variable_name]
|
||||
value = (
|
||||
kwargs.get(self.variable_name, [])
|
||||
if self.optional
|
||||
else kwargs[self.variable_name]
|
||||
)
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(
|
||||
f"variable {self.variable_name} should be a list of base messages, "
|
||||
|
@ -19,6 +19,7 @@ from langchain_core.prompts.chat import (
|
||||
ChatMessagePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
SystemMessagePromptTemplate,
|
||||
_convert_to_message,
|
||||
)
|
||||
@ -360,3 +361,11 @@ def test_chat_message_partial() -> None:
|
||||
]
|
||||
assert res == expected
|
||||
assert template2.format(input="hello") == get_buffer_string(expected)
|
||||
|
||||
|
||||
def test_messages_placeholder() -> None:
|
||||
prompt = MessagesPlaceholder("history")
|
||||
with pytest.raises(KeyError):
|
||||
prompt.format_messages()
|
||||
prompt = MessagesPlaceholder("history", optional=True)
|
||||
assert prompt.format_messages() == []
|
||||
|
Loading…
Reference in New Issue
Block a user