core[minor]: Implement aformat_messages for _StringImageMessagePromptTemplate (#20036)

This commit is contained in:
Christophe Bornet
2024-04-09 21:59:39 +02:00
committed by GitHub
parent 19001e6cb9
commit f43b48aebc
2 changed files with 78 additions and 16 deletions

View File

@@ -143,6 +143,9 @@ async def test_chat_prompt_template(chat_prompt_template: ChatPromptTemplate) ->
string = chat_prompt_template.format(foo="foo", bar="bar", context="context")
assert string == expected
string = await chat_prompt_template.aformat(foo="foo", bar="bar", context="context")
assert string == expected
def test_chat_prompt_template_from_messages(
messages: List[BaseMessagePromptTemplate],
@@ -155,7 +158,7 @@ def test_chat_prompt_template_from_messages(
assert len(chat_prompt_template.messages) == 4
def test_chat_prompt_template_from_messages_using_role_strings() -> None:
async def test_chat_prompt_template_from_messages_using_role_strings() -> None:
"""Test creating a chat prompt template from role string messages."""
template = ChatPromptTemplate.from_messages(
[
@@ -166,9 +169,7 @@ def test_chat_prompt_template_from_messages_using_role_strings() -> None:
]
)
messages = template.format_messages(name="Bob", user_input="What is your name?")
assert messages == [
expected = [
SystemMessage(
content="You are a helpful AI bot. Your name is Bob.", additional_kwargs={}
),
@@ -181,6 +182,14 @@ def test_chat_prompt_template_from_messages_using_role_strings() -> None:
HumanMessage(content="What is your name?", additional_kwargs={}, example=False),
]
messages = template.format_messages(name="Bob", user_input="What is your name?")
assert messages == expected
messages = await template.aformat_messages(
name="Bob", user_input="What is your name?"
)
assert messages == expected
def test_chat_prompt_template_with_messages(
messages: List[BaseMessagePromptTemplate],
@@ -262,7 +271,7 @@ def test_chat_valid_infer_variables() -> None:
assert prompt.partial_variables == {"formatins": "some structure"}
def test_chat_from_role_strings() -> None:
async def test_chat_from_role_strings() -> None:
"""Test instantiation of chat template from role strings."""
with pytest.warns(LangChainPendingDeprecationWarning):
template = ChatPromptTemplate.from_role_strings(
@@ -274,14 +283,19 @@ def test_chat_from_role_strings() -> None:
]
)
messages = template.format_messages(question="How are you?", quack="duck")
assert messages == [
expected = [
ChatMessage(content="You are a bot.", role="system"),
ChatMessage(content="hello!", role="assistant"),
ChatMessage(content="How are you?", role="human"),
ChatMessage(content="duck", role="other"),
]
messages = template.format_messages(question="How are you?", quack="duck")
assert messages == expected
messages = await template.aformat_messages(question="How are you?", quack="duck")
assert messages == expected
@pytest.mark.parametrize(
"args,expected",
@@ -385,7 +399,7 @@ def test_chat_message_partial() -> None:
assert template2.format(input="hello") == get_buffer_string(expected)
def test_chat_tmpl_from_messages_multipart_text() -> None:
async def test_chat_tmpl_from_messages_multipart_text() -> None:
template = ChatPromptTemplate.from_messages(
[
("system", "You are an AI assistant named {name}."),
@@ -398,7 +412,6 @@ def test_chat_tmpl_from_messages_multipart_text() -> None:
),
]
)
messages = template.format_messages(name="R2D2")
expected = [
SystemMessage(content="You are an AI assistant named R2D2."),
HumanMessage(
@@ -408,10 +421,14 @@ def test_chat_tmpl_from_messages_multipart_text() -> None:
]
),
]
messages = template.format_messages(name="R2D2")
assert messages == expected
messages = await template.aformat_messages(name="R2D2")
assert messages == expected
def test_chat_tmpl_from_messages_multipart_text_with_template() -> None:
async def test_chat_tmpl_from_messages_multipart_text_with_template() -> None:
template = ChatPromptTemplate.from_messages(
[
("system", "You are an AI assistant named {name}."),
@@ -424,7 +441,6 @@ def test_chat_tmpl_from_messages_multipart_text_with_template() -> None:
),
]
)
messages = template.format_messages(name="R2D2", object_name="image")
expected = [
SystemMessage(content="You are an AI assistant named R2D2."),
HumanMessage(
@@ -434,10 +450,14 @@ def test_chat_tmpl_from_messages_multipart_text_with_template() -> None:
]
),
]
messages = template.format_messages(name="R2D2", object_name="image")
assert messages == expected
messages = await template.aformat_messages(name="R2D2", object_name="image")
assert messages == expected
def test_chat_tmpl_from_messages_multipart_image() -> None:
async def test_chat_tmpl_from_messages_multipart_image() -> None:
base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA"
other_base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA"
template = ChatPromptTemplate.from_messages(
@@ -472,9 +492,6 @@ def test_chat_tmpl_from_messages_multipart_image() -> None:
),
]
)
messages = template.format_messages(
name="R2D2", my_image=base64_image, my_other_image=other_base64_image
)
expected = [
SystemMessage(content="You are an AI assistant named R2D2."),
HumanMessage(
@@ -512,6 +529,14 @@ def test_chat_tmpl_from_messages_multipart_image() -> None:
]
),
]
messages = template.format_messages(
name="R2D2", my_image=base64_image, my_other_image=other_base64_image
)
assert messages == expected
messages = await template.aformat_messages(
name="R2D2", my_image=base64_image, my_other_image=other_base64_image
)
assert messages == expected
@@ -566,14 +591,20 @@ def test_chat_prompt_message_placeholder_tuple() -> None:
assert optional_prompt.format_messages() == []
def test_messages_prompt_accepts_list() -> None:
async def test_messages_prompt_accepts_list() -> None:
prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("history")])
value = prompt.invoke([("user", "Hi there")]) # type: ignore
assert value.to_messages() == [HumanMessage(content="Hi there")]
value = await prompt.ainvoke([("user", "Hi there")]) # type: ignore
assert value.to_messages() == [HumanMessage(content="Hi there")]
# Assert still raises a nice error
prompt = ChatPromptTemplate.from_messages(
[("system", "You are a {foo}"), MessagesPlaceholder("history")]
)
with pytest.raises(TypeError):
prompt.invoke([("user", "Hi there")]) # type: ignore
with pytest.raises(TypeError):
await prompt.ainvoke([("user", "Hi there")]) # type: ignore