Add option to make messages placeholder optional (#15031)

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos 2023-12-21 12:36:37 -08:00 committed by GitHub
parent 40f42b8947
commit d5533b7081
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 3 deletions

View File

@ -87,13 +87,17 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
variable_name: str variable_name: str
"""Name of variable to use as messages.""" """Name of variable to use as messages."""
optional: bool = False
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"] return ["langchain", "prompts", "chat"]
def __init__(self, variable_name: str, **kwargs: Any): def __init__(self, variable_name: str, *, optional: bool = False, **kwargs: Any):
return super().__init__(variable_name=variable_name, **kwargs) return super().__init__(
variable_name=variable_name, optional=optional, **kwargs
)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]: def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs. """Format messages from kwargs.
@ -104,7 +108,11 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
Returns: Returns:
List of BaseMessage. List of BaseMessage.
""" """
value = kwargs[self.variable_name] value = (
kwargs.get(self.variable_name, [])
if self.optional
else kwargs[self.variable_name]
)
if not isinstance(value, list): if not isinstance(value, list):
raise ValueError( raise ValueError(
f"variable {self.variable_name} should be a list of base messages, " f"variable {self.variable_name} should be a list of base messages, "

View File

@ -19,6 +19,7 @@ from langchain_core.prompts.chat import (
ChatMessagePromptTemplate, ChatMessagePromptTemplate,
ChatPromptTemplate, ChatPromptTemplate,
HumanMessagePromptTemplate, HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate, SystemMessagePromptTemplate,
_convert_to_message, _convert_to_message,
) )
@ -360,3 +361,11 @@ def test_chat_message_partial() -> None:
] ]
assert res == expected assert res == expected
assert template2.format(input="hello") == get_buffer_string(expected) assert template2.format(input="hello") == get_buffer_string(expected)
def test_messages_placeholder() -> None:
prompt = MessagesPlaceholder("history")
with pytest.raises(KeyError):
prompt.format_messages()
prompt = MessagesPlaceholder("history", optional=True)
assert prompt.format_messages() == []