mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-11 05:45:01 +00:00
core[minor]: Implement aformat_messages for _StringImageMessagePromptTemplate (#20036)
This commit is contained in:
parent
19001e6cb9
commit
f43b48aebc
@ -506,6 +506,9 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
"""
|
"""
|
||||||
return [self.format(**kwargs)]
|
return [self.format(**kwargs)]
|
||||||
|
|
||||||
|
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||||
|
return [await self.aformat(**kwargs)]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_variables(self) -> List[str]:
|
def input_variables(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
@ -546,6 +549,34 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
content=content, additional_kwargs=self.additional_kwargs
|
content=content, additional_kwargs=self.additional_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def aformat(self, **kwargs: Any) -> BaseMessage:
|
||||||
|
"""Format the prompt template.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Keyword arguments to use for formatting.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted message.
|
||||||
|
"""
|
||||||
|
if isinstance(self.prompt, StringPromptTemplate):
|
||||||
|
text = await self.prompt.aformat(**kwargs)
|
||||||
|
return self._msg_class(
|
||||||
|
content=text, additional_kwargs=self.additional_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
content: List = []
|
||||||
|
for prompt in self.prompt:
|
||||||
|
inputs = {var: kwargs[var] for var in prompt.input_variables}
|
||||||
|
if isinstance(prompt, StringPromptTemplate):
|
||||||
|
formatted: Union[str, ImageURL] = await prompt.aformat(**inputs)
|
||||||
|
content.append({"type": "text", "text": formatted})
|
||||||
|
elif isinstance(prompt, ImagePromptTemplate):
|
||||||
|
formatted = await prompt.aformat(**inputs)
|
||||||
|
content.append({"type": "image_url", "image_url": formatted})
|
||||||
|
return self._msg_class(
|
||||||
|
content=content, additional_kwargs=self.additional_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def pretty_repr(self, html: bool = False) -> str:
|
def pretty_repr(self, html: bool = False) -> str:
|
||||||
# TODO: Handle partials
|
# TODO: Handle partials
|
||||||
title = self.__class__.__name__.replace("MessagePromptTemplate", " Message")
|
title = self.__class__.__name__.replace("MessagePromptTemplate", " Message")
|
||||||
|
@ -143,6 +143,9 @@ async def test_chat_prompt_template(chat_prompt_template: ChatPromptTemplate) ->
|
|||||||
string = chat_prompt_template.format(foo="foo", bar="bar", context="context")
|
string = chat_prompt_template.format(foo="foo", bar="bar", context="context")
|
||||||
assert string == expected
|
assert string == expected
|
||||||
|
|
||||||
|
string = await chat_prompt_template.aformat(foo="foo", bar="bar", context="context")
|
||||||
|
assert string == expected
|
||||||
|
|
||||||
|
|
||||||
def test_chat_prompt_template_from_messages(
|
def test_chat_prompt_template_from_messages(
|
||||||
messages: List[BaseMessagePromptTemplate],
|
messages: List[BaseMessagePromptTemplate],
|
||||||
@ -155,7 +158,7 @@ def test_chat_prompt_template_from_messages(
|
|||||||
assert len(chat_prompt_template.messages) == 4
|
assert len(chat_prompt_template.messages) == 4
|
||||||
|
|
||||||
|
|
||||||
def test_chat_prompt_template_from_messages_using_role_strings() -> None:
|
async def test_chat_prompt_template_from_messages_using_role_strings() -> None:
|
||||||
"""Test creating a chat prompt template from role string messages."""
|
"""Test creating a chat prompt template from role string messages."""
|
||||||
template = ChatPromptTemplate.from_messages(
|
template = ChatPromptTemplate.from_messages(
|
||||||
[
|
[
|
||||||
@ -166,9 +169,7 @@ def test_chat_prompt_template_from_messages_using_role_strings() -> None:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = template.format_messages(name="Bob", user_input="What is your name?")
|
expected = [
|
||||||
|
|
||||||
assert messages == [
|
|
||||||
SystemMessage(
|
SystemMessage(
|
||||||
content="You are a helpful AI bot. Your name is Bob.", additional_kwargs={}
|
content="You are a helpful AI bot. Your name is Bob.", additional_kwargs={}
|
||||||
),
|
),
|
||||||
@ -181,6 +182,14 @@ def test_chat_prompt_template_from_messages_using_role_strings() -> None:
|
|||||||
HumanMessage(content="What is your name?", additional_kwargs={}, example=False),
|
HumanMessage(content="What is your name?", additional_kwargs={}, example=False),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
messages = template.format_messages(name="Bob", user_input="What is your name?")
|
||||||
|
assert messages == expected
|
||||||
|
|
||||||
|
messages = await template.aformat_messages(
|
||||||
|
name="Bob", user_input="What is your name?"
|
||||||
|
)
|
||||||
|
assert messages == expected
|
||||||
|
|
||||||
|
|
||||||
def test_chat_prompt_template_with_messages(
|
def test_chat_prompt_template_with_messages(
|
||||||
messages: List[BaseMessagePromptTemplate],
|
messages: List[BaseMessagePromptTemplate],
|
||||||
@ -262,7 +271,7 @@ def test_chat_valid_infer_variables() -> None:
|
|||||||
assert prompt.partial_variables == {"formatins": "some structure"}
|
assert prompt.partial_variables == {"formatins": "some structure"}
|
||||||
|
|
||||||
|
|
||||||
def test_chat_from_role_strings() -> None:
|
async def test_chat_from_role_strings() -> None:
|
||||||
"""Test instantiation of chat template from role strings."""
|
"""Test instantiation of chat template from role strings."""
|
||||||
with pytest.warns(LangChainPendingDeprecationWarning):
|
with pytest.warns(LangChainPendingDeprecationWarning):
|
||||||
template = ChatPromptTemplate.from_role_strings(
|
template = ChatPromptTemplate.from_role_strings(
|
||||||
@ -274,14 +283,19 @@ def test_chat_from_role_strings() -> None:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = template.format_messages(question="How are you?", quack="duck")
|
expected = [
|
||||||
assert messages == [
|
|
||||||
ChatMessage(content="You are a bot.", role="system"),
|
ChatMessage(content="You are a bot.", role="system"),
|
||||||
ChatMessage(content="hello!", role="assistant"),
|
ChatMessage(content="hello!", role="assistant"),
|
||||||
ChatMessage(content="How are you?", role="human"),
|
ChatMessage(content="How are you?", role="human"),
|
||||||
ChatMessage(content="duck", role="other"),
|
ChatMessage(content="duck", role="other"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
messages = template.format_messages(question="How are you?", quack="duck")
|
||||||
|
assert messages == expected
|
||||||
|
|
||||||
|
messages = await template.aformat_messages(question="How are you?", quack="duck")
|
||||||
|
assert messages == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"args,expected",
|
"args,expected",
|
||||||
@ -385,7 +399,7 @@ def test_chat_message_partial() -> None:
|
|||||||
assert template2.format(input="hello") == get_buffer_string(expected)
|
assert template2.format(input="hello") == get_buffer_string(expected)
|
||||||
|
|
||||||
|
|
||||||
def test_chat_tmpl_from_messages_multipart_text() -> None:
|
async def test_chat_tmpl_from_messages_multipart_text() -> None:
|
||||||
template = ChatPromptTemplate.from_messages(
|
template = ChatPromptTemplate.from_messages(
|
||||||
[
|
[
|
||||||
("system", "You are an AI assistant named {name}."),
|
("system", "You are an AI assistant named {name}."),
|
||||||
@ -398,7 +412,6 @@ def test_chat_tmpl_from_messages_multipart_text() -> None:
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
messages = template.format_messages(name="R2D2")
|
|
||||||
expected = [
|
expected = [
|
||||||
SystemMessage(content="You are an AI assistant named R2D2."),
|
SystemMessage(content="You are an AI assistant named R2D2."),
|
||||||
HumanMessage(
|
HumanMessage(
|
||||||
@ -408,10 +421,14 @@ def test_chat_tmpl_from_messages_multipart_text() -> None:
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
messages = template.format_messages(name="R2D2")
|
||||||
|
assert messages == expected
|
||||||
|
|
||||||
|
messages = await template.aformat_messages(name="R2D2")
|
||||||
assert messages == expected
|
assert messages == expected
|
||||||
|
|
||||||
|
|
||||||
def test_chat_tmpl_from_messages_multipart_text_with_template() -> None:
|
async def test_chat_tmpl_from_messages_multipart_text_with_template() -> None:
|
||||||
template = ChatPromptTemplate.from_messages(
|
template = ChatPromptTemplate.from_messages(
|
||||||
[
|
[
|
||||||
("system", "You are an AI assistant named {name}."),
|
("system", "You are an AI assistant named {name}."),
|
||||||
@ -424,7 +441,6 @@ def test_chat_tmpl_from_messages_multipart_text_with_template() -> None:
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
messages = template.format_messages(name="R2D2", object_name="image")
|
|
||||||
expected = [
|
expected = [
|
||||||
SystemMessage(content="You are an AI assistant named R2D2."),
|
SystemMessage(content="You are an AI assistant named R2D2."),
|
||||||
HumanMessage(
|
HumanMessage(
|
||||||
@ -434,10 +450,14 @@ def test_chat_tmpl_from_messages_multipart_text_with_template() -> None:
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
messages = template.format_messages(name="R2D2", object_name="image")
|
||||||
|
assert messages == expected
|
||||||
|
|
||||||
|
messages = await template.aformat_messages(name="R2D2", object_name="image")
|
||||||
assert messages == expected
|
assert messages == expected
|
||||||
|
|
||||||
|
|
||||||
def test_chat_tmpl_from_messages_multipart_image() -> None:
|
async def test_chat_tmpl_from_messages_multipart_image() -> None:
|
||||||
base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA"
|
base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA"
|
||||||
other_base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA"
|
other_base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA"
|
||||||
template = ChatPromptTemplate.from_messages(
|
template = ChatPromptTemplate.from_messages(
|
||||||
@ -472,9 +492,6 @@ def test_chat_tmpl_from_messages_multipart_image() -> None:
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
messages = template.format_messages(
|
|
||||||
name="R2D2", my_image=base64_image, my_other_image=other_base64_image
|
|
||||||
)
|
|
||||||
expected = [
|
expected = [
|
||||||
SystemMessage(content="You are an AI assistant named R2D2."),
|
SystemMessage(content="You are an AI assistant named R2D2."),
|
||||||
HumanMessage(
|
HumanMessage(
|
||||||
@ -512,6 +529,14 @@ def test_chat_tmpl_from_messages_multipart_image() -> None:
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
messages = template.format_messages(
|
||||||
|
name="R2D2", my_image=base64_image, my_other_image=other_base64_image
|
||||||
|
)
|
||||||
|
assert messages == expected
|
||||||
|
|
||||||
|
messages = await template.aformat_messages(
|
||||||
|
name="R2D2", my_image=base64_image, my_other_image=other_base64_image
|
||||||
|
)
|
||||||
assert messages == expected
|
assert messages == expected
|
||||||
|
|
||||||
|
|
||||||
@ -566,14 +591,20 @@ def test_chat_prompt_message_placeholder_tuple() -> None:
|
|||||||
assert optional_prompt.format_messages() == []
|
assert optional_prompt.format_messages() == []
|
||||||
|
|
||||||
|
|
||||||
def test_messages_prompt_accepts_list() -> None:
|
async def test_messages_prompt_accepts_list() -> None:
|
||||||
prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("history")])
|
prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("history")])
|
||||||
value = prompt.invoke([("user", "Hi there")]) # type: ignore
|
value = prompt.invoke([("user", "Hi there")]) # type: ignore
|
||||||
assert value.to_messages() == [HumanMessage(content="Hi there")]
|
assert value.to_messages() == [HumanMessage(content="Hi there")]
|
||||||
|
|
||||||
|
value = await prompt.ainvoke([("user", "Hi there")]) # type: ignore
|
||||||
|
assert value.to_messages() == [HumanMessage(content="Hi there")]
|
||||||
|
|
||||||
# Assert still raises a nice error
|
# Assert still raises a nice error
|
||||||
prompt = ChatPromptTemplate.from_messages(
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
[("system", "You are a {foo}"), MessagesPlaceholder("history")]
|
[("system", "You are a {foo}"), MessagesPlaceholder("history")]
|
||||||
)
|
)
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
prompt.invoke([("user", "Hi there")]) # type: ignore
|
prompt.invoke([("user", "Hi there")]) # type: ignore
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
await prompt.ainvoke([("user", "Hi there")]) # type: ignore
|
||||||
|
Loading…
Reference in New Issue
Block a user