mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 01:37:59 +00:00
Add option to make messages placeholder optional (#15031)
<!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes if applicable, - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
40f42b8947
commit
d5533b7081
@ -87,13 +87,17 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
|||||||
variable_name: str
|
variable_name: str
|
||||||
"""Name of variable to use as messages."""
|
"""Name of variable to use as messages."""
|
||||||
|
|
||||||
|
optional: bool = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> List[str]:
|
||||||
"""Get the namespace of the langchain object."""
|
"""Get the namespace of the langchain object."""
|
||||||
return ["langchain", "prompts", "chat"]
|
return ["langchain", "prompts", "chat"]
|
||||||
|
|
||||||
def __init__(self, variable_name: str, **kwargs: Any):
|
def __init__(self, variable_name: str, *, optional: bool = False, **kwargs: Any):
|
||||||
return super().__init__(variable_name=variable_name, **kwargs)
|
return super().__init__(
|
||||||
|
variable_name=variable_name, optional=optional, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||||
"""Format messages from kwargs.
|
"""Format messages from kwargs.
|
||||||
@ -104,7 +108,11 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
|||||||
Returns:
|
Returns:
|
||||||
List of BaseMessage.
|
List of BaseMessage.
|
||||||
"""
|
"""
|
||||||
value = kwargs[self.variable_name]
|
value = (
|
||||||
|
kwargs.get(self.variable_name, [])
|
||||||
|
if self.optional
|
||||||
|
else kwargs[self.variable_name]
|
||||||
|
)
|
||||||
if not isinstance(value, list):
|
if not isinstance(value, list):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"variable {self.variable_name} should be a list of base messages, "
|
f"variable {self.variable_name} should be a list of base messages, "
|
||||||
|
@ -19,6 +19,7 @@ from langchain_core.prompts.chat import (
|
|||||||
ChatMessagePromptTemplate,
|
ChatMessagePromptTemplate,
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
|
MessagesPlaceholder,
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
_convert_to_message,
|
_convert_to_message,
|
||||||
)
|
)
|
||||||
@ -360,3 +361,11 @@ def test_chat_message_partial() -> None:
|
|||||||
]
|
]
|
||||||
assert res == expected
|
assert res == expected
|
||||||
assert template2.format(input="hello") == get_buffer_string(expected)
|
assert template2.format(input="hello") == get_buffer_string(expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_messages_placeholder() -> None:
|
||||||
|
prompt = MessagesPlaceholder("history")
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
prompt.format_messages()
|
||||||
|
prompt = MessagesPlaceholder("history", optional=True)
|
||||||
|
assert prompt.format_messages() == []
|
||||||
|
Loading…
Reference in New Issue
Block a user