Add option to make messages placeholder optional (#15031)

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos 2023-12-21 12:36:37 -08:00 committed by GitHub
parent 40f42b8947
commit d5533b7081
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 3 deletions

View File

@ -87,13 +87,17 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
variable_name: str
"""Name of variable to use as messages."""
optional: bool = False
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def __init__(self, variable_name: str, **kwargs: Any):
return super().__init__(variable_name=variable_name, **kwargs)
def __init__(self, variable_name: str, *, optional: bool = False, **kwargs: Any):
return super().__init__(
variable_name=variable_name, optional=optional, **kwargs
)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs.
@ -104,7 +108,11 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
Returns:
List of BaseMessage.
"""
value = kwargs[self.variable_name]
value = (
kwargs.get(self.variable_name, [])
if self.optional
else kwargs[self.variable_name]
)
if not isinstance(value, list):
raise ValueError(
f"variable {self.variable_name} should be a list of base messages, "

View File

@ -19,6 +19,7 @@ from langchain_core.prompts.chat import (
ChatMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
_convert_to_message,
)
@ -360,3 +361,11 @@ def test_chat_message_partial() -> None:
]
assert res == expected
assert template2.format(input="hello") == get_buffer_string(expected)
def test_messages_placeholder() -> None:
prompt = MessagesPlaceholder("history")
with pytest.raises(KeyError):
prompt.format_messages()
prompt = MessagesPlaceholder("history", optional=True)
assert prompt.format_messages() == []