core[minor]: Implement aformat_prompt and ainvoke in BasePromptTemplate (#20035)

This commit is contained in:
Christophe Bornet 2024-04-05 16:36:43 +02:00 committed by GitHub
parent 7e5c1905b1
commit 4d8a6a27a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 113 additions and 18 deletions

View File

@ -87,7 +87,7 @@ class BasePromptTemplate(
**{k: (self.input_types.get(k, str), None) for k in self.input_variables}, **{k: (self.input_types.get(k, str), None) for k in self.input_variables},
) )
def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue: def _validate_input(self, inner_input: Dict) -> Dict:
if not isinstance(inner_input, dict): if not isinstance(inner_input, dict):
if len(self.input_variables) == 1: if len(self.input_variables) == 1:
var_name = self.input_variables[0] var_name = self.input_variables[0]
@ -105,7 +105,17 @@ class BasePromptTemplate(
f" Expected: {self.input_variables}" f" Expected: {self.input_variables}"
f" Received: {list(inner_input.keys())}" f" Received: {list(inner_input.keys())}"
) )
return self.format_prompt(**inner_input) return inner_input
def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue:
_inner_input = self._validate_input(inner_input)
return self.format_prompt(**_inner_input)
async def _aformat_prompt_with_error_handling(
self, inner_input: Dict
) -> PromptValue:
_inner_input = self._validate_input(inner_input)
return await self.aformat_prompt(**_inner_input)
def invoke( def invoke(
self, input: Dict, config: Optional[RunnableConfig] = None self, input: Dict, config: Optional[RunnableConfig] = None
@ -122,10 +132,29 @@ class BasePromptTemplate(
run_type="prompt", run_type="prompt",
) )
async def ainvoke(
self, input: Dict, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> PromptValue:
config = ensure_config(config)
if self.metadata:
config["metadata"].update(self.metadata)
if self.tags:
config["tags"].extend(self.tags)
return await self._acall_with_config(
self._aformat_prompt_with_error_handling,
input,
config,
run_type="prompt",
)
@abstractmethod @abstractmethod
def format_prompt(self, **kwargs: Any) -> PromptValue: def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Prompt Value.""" """Create Prompt Value."""
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Prompt Value."""
return self.format_prompt(**kwargs)
@root_validator() @root_validator()
def validate_variable_names(cls, values: Dict) -> Dict: def validate_variable_names(cls, values: Dict) -> Dict:
"""Validate variable names do not include restricted names.""" """Validate variable names do not include restricted names."""

View File

@ -609,6 +609,18 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
""" """
return self.format_prompt(**kwargs).to_string() return self.format_prompt(**kwargs).to_string()
async def aformat(self, **kwargs: Any) -> str:
"""Format the chat template into a string.
Args:
**kwargs: keyword arguments to use for filling in template variables
in all the template messages in this chat template.
Returns:
formatted string
"""
return (await self.aformat_prompt(**kwargs)).to_string()
def format_prompt(self, **kwargs: Any) -> PromptValue: def format_prompt(self, **kwargs: Any) -> PromptValue:
""" """
Format prompt. Should return a PromptValue. Format prompt. Should return a PromptValue.
@ -621,6 +633,10 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
messages = self.format_messages(**kwargs) messages = self.format_messages(**kwargs)
return ChatPromptValue(messages=messages) return ChatPromptValue(messages=messages)
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
messages = await self.aformat_messages(**kwargs)
return ChatPromptValue(messages=messages)
@abstractmethod @abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]: def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format kwargs into a list of messages.""" """Format kwargs into a list of messages."""

View File

@ -37,9 +37,11 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
return ["langchain", "prompts", "image"] return ["langchain", "prompts", "image"]
def format_prompt(self, **kwargs: Any) -> PromptValue: def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
return ImagePromptValue(image_url=self.format(**kwargs)) return ImagePromptValue(image_url=self.format(**kwargs))
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
return ImagePromptValue(image_url=await self.aformat(**kwargs))
def format( def format(
self, self,
**kwargs: Any, **kwargs: Any,

View File

@ -54,9 +54,22 @@ class PipelinePromptTemplate(BasePromptTemplate):
_inputs = _get_inputs(kwargs, self.final_prompt.input_variables) _inputs = _get_inputs(kwargs, self.final_prompt.input_variables)
return self.final_prompt.format_prompt(**_inputs) return self.final_prompt.format_prompt(**_inputs)
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
for k, prompt in self.pipeline_prompts:
_inputs = _get_inputs(kwargs, prompt.input_variables)
if isinstance(prompt, BaseChatPromptTemplate):
kwargs[k] = await prompt.aformat_messages(**_inputs)
else:
kwargs[k] = await prompt.aformat(**_inputs)
_inputs = _get_inputs(kwargs, self.final_prompt.input_variables)
return await self.final_prompt.aformat_prompt(**_inputs)
def format(self, **kwargs: Any) -> str: def format(self, **kwargs: Any) -> str:
return self.format_prompt(**kwargs).to_string() return self.format_prompt(**kwargs).to_string()
async def aformat(self, **kwargs: Any) -> str:
return (await self.aformat_prompt(**kwargs)).to_string()
@property @property
def _prompt_type(self) -> str: def _prompt_type(self) -> str:
raise ValueError raise ValueError

View File

@ -160,9 +160,11 @@ class StringPromptTemplate(BasePromptTemplate, ABC):
return ["langchain", "prompts", "base"] return ["langchain", "prompts", "base"]
def format_prompt(self, **kwargs: Any) -> PromptValue: def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
return StringPromptValue(text=self.format(**kwargs)) return StringPromptValue(text=self.format(**kwargs))
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
return StringPromptValue(text=await self.aformat(**kwargs))
def pretty_repr(self, html: bool = False) -> str: def pretty_repr(self, html: bool = False) -> str:
# TODO: handle partials # TODO: handle partials
dummy_vars = { dummy_vars = {

View File

@ -28,7 +28,8 @@ from langchain_core.prompts.chat import (
) )
def create_messages() -> List[BaseMessagePromptTemplate]: @pytest.fixture
def messages() -> List[BaseMessagePromptTemplate]:
"""Create messages.""" """Create messages."""
system_message_prompt = SystemMessagePromptTemplate( system_message_prompt = SystemMessagePromptTemplate(
prompt=PromptTemplate( prompt=PromptTemplate(
@ -63,11 +64,14 @@ def create_messages() -> List[BaseMessagePromptTemplate]:
] ]
def create_chat_prompt_template() -> ChatPromptTemplate: @pytest.fixture
def chat_prompt_template(
messages: List[BaseMessagePromptTemplate],
) -> ChatPromptTemplate:
"""Create a chat prompt template.""" """Create a chat prompt template."""
return ChatPromptTemplate( return ChatPromptTemplate(
input_variables=["foo", "bar", "context"], input_variables=["foo", "bar", "context"],
messages=create_messages(), # type: ignore[arg-type] messages=messages, # type: ignore[arg-type]
) )
@ -110,10 +114,9 @@ def test_message_prompt_template_from_template_file() -> None:
assert expected == actual assert expected == actual
def test_chat_prompt_template() -> None: async def test_chat_prompt_template(chat_prompt_template: ChatPromptTemplate) -> None:
"""Test chat prompt template.""" """Test chat prompt template."""
prompt_template = create_chat_prompt_template() prompt = chat_prompt_template.format_prompt(foo="foo", bar="bar", context="context")
prompt = prompt_template.format_prompt(foo="foo", bar="bar", context="context")
assert isinstance(prompt, ChatPromptValue) assert isinstance(prompt, ChatPromptValue)
messages = prompt.to_messages() messages = prompt.to_messages()
assert len(messages) == 4 assert len(messages) == 4
@ -122,6 +125,12 @@ def test_chat_prompt_template() -> None:
assert messages[2].content == "I'm an AI. I'm foo. I'm bar." assert messages[2].content == "I'm an AI. I'm foo. I'm bar."
assert messages[3].content == "I'm a generic message. I'm foo. I'm bar." assert messages[3].content == "I'm a generic message. I'm foo. I'm bar."
async_prompt = await chat_prompt_template.aformat_prompt(
foo="foo", bar="bar", context="context"
)
assert async_prompt.to_messages() == messages
string = prompt.to_string() string = prompt.to_string()
expected = ( expected = (
"System: Here's some context: context\n" "System: Here's some context: context\n"
@ -131,13 +140,15 @@ def test_chat_prompt_template() -> None:
) )
assert string == expected assert string == expected
string = prompt_template.format(foo="foo", bar="bar", context="context") string = chat_prompt_template.format(foo="foo", bar="bar", context="context")
assert string == expected assert string == expected
def test_chat_prompt_template_from_messages() -> None: def test_chat_prompt_template_from_messages(
messages: List[BaseMessagePromptTemplate],
) -> None:
"""Test creating a chat prompt template from messages.""" """Test creating a chat prompt template from messages."""
chat_prompt_template = ChatPromptTemplate.from_messages(create_messages()) chat_prompt_template = ChatPromptTemplate.from_messages(messages)
assert sorted(chat_prompt_template.input_variables) == sorted( assert sorted(chat_prompt_template.input_variables) == sorted(
["context", "foo", "bar"] ["context", "foo", "bar"]
) )
@ -171,11 +182,12 @@ def test_chat_prompt_template_from_messages_using_role_strings() -> None:
] ]
def test_chat_prompt_template_with_messages() -> None: def test_chat_prompt_template_with_messages(
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]] = ( messages: List[BaseMessagePromptTemplate],
create_messages() + [HumanMessage(content="foo")] ) -> None:
chat_prompt_template = ChatPromptTemplate.from_messages(
messages + [HumanMessage(content="foo")]
) )
chat_prompt_template = ChatPromptTemplate.from_messages(messages)
assert sorted(chat_prompt_template.input_variables) == sorted( assert sorted(chat_prompt_template.input_variables) == sorted(
["context", "foo", "bar"] ["context", "foo", "bar"]
) )

View File

@ -32,7 +32,7 @@ def test_multi_variable_pipeline() -> None:
assert output == "okay jim deep" assert output == "okay jim deep"
def test_partial_with_chat_prompts() -> None: async def test_partial_with_chat_prompts() -> None:
prompt_a = ChatPromptTemplate( prompt_a = ChatPromptTemplate(
input_variables=["foo"], messages=[MessagesPlaceholder(variable_name="foo")] input_variables=["foo"], messages=[MessagesPlaceholder(variable_name="foo")]
) )
@ -43,3 +43,5 @@ def test_partial_with_chat_prompts() -> None:
assert pipeline_prompt.input_variables == ["bar"] assert pipeline_prompt.input_variables == ["bar"]
output = pipeline_prompt.format_prompt(bar="okay") output = pipeline_prompt.format_prompt(bar="okay")
assert output.to_messages()[0].content == "jim okay" assert output.to_messages()[0].content == "jim okay"
output = await pipeline_prompt.aformat_prompt(bar="okay")
assert output.to_messages()[0].content == "jim okay"

View File

@ -351,3 +351,22 @@ def test_prompt_invoke_with_metadata() -> None:
assert len(tracer.traced_runs) == 1 assert len(tracer.traced_runs) == 1
assert tracer.traced_runs[0].extra["metadata"] == {"version": "1", "foo": "bar"} # type: ignore assert tracer.traced_runs[0].extra["metadata"] == {"version": "1", "foo": "bar"} # type: ignore
assert tracer.traced_runs[0].tags == ["tag1", "tag2"] # type: ignore assert tracer.traced_runs[0].tags == ["tag1", "tag2"] # type: ignore
async def test_prompt_ainvoke_with_metadata() -> None:
"""Test prompt can be invoked with metadata."""
template = "This is a {foo} test."
prompt = PromptTemplate(
input_variables=["foo"],
template=template,
metadata={"version": "1"},
tags=["tag1", "tag2"],
)
tracer = RunCollectorCallbackHandler()
result = await prompt.ainvoke(
{"foo": "bar"}, {"metadata": {"foo": "bar"}, "callbacks": [tracer]}
)
assert result.to_string() == "This is a bar test."
assert len(tracer.traced_runs) == 1
assert tracer.traced_runs[0].extra["metadata"] == {"version": "1", "foo": "bar"} # type: ignore
assert tracer.traced_runs[0].tags == ["tag1", "tag2"] # type: ignore