core[minor]: Implement aformat_prompt and ainvoke in BasePromptTemplate (#20035)

This commit is contained in:
Christophe Bornet
2024-04-05 16:36:43 +02:00
committed by GitHub
parent 7e5c1905b1
commit 4d8a6a27a3
8 changed files with 113 additions and 18 deletions

View File

@@ -28,7 +28,8 @@ from langchain_core.prompts.chat import (
)
def create_messages() -> List[BaseMessagePromptTemplate]:
@pytest.fixture
def messages() -> List[BaseMessagePromptTemplate]:
"""Create messages."""
system_message_prompt = SystemMessagePromptTemplate(
prompt=PromptTemplate(
@@ -63,11 +64,14 @@ def create_messages() -> List[BaseMessagePromptTemplate]:
]
def create_chat_prompt_template() -> ChatPromptTemplate:
@pytest.fixture
def chat_prompt_template(
messages: List[BaseMessagePromptTemplate],
) -> ChatPromptTemplate:
"""Create a chat prompt template."""
return ChatPromptTemplate(
input_variables=["foo", "bar", "context"],
messages=create_messages(), # type: ignore[arg-type]
messages=messages, # type: ignore[arg-type]
)
@@ -110,10 +114,9 @@ def test_message_prompt_template_from_template_file() -> None:
assert expected == actual
def test_chat_prompt_template() -> None:
async def test_chat_prompt_template(chat_prompt_template: ChatPromptTemplate) -> None:
"""Test chat prompt template."""
prompt_template = create_chat_prompt_template()
prompt = prompt_template.format_prompt(foo="foo", bar="bar", context="context")
prompt = chat_prompt_template.format_prompt(foo="foo", bar="bar", context="context")
assert isinstance(prompt, ChatPromptValue)
messages = prompt.to_messages()
assert len(messages) == 4
@@ -122,6 +125,12 @@ def test_chat_prompt_template() -> None:
assert messages[2].content == "I'm an AI. I'm foo. I'm bar."
assert messages[3].content == "I'm a generic message. I'm foo. I'm bar."
async_prompt = await chat_prompt_template.aformat_prompt(
foo="foo", bar="bar", context="context"
)
assert async_prompt.to_messages() == messages
string = prompt.to_string()
expected = (
"System: Here's some context: context\n"
@@ -131,13 +140,15 @@ def test_chat_prompt_template() -> None:
)
assert string == expected
string = prompt_template.format(foo="foo", bar="bar", context="context")
string = chat_prompt_template.format(foo="foo", bar="bar", context="context")
assert string == expected
def test_chat_prompt_template_from_messages() -> None:
def test_chat_prompt_template_from_messages(
messages: List[BaseMessagePromptTemplate],
) -> None:
"""Test creating a chat prompt template from messages."""
chat_prompt_template = ChatPromptTemplate.from_messages(create_messages())
chat_prompt_template = ChatPromptTemplate.from_messages(messages)
assert sorted(chat_prompt_template.input_variables) == sorted(
["context", "foo", "bar"]
)
@@ -171,11 +182,12 @@ def test_chat_prompt_template_from_messages_using_role_strings() -> None:
]
def test_chat_prompt_template_with_messages() -> None:
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]] = (
create_messages() + [HumanMessage(content="foo")]
def test_chat_prompt_template_with_messages(
messages: List[BaseMessagePromptTemplate],
) -> None:
chat_prompt_template = ChatPromptTemplate.from_messages(
messages + [HumanMessage(content="foo")]
)
chat_prompt_template = ChatPromptTemplate.from_messages(messages)
assert sorted(chat_prompt_template.input_variables) == sorted(
["context", "foo", "bar"]
)

View File

@@ -32,7 +32,7 @@ def test_multi_variable_pipeline() -> None:
assert output == "okay jim deep"
def test_partial_with_chat_prompts() -> None:
async def test_partial_with_chat_prompts() -> None:
prompt_a = ChatPromptTemplate(
input_variables=["foo"], messages=[MessagesPlaceholder(variable_name="foo")]
)
@@ -43,3 +43,5 @@ def test_partial_with_chat_prompts() -> None:
assert pipeline_prompt.input_variables == ["bar"]
output = pipeline_prompt.format_prompt(bar="okay")
assert output.to_messages()[0].content == "jim okay"
output = await pipeline_prompt.aformat_prompt(bar="okay")
assert output.to_messages()[0].content == "jim okay"

View File

@@ -351,3 +351,22 @@ def test_prompt_invoke_with_metadata() -> None:
assert len(tracer.traced_runs) == 1
assert tracer.traced_runs[0].extra["metadata"] == {"version": "1", "foo": "bar"} # type: ignore
assert tracer.traced_runs[0].tags == ["tag1", "tag2"] # type: ignore
async def test_prompt_ainvoke_with_metadata() -> None:
"""Test prompt can be invoked with metadata."""
template = "This is a {foo} test."
prompt = PromptTemplate(
input_variables=["foo"],
template=template,
metadata={"version": "1"},
tags=["tag1", "tag2"],
)
tracer = RunCollectorCallbackHandler()
result = await prompt.ainvoke(
{"foo": "bar"}, {"metadata": {"foo": "bar"}, "callbacks": [tracer]}
)
assert result.to_string() == "This is a bar test."
assert len(tracer.traced_runs) == 1
assert tracer.traced_runs[0].extra["metadata"] == {"version": "1", "foo": "bar"} # type: ignore
assert tracer.traced_runs[0].tags == ["tag1", "tag2"] # type: ignore