mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
core[minor]: Implement aformat_prompt and ainvoke in BasePromptTemplate (#20035)
This commit is contained in:
parent
7e5c1905b1
commit
4d8a6a27a3
@ -87,7 +87,7 @@ class BasePromptTemplate(
|
|||||||
**{k: (self.input_types.get(k, str), None) for k in self.input_variables},
|
**{k: (self.input_types.get(k, str), None) for k in self.input_variables},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue:
|
def _validate_input(self, inner_input: Dict) -> Dict:
|
||||||
if not isinstance(inner_input, dict):
|
if not isinstance(inner_input, dict):
|
||||||
if len(self.input_variables) == 1:
|
if len(self.input_variables) == 1:
|
||||||
var_name = self.input_variables[0]
|
var_name = self.input_variables[0]
|
||||||
@ -105,7 +105,17 @@ class BasePromptTemplate(
|
|||||||
f" Expected: {self.input_variables}"
|
f" Expected: {self.input_variables}"
|
||||||
f" Received: {list(inner_input.keys())}"
|
f" Received: {list(inner_input.keys())}"
|
||||||
)
|
)
|
||||||
return self.format_prompt(**inner_input)
|
return inner_input
|
||||||
|
|
||||||
|
def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue:
|
||||||
|
_inner_input = self._validate_input(inner_input)
|
||||||
|
return self.format_prompt(**_inner_input)
|
||||||
|
|
||||||
|
async def _aformat_prompt_with_error_handling(
|
||||||
|
self, inner_input: Dict
|
||||||
|
) -> PromptValue:
|
||||||
|
_inner_input = self._validate_input(inner_input)
|
||||||
|
return await self.aformat_prompt(**_inner_input)
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: Dict, config: Optional[RunnableConfig] = None
|
self, input: Dict, config: Optional[RunnableConfig] = None
|
||||||
@ -122,10 +132,29 @@ class BasePromptTemplate(
|
|||||||
run_type="prompt",
|
run_type="prompt",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self, input: Dict, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
|
) -> PromptValue:
|
||||||
|
config = ensure_config(config)
|
||||||
|
if self.metadata:
|
||||||
|
config["metadata"].update(self.metadata)
|
||||||
|
if self.tags:
|
||||||
|
config["tags"].extend(self.tags)
|
||||||
|
return await self._acall_with_config(
|
||||||
|
self._aformat_prompt_with_error_handling,
|
||||||
|
input,
|
||||||
|
config,
|
||||||
|
run_type="prompt",
|
||||||
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||||
"""Create Prompt Value."""
|
"""Create Prompt Value."""
|
||||||
|
|
||||||
|
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
|
||||||
|
"""Create Prompt Value."""
|
||||||
|
return self.format_prompt(**kwargs)
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_variable_names(cls, values: Dict) -> Dict:
|
def validate_variable_names(cls, values: Dict) -> Dict:
|
||||||
"""Validate variable names do not include restricted names."""
|
"""Validate variable names do not include restricted names."""
|
||||||
|
@ -609,6 +609,18 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
|||||||
"""
|
"""
|
||||||
return self.format_prompt(**kwargs).to_string()
|
return self.format_prompt(**kwargs).to_string()
|
||||||
|
|
||||||
|
async def aformat(self, **kwargs: Any) -> str:
|
||||||
|
"""Format the chat template into a string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: keyword arguments to use for filling in template variables
|
||||||
|
in all the template messages in this chat template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
formatted string
|
||||||
|
"""
|
||||||
|
return (await self.aformat_prompt(**kwargs)).to_string()
|
||||||
|
|
||||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||||
"""
|
"""
|
||||||
Format prompt. Should return a PromptValue.
|
Format prompt. Should return a PromptValue.
|
||||||
@ -621,6 +633,10 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
|||||||
messages = self.format_messages(**kwargs)
|
messages = self.format_messages(**kwargs)
|
||||||
return ChatPromptValue(messages=messages)
|
return ChatPromptValue(messages=messages)
|
||||||
|
|
||||||
|
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
|
||||||
|
messages = await self.aformat_messages(**kwargs)
|
||||||
|
return ChatPromptValue(messages=messages)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||||
"""Format kwargs into a list of messages."""
|
"""Format kwargs into a list of messages."""
|
||||||
|
@ -37,9 +37,11 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
|
|||||||
return ["langchain", "prompts", "image"]
|
return ["langchain", "prompts", "image"]
|
||||||
|
|
||||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||||
"""Create Chat Messages."""
|
|
||||||
return ImagePromptValue(image_url=self.format(**kwargs))
|
return ImagePromptValue(image_url=self.format(**kwargs))
|
||||||
|
|
||||||
|
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
|
||||||
|
return ImagePromptValue(image_url=await self.aformat(**kwargs))
|
||||||
|
|
||||||
def format(
|
def format(
|
||||||
self,
|
self,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
|
@ -54,9 +54,22 @@ class PipelinePromptTemplate(BasePromptTemplate):
|
|||||||
_inputs = _get_inputs(kwargs, self.final_prompt.input_variables)
|
_inputs = _get_inputs(kwargs, self.final_prompt.input_variables)
|
||||||
return self.final_prompt.format_prompt(**_inputs)
|
return self.final_prompt.format_prompt(**_inputs)
|
||||||
|
|
||||||
|
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
|
||||||
|
for k, prompt in self.pipeline_prompts:
|
||||||
|
_inputs = _get_inputs(kwargs, prompt.input_variables)
|
||||||
|
if isinstance(prompt, BaseChatPromptTemplate):
|
||||||
|
kwargs[k] = await prompt.aformat_messages(**_inputs)
|
||||||
|
else:
|
||||||
|
kwargs[k] = await prompt.aformat(**_inputs)
|
||||||
|
_inputs = _get_inputs(kwargs, self.final_prompt.input_variables)
|
||||||
|
return await self.final_prompt.aformat_prompt(**_inputs)
|
||||||
|
|
||||||
def format(self, **kwargs: Any) -> str:
|
def format(self, **kwargs: Any) -> str:
|
||||||
return self.format_prompt(**kwargs).to_string()
|
return self.format_prompt(**kwargs).to_string()
|
||||||
|
|
||||||
|
async def aformat(self, **kwargs: Any) -> str:
|
||||||
|
return (await self.aformat_prompt(**kwargs)).to_string()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _prompt_type(self) -> str:
|
def _prompt_type(self) -> str:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
@ -160,9 +160,11 @@ class StringPromptTemplate(BasePromptTemplate, ABC):
|
|||||||
return ["langchain", "prompts", "base"]
|
return ["langchain", "prompts", "base"]
|
||||||
|
|
||||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||||
"""Create Chat Messages."""
|
|
||||||
return StringPromptValue(text=self.format(**kwargs))
|
return StringPromptValue(text=self.format(**kwargs))
|
||||||
|
|
||||||
|
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
|
||||||
|
return StringPromptValue(text=await self.aformat(**kwargs))
|
||||||
|
|
||||||
def pretty_repr(self, html: bool = False) -> str:
|
def pretty_repr(self, html: bool = False) -> str:
|
||||||
# TODO: handle partials
|
# TODO: handle partials
|
||||||
dummy_vars = {
|
dummy_vars = {
|
||||||
|
@ -28,7 +28,8 @@ from langchain_core.prompts.chat import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_messages() -> List[BaseMessagePromptTemplate]:
|
@pytest.fixture
|
||||||
|
def messages() -> List[BaseMessagePromptTemplate]:
|
||||||
"""Create messages."""
|
"""Create messages."""
|
||||||
system_message_prompt = SystemMessagePromptTemplate(
|
system_message_prompt = SystemMessagePromptTemplate(
|
||||||
prompt=PromptTemplate(
|
prompt=PromptTemplate(
|
||||||
@ -63,11 +64,14 @@ def create_messages() -> List[BaseMessagePromptTemplate]:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def create_chat_prompt_template() -> ChatPromptTemplate:
|
@pytest.fixture
|
||||||
|
def chat_prompt_template(
|
||||||
|
messages: List[BaseMessagePromptTemplate],
|
||||||
|
) -> ChatPromptTemplate:
|
||||||
"""Create a chat prompt template."""
|
"""Create a chat prompt template."""
|
||||||
return ChatPromptTemplate(
|
return ChatPromptTemplate(
|
||||||
input_variables=["foo", "bar", "context"],
|
input_variables=["foo", "bar", "context"],
|
||||||
messages=create_messages(), # type: ignore[arg-type]
|
messages=messages, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -110,10 +114,9 @@ def test_message_prompt_template_from_template_file() -> None:
|
|||||||
assert expected == actual
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
def test_chat_prompt_template() -> None:
|
async def test_chat_prompt_template(chat_prompt_template: ChatPromptTemplate) -> None:
|
||||||
"""Test chat prompt template."""
|
"""Test chat prompt template."""
|
||||||
prompt_template = create_chat_prompt_template()
|
prompt = chat_prompt_template.format_prompt(foo="foo", bar="bar", context="context")
|
||||||
prompt = prompt_template.format_prompt(foo="foo", bar="bar", context="context")
|
|
||||||
assert isinstance(prompt, ChatPromptValue)
|
assert isinstance(prompt, ChatPromptValue)
|
||||||
messages = prompt.to_messages()
|
messages = prompt.to_messages()
|
||||||
assert len(messages) == 4
|
assert len(messages) == 4
|
||||||
@ -122,6 +125,12 @@ def test_chat_prompt_template() -> None:
|
|||||||
assert messages[2].content == "I'm an AI. I'm foo. I'm bar."
|
assert messages[2].content == "I'm an AI. I'm foo. I'm bar."
|
||||||
assert messages[3].content == "I'm a generic message. I'm foo. I'm bar."
|
assert messages[3].content == "I'm a generic message. I'm foo. I'm bar."
|
||||||
|
|
||||||
|
async_prompt = await chat_prompt_template.aformat_prompt(
|
||||||
|
foo="foo", bar="bar", context="context"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert async_prompt.to_messages() == messages
|
||||||
|
|
||||||
string = prompt.to_string()
|
string = prompt.to_string()
|
||||||
expected = (
|
expected = (
|
||||||
"System: Here's some context: context\n"
|
"System: Here's some context: context\n"
|
||||||
@ -131,13 +140,15 @@ def test_chat_prompt_template() -> None:
|
|||||||
)
|
)
|
||||||
assert string == expected
|
assert string == expected
|
||||||
|
|
||||||
string = prompt_template.format(foo="foo", bar="bar", context="context")
|
string = chat_prompt_template.format(foo="foo", bar="bar", context="context")
|
||||||
assert string == expected
|
assert string == expected
|
||||||
|
|
||||||
|
|
||||||
def test_chat_prompt_template_from_messages() -> None:
|
def test_chat_prompt_template_from_messages(
|
||||||
|
messages: List[BaseMessagePromptTemplate],
|
||||||
|
) -> None:
|
||||||
"""Test creating a chat prompt template from messages."""
|
"""Test creating a chat prompt template from messages."""
|
||||||
chat_prompt_template = ChatPromptTemplate.from_messages(create_messages())
|
chat_prompt_template = ChatPromptTemplate.from_messages(messages)
|
||||||
assert sorted(chat_prompt_template.input_variables) == sorted(
|
assert sorted(chat_prompt_template.input_variables) == sorted(
|
||||||
["context", "foo", "bar"]
|
["context", "foo", "bar"]
|
||||||
)
|
)
|
||||||
@ -171,11 +182,12 @@ def test_chat_prompt_template_from_messages_using_role_strings() -> None:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_chat_prompt_template_with_messages() -> None:
|
def test_chat_prompt_template_with_messages(
|
||||||
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]] = (
|
messages: List[BaseMessagePromptTemplate],
|
||||||
create_messages() + [HumanMessage(content="foo")]
|
) -> None:
|
||||||
|
chat_prompt_template = ChatPromptTemplate.from_messages(
|
||||||
|
messages + [HumanMessage(content="foo")]
|
||||||
)
|
)
|
||||||
chat_prompt_template = ChatPromptTemplate.from_messages(messages)
|
|
||||||
assert sorted(chat_prompt_template.input_variables) == sorted(
|
assert sorted(chat_prompt_template.input_variables) == sorted(
|
||||||
["context", "foo", "bar"]
|
["context", "foo", "bar"]
|
||||||
)
|
)
|
||||||
|
@ -32,7 +32,7 @@ def test_multi_variable_pipeline() -> None:
|
|||||||
assert output == "okay jim deep"
|
assert output == "okay jim deep"
|
||||||
|
|
||||||
|
|
||||||
def test_partial_with_chat_prompts() -> None:
|
async def test_partial_with_chat_prompts() -> None:
|
||||||
prompt_a = ChatPromptTemplate(
|
prompt_a = ChatPromptTemplate(
|
||||||
input_variables=["foo"], messages=[MessagesPlaceholder(variable_name="foo")]
|
input_variables=["foo"], messages=[MessagesPlaceholder(variable_name="foo")]
|
||||||
)
|
)
|
||||||
@ -43,3 +43,5 @@ def test_partial_with_chat_prompts() -> None:
|
|||||||
assert pipeline_prompt.input_variables == ["bar"]
|
assert pipeline_prompt.input_variables == ["bar"]
|
||||||
output = pipeline_prompt.format_prompt(bar="okay")
|
output = pipeline_prompt.format_prompt(bar="okay")
|
||||||
assert output.to_messages()[0].content == "jim okay"
|
assert output.to_messages()[0].content == "jim okay"
|
||||||
|
output = await pipeline_prompt.aformat_prompt(bar="okay")
|
||||||
|
assert output.to_messages()[0].content == "jim okay"
|
||||||
|
@ -351,3 +351,22 @@ def test_prompt_invoke_with_metadata() -> None:
|
|||||||
assert len(tracer.traced_runs) == 1
|
assert len(tracer.traced_runs) == 1
|
||||||
assert tracer.traced_runs[0].extra["metadata"] == {"version": "1", "foo": "bar"} # type: ignore
|
assert tracer.traced_runs[0].extra["metadata"] == {"version": "1", "foo": "bar"} # type: ignore
|
||||||
assert tracer.traced_runs[0].tags == ["tag1", "tag2"] # type: ignore
|
assert tracer.traced_runs[0].tags == ["tag1", "tag2"] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
async def test_prompt_ainvoke_with_metadata() -> None:
|
||||||
|
"""Test prompt can be invoked with metadata."""
|
||||||
|
template = "This is a {foo} test."
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
input_variables=["foo"],
|
||||||
|
template=template,
|
||||||
|
metadata={"version": "1"},
|
||||||
|
tags=["tag1", "tag2"],
|
||||||
|
)
|
||||||
|
tracer = RunCollectorCallbackHandler()
|
||||||
|
result = await prompt.ainvoke(
|
||||||
|
{"foo": "bar"}, {"metadata": {"foo": "bar"}, "callbacks": [tracer]}
|
||||||
|
)
|
||||||
|
assert result.to_string() == "This is a bar test."
|
||||||
|
assert len(tracer.traced_runs) == 1
|
||||||
|
assert tracer.traced_runs[0].extra["metadata"] == {"version": "1", "foo": "bar"} # type: ignore
|
||||||
|
assert tracer.traced_runs[0].tags == ["tag1", "tag2"] # type: ignore
|
||||||
|
Loading…
Reference in New Issue
Block a user