mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-23 19:39:58 +00:00
core[minor]: Add aformat_messages to FewShotChatMessagePromptTemplate and ChatPromptTemplate (#19648)
Needed since the example selector may use a vector store.
This commit is contained in:
committed by
GitHub
parent
5f814820f6
commit
6b2b511f68
@@ -66,6 +66,17 @@ class BaseMessagePromptTemplate(Serializable, ABC):
|
||||
List of BaseMessages.
|
||||
"""
|
||||
|
||||
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format messages from kwargs. Should return a list of BaseMessages.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
List of BaseMessages.
|
||||
"""
|
||||
return self.format_messages(**kwargs)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_variables(self) -> List[str]:
|
||||
@@ -594,6 +605,10 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format kwargs into a list of messages."""
|
||||
|
||||
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format kwargs into a list of messages."""
|
||||
return self.format_messages(**kwargs)
|
||||
|
||||
def pretty_repr(self, html: bool = False) -> str:
|
||||
"""Human-readable representation."""
|
||||
raise NotImplementedError
|
||||
@@ -901,18 +916,6 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
partial_variables=partial_vars,
|
||||
)
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the chat template into a string.
|
||||
|
||||
Args:
|
||||
**kwargs: keyword arguments to use for filling in template variables
|
||||
in all the template messages in this chat template.
|
||||
|
||||
Returns:
|
||||
formatted string
|
||||
"""
|
||||
return self.format_prompt(**kwargs).to_string()
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format the chat template into a list of finalized messages.
|
||||
|
||||
@@ -937,6 +940,30 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
raise ValueError(f"Unexpected input: {message_template}")
|
||||
return result
|
||||
|
||||
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format the chat template into a list of finalized messages.
|
||||
|
||||
Args:
|
||||
**kwargs: keyword arguments to use for filling in template variables
|
||||
in all the template messages in this chat template.
|
||||
|
||||
Returns:
|
||||
list of formatted messages
|
||||
"""
|
||||
kwargs = self._merge_partial_and_user_variables(**kwargs)
|
||||
result = []
|
||||
for message_template in self.messages:
|
||||
if isinstance(message_template, BaseMessage):
|
||||
result.extend([message_template])
|
||||
elif isinstance(
|
||||
message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate)
|
||||
):
|
||||
message = await message_template.aformat_messages(**kwargs)
|
||||
result.extend(message)
|
||||
else:
|
||||
raise ValueError(f"Unexpected input: {message_template}")
|
||||
return result
|
||||
|
||||
def partial(self, **kwargs: Any) -> ChatPromptTemplate:
|
||||
"""Get a new ChatPromptTemplate with some input variables already filled in.
|
||||
|
||||
|
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from langchain_core.example_selectors import BaseExampleSelector
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.prompts.chat import (
|
||||
BaseChatPromptTemplate,
|
||||
@@ -27,7 +28,7 @@ class _FewShotPromptTemplateMixin(BaseModel):
|
||||
"""Examples to format into the prompt.
|
||||
Either this or example_selector should be provided."""
|
||||
|
||||
example_selector: Any = None
|
||||
example_selector: Optional[BaseExampleSelector] = None
|
||||
"""ExampleSelector to choose the examples to format into the prompt.
|
||||
Either this or examples should be provided."""
|
||||
|
||||
@@ -72,6 +73,24 @@ class _FewShotPromptTemplateMixin(BaseModel):
|
||||
"One of 'examples' and 'example_selector' should be provided"
|
||||
)
|
||||
|
||||
async def _aget_examples(self, **kwargs: Any) -> List[dict]:
|
||||
"""Get the examples to use for formatting the prompt.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to be passed to the example selector.
|
||||
|
||||
Returns:
|
||||
List of examples.
|
||||
"""
|
||||
if self.examples is not None:
|
||||
return self.examples
|
||||
elif self.example_selector is not None:
|
||||
return await self.example_selector.aselect_examples(kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"One of 'examples' and 'example_selector' should be provided"
|
||||
)
|
||||
|
||||
|
||||
class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
|
||||
"""Prompt template that contains few shot examples."""
|
||||
@@ -325,6 +344,28 @@ class FewShotChatMessagePromptTemplate(
|
||||
]
|
||||
return messages
|
||||
|
||||
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format kwargs into a list of messages.
|
||||
|
||||
Args:
|
||||
**kwargs: keyword arguments to use for filling in templates in messages.
|
||||
|
||||
Returns:
|
||||
A list of formatted messages with all template variables filled in.
|
||||
"""
|
||||
# Get the examples to use.
|
||||
examples = await self._aget_examples(**kwargs)
|
||||
examples = [
|
||||
{k: e[k] for k in self.example_prompt.input_variables} for e in examples
|
||||
]
|
||||
# Format the examples.
|
||||
messages = [
|
||||
message
|
||||
for example in examples
|
||||
for message in await self.example_prompt.aformat_messages(**example)
|
||||
]
|
||||
return messages
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with inputs generating a string.
|
||||
|
||||
|
Reference in New Issue
Block a user