mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-02 11:39:18 +00:00
core: allow passing message dicts into ChatPromptTemplate (#29363)
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -828,6 +828,7 @@ MessageLikeRepresentation = Union[
|
|||||||
Union[str, list[dict], list[object]],
|
Union[str, list[dict], list[object]],
|
||||||
],
|
],
|
||||||
str,
|
str,
|
||||||
|
dict,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -1461,7 +1462,15 @@ def _convert_to_message(
|
|||||||
_message = _create_template_from_message_type(
|
_message = _create_template_from_message_type(
|
||||||
"human", message, template_format=template_format
|
"human", message, template_format=template_format
|
||||||
)
|
)
|
||||||
elif isinstance(message, tuple):
|
elif isinstance(message, (tuple, dict)):
|
||||||
|
if isinstance(message, dict):
|
||||||
|
if set(message.keys()) != {"content", "role"}:
|
||||||
|
msg = (
|
||||||
|
"Expected dict to have exact keys 'role' and 'content'."
|
||||||
|
f" Got: {message}"
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
message = (message["role"], message["content"])
|
||||||
if len(message) != 2:
|
if len(message) != 2:
|
||||||
msg = f"Expected 2-tuple of (role, template), got {message}"
|
msg = f"Expected 2-tuple of (role, template), got {message}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
@@ -824,6 +824,41 @@ def test_chat_prompt_message_placeholder_tuple() -> None:
|
|||||||
assert optional_prompt.format_messages() == []
|
assert optional_prompt.format_messages() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_prompt_message_placeholder_dict() -> None:
|
||||||
|
prompt = ChatPromptTemplate([{"role": "placeholder", "content": "{convo}"}])
|
||||||
|
assert prompt.format_messages(convo=[("user", "foo")]) == [
|
||||||
|
HumanMessage(content="foo")
|
||||||
|
]
|
||||||
|
|
||||||
|
assert prompt.format_messages() == []
|
||||||
|
|
||||||
|
# Is optional = True
|
||||||
|
optional_prompt = ChatPromptTemplate(
|
||||||
|
[{"role": "placeholder", "content": ["{convo}", False]}]
|
||||||
|
)
|
||||||
|
assert optional_prompt.format_messages(convo=[("user", "foo")]) == [
|
||||||
|
HumanMessage(content="foo")
|
||||||
|
]
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
assert optional_prompt.format_messages() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_prompt_message_dict() -> None:
|
||||||
|
prompt = ChatPromptTemplate(
|
||||||
|
[{"role": "system", "content": "foo"}, {"role": "user", "content": "bar"}]
|
||||||
|
)
|
||||||
|
assert prompt.format_messages() == [
|
||||||
|
SystemMessage(content="foo"),
|
||||||
|
HumanMessage(content="bar"),
|
||||||
|
]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ChatPromptTemplate([{"role": "system", "content": False}])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ChatPromptTemplate([{"role": "foo", "content": "foo"}])
|
||||||
|
|
||||||
|
|
||||||
async def test_messages_prompt_accepts_list() -> None:
|
async def test_messages_prompt_accepts_list() -> None:
|
||||||
prompt = ChatPromptTemplate([MessagesPlaceholder("history")])
|
prompt = ChatPromptTemplate([MessagesPlaceholder("history")])
|
||||||
value = prompt.invoke([("user", "Hi there")]) # type: ignore
|
value = prompt.invoke([("user", "Hi there")]) # type: ignore
|
||||||
|
Reference in New Issue
Block a user