mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 08:03:39 +00:00
core[minor]: Add aformat_messages to FewShotChatMessagePromptTemplate and ChatPromptTemplate (#19648)
Needed since the example selector may use a vector store.
This commit is contained in:
parent
5f814820f6
commit
6b2b511f68
@ -66,6 +66,17 @@ class BaseMessagePromptTemplate(Serializable, ABC):
|
|||||||
List of BaseMessages.
|
List of BaseMessages.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||||
|
"""Format messages from kwargs. Should return a list of BaseMessages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Keyword arguments to use for formatting.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of BaseMessages.
|
||||||
|
"""
|
||||||
|
return self.format_messages(**kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def input_variables(self) -> List[str]:
|
def input_variables(self) -> List[str]:
|
||||||
@ -594,6 +605,10 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
|||||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||||
"""Format kwargs into a list of messages."""
|
"""Format kwargs into a list of messages."""
|
||||||
|
|
||||||
|
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||||
|
"""Format kwargs into a list of messages."""
|
||||||
|
return self.format_messages(**kwargs)
|
||||||
|
|
||||||
def pretty_repr(self, html: bool = False) -> str:
|
def pretty_repr(self, html: bool = False) -> str:
|
||||||
"""Human-readable representation."""
|
"""Human-readable representation."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -901,18 +916,6 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
partial_variables=partial_vars,
|
partial_variables=partial_vars,
|
||||||
)
|
)
|
||||||
|
|
||||||
def format(self, **kwargs: Any) -> str:
|
|
||||||
"""Format the chat template into a string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
**kwargs: keyword arguments to use for filling in template variables
|
|
||||||
in all the template messages in this chat template.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
formatted string
|
|
||||||
"""
|
|
||||||
return self.format_prompt(**kwargs).to_string()
|
|
||||||
|
|
||||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||||
"""Format the chat template into a list of finalized messages.
|
"""Format the chat template into a list of finalized messages.
|
||||||
|
|
||||||
@ -937,6 +940,30 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
raise ValueError(f"Unexpected input: {message_template}")
|
raise ValueError(f"Unexpected input: {message_template}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||||
|
"""Format the chat template into a list of finalized messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: keyword arguments to use for filling in template variables
|
||||||
|
in all the template messages in this chat template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of formatted messages
|
||||||
|
"""
|
||||||
|
kwargs = self._merge_partial_and_user_variables(**kwargs)
|
||||||
|
result = []
|
||||||
|
for message_template in self.messages:
|
||||||
|
if isinstance(message_template, BaseMessage):
|
||||||
|
result.extend([message_template])
|
||||||
|
elif isinstance(
|
||||||
|
message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate)
|
||||||
|
):
|
||||||
|
message = await message_template.aformat_messages(**kwargs)
|
||||||
|
result.extend(message)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected input: {message_template}")
|
||||||
|
return result
|
||||||
|
|
||||||
def partial(self, **kwargs: Any) -> ChatPromptTemplate:
|
def partial(self, **kwargs: Any) -> ChatPromptTemplate:
|
||||||
"""Get a new ChatPromptTemplate with some input variables already filled in.
|
"""Get a new ChatPromptTemplate with some input variables already filled in.
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from langchain_core.example_selectors import BaseExampleSelector
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
from langchain_core.prompts.chat import (
|
from langchain_core.prompts.chat import (
|
||||||
BaseChatPromptTemplate,
|
BaseChatPromptTemplate,
|
||||||
@ -27,7 +28,7 @@ class _FewShotPromptTemplateMixin(BaseModel):
|
|||||||
"""Examples to format into the prompt.
|
"""Examples to format into the prompt.
|
||||||
Either this or example_selector should be provided."""
|
Either this or example_selector should be provided."""
|
||||||
|
|
||||||
example_selector: Any = None
|
example_selector: Optional[BaseExampleSelector] = None
|
||||||
"""ExampleSelector to choose the examples to format into the prompt.
|
"""ExampleSelector to choose the examples to format into the prompt.
|
||||||
Either this or examples should be provided."""
|
Either this or examples should be provided."""
|
||||||
|
|
||||||
@ -72,6 +73,24 @@ class _FewShotPromptTemplateMixin(BaseModel):
|
|||||||
"One of 'examples' and 'example_selector' should be provided"
|
"One of 'examples' and 'example_selector' should be provided"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _aget_examples(self, **kwargs: Any) -> List[dict]:
|
||||||
|
"""Get the examples to use for formatting the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Keyword arguments to be passed to the example selector.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of examples.
|
||||||
|
"""
|
||||||
|
if self.examples is not None:
|
||||||
|
return self.examples
|
||||||
|
elif self.example_selector is not None:
|
||||||
|
return await self.example_selector.aselect_examples(kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"One of 'examples' and 'example_selector' should be provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
|
class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
|
||||||
"""Prompt template that contains few shot examples."""
|
"""Prompt template that contains few shot examples."""
|
||||||
@ -325,6 +344,28 @@ class FewShotChatMessagePromptTemplate(
|
|||||||
]
|
]
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||||
|
"""Format kwargs into a list of messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: keyword arguments to use for filling in templates in messages.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of formatted messages with all template variables filled in.
|
||||||
|
"""
|
||||||
|
# Get the examples to use.
|
||||||
|
examples = await self._aget_examples(**kwargs)
|
||||||
|
examples = [
|
||||||
|
{k: e[k] for k in self.example_prompt.input_variables} for e in examples
|
||||||
|
]
|
||||||
|
# Format the examples.
|
||||||
|
messages = [
|
||||||
|
message
|
||||||
|
for example in examples
|
||||||
|
for message in await self.example_prompt.aformat_messages(**example)
|
||||||
|
]
|
||||||
|
return messages
|
||||||
|
|
||||||
def format(self, **kwargs: Any) -> str:
|
def format(self, **kwargs: Any) -> str:
|
||||||
"""Format the prompt with inputs generating a string.
|
"""Format the prompt with inputs generating a string.
|
||||||
|
|
||||||
|
@ -308,7 +308,7 @@ def test_prompt_jinja2_extra_input_variables(
|
|||||||
).input_variables == ["bar", "foo"]
|
).input_variables == ["bar", "foo"]
|
||||||
|
|
||||||
|
|
||||||
def test_few_shot_chat_message_prompt_template() -> None:
|
async def test_few_shot_chat_message_prompt_template() -> None:
|
||||||
"""Tests for few shot chat message template."""
|
"""Tests for few shot chat message template."""
|
||||||
examples = [
|
examples = [
|
||||||
{"input": "2+2", "output": "4"},
|
{"input": "2+2", "output": "4"},
|
||||||
@ -333,8 +333,7 @@ def test_few_shot_chat_message_prompt_template() -> None:
|
|||||||
+ HumanMessagePromptTemplate.from_template("{input}")
|
+ HumanMessagePromptTemplate.from_template("{input}")
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = final_prompt.format_messages(input="100 + 1")
|
expected = [
|
||||||
assert messages == [
|
|
||||||
SystemMessage(content="You are a helpful AI Assistant", additional_kwargs={}),
|
SystemMessage(content="You are a helpful AI Assistant", additional_kwargs={}),
|
||||||
HumanMessage(content="2+2", additional_kwargs={}, example=False),
|
HumanMessage(content="2+2", additional_kwargs={}, example=False),
|
||||||
AIMessage(content="4", additional_kwargs={}, example=False),
|
AIMessage(content="4", additional_kwargs={}, example=False),
|
||||||
@ -343,6 +342,11 @@ def test_few_shot_chat_message_prompt_template() -> None:
|
|||||||
HumanMessage(content="100 + 1", additional_kwargs={}, example=False),
|
HumanMessage(content="100 + 1", additional_kwargs={}, example=False),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
messages = final_prompt.format_messages(input="100 + 1")
|
||||||
|
assert messages == expected
|
||||||
|
messages = await final_prompt.aformat_messages(input="100 + 1")
|
||||||
|
assert messages == expected
|
||||||
|
|
||||||
|
|
||||||
class AsIsSelector(BaseExampleSelector):
|
class AsIsSelector(BaseExampleSelector):
|
||||||
"""An example selector for testing purposes.
|
"""An example selector for testing purposes.
|
||||||
@ -355,11 +359,9 @@ class AsIsSelector(BaseExampleSelector):
|
|||||||
self.examples = examples
|
self.examples = examples
|
||||||
|
|
||||||
def add_example(self, example: Dict[str, str]) -> Any:
|
def add_example(self, example: Dict[str, str]) -> Any:
|
||||||
"""Adds an example to the selector."""
|
raise NotImplementedError
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||||
"""Select which examples to use based on the inputs."""
|
|
||||||
return list(self.examples)
|
return list(self.examples)
|
||||||
|
|
||||||
|
|
||||||
@ -387,8 +389,7 @@ def test_few_shot_chat_message_prompt_template_with_selector() -> None:
|
|||||||
+ few_shot_prompt
|
+ few_shot_prompt
|
||||||
+ HumanMessagePromptTemplate.from_template("{input}")
|
+ HumanMessagePromptTemplate.from_template("{input}")
|
||||||
)
|
)
|
||||||
messages = final_prompt.format_messages(input="100 + 1")
|
expected = [
|
||||||
assert messages == [
|
|
||||||
SystemMessage(content="You are a helpful AI Assistant", additional_kwargs={}),
|
SystemMessage(content="You are a helpful AI Assistant", additional_kwargs={}),
|
||||||
HumanMessage(content="2+2", additional_kwargs={}, example=False),
|
HumanMessage(content="2+2", additional_kwargs={}, example=False),
|
||||||
AIMessage(content="4", additional_kwargs={}, example=False),
|
AIMessage(content="4", additional_kwargs={}, example=False),
|
||||||
@ -396,3 +397,61 @@ def test_few_shot_chat_message_prompt_template_with_selector() -> None:
|
|||||||
AIMessage(content="5", additional_kwargs={}, example=False),
|
AIMessage(content="5", additional_kwargs={}, example=False),
|
||||||
HumanMessage(content="100 + 1", additional_kwargs={}, example=False),
|
HumanMessage(content="100 + 1", additional_kwargs={}, example=False),
|
||||||
]
|
]
|
||||||
|
messages = final_prompt.format_messages(input="100 + 1")
|
||||||
|
assert messages == expected
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncAsIsSelector(BaseExampleSelector):
|
||||||
|
"""An example selector for testing purposes.
|
||||||
|
|
||||||
|
This selector returns the examples as-is.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, examples: Sequence[Dict[str, str]]) -> None:
|
||||||
|
"""Initializes the selector."""
|
||||||
|
self.examples = examples
|
||||||
|
|
||||||
|
def add_example(self, example: Dict[str, str]) -> Any:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||||
|
return list(self.examples)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_few_shot_chat_message_prompt_template_with_selector_async() -> None:
|
||||||
|
"""Tests for few shot chat message template with an async example selector."""
|
||||||
|
examples = [
|
||||||
|
{"input": "2+2", "output": "4"},
|
||||||
|
{"input": "2+3", "output": "5"},
|
||||||
|
]
|
||||||
|
example_selector = AsyncAsIsSelector(examples)
|
||||||
|
example_prompt = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
HumanMessagePromptTemplate.from_template("{input}"),
|
||||||
|
AIMessagePromptTemplate.from_template("{output}"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
few_shot_prompt = FewShotChatMessagePromptTemplate(
|
||||||
|
input_variables=["input"],
|
||||||
|
example_prompt=example_prompt,
|
||||||
|
example_selector=example_selector,
|
||||||
|
)
|
||||||
|
final_prompt: ChatPromptTemplate = (
|
||||||
|
SystemMessagePromptTemplate.from_template("You are a helpful AI Assistant")
|
||||||
|
+ few_shot_prompt
|
||||||
|
+ HumanMessagePromptTemplate.from_template("{input}")
|
||||||
|
)
|
||||||
|
expected = [
|
||||||
|
SystemMessage(content="You are a helpful AI Assistant", additional_kwargs={}),
|
||||||
|
HumanMessage(content="2+2", additional_kwargs={}, example=False),
|
||||||
|
AIMessage(content="4", additional_kwargs={}, example=False),
|
||||||
|
HumanMessage(content="2+3", additional_kwargs={}, example=False),
|
||||||
|
AIMessage(content="5", additional_kwargs={}, example=False),
|
||||||
|
HumanMessage(content="100 + 1", additional_kwargs={}, example=False),
|
||||||
|
]
|
||||||
|
messages = await final_prompt.aformat_messages(input="100 + 1")
|
||||||
|
assert messages == expected
|
||||||
|
Loading…
Reference in New Issue
Block a user