core[minor]: Implement aformat_prompt and ainvoke in BasePromptTemplate (#20035)

This commit is contained in:
Christophe Bornet
2024-04-05 16:36:43 +02:00
committed by GitHub
parent 7e5c1905b1
commit 4d8a6a27a3
8 changed files with 113 additions and 18 deletions

View File

@@ -87,7 +87,7 @@ class BasePromptTemplate(
**{k: (self.input_types.get(k, str), None) for k in self.input_variables},
)
def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue:
def _validate_input(self, inner_input: Dict) -> Dict:
if not isinstance(inner_input, dict):
if len(self.input_variables) == 1:
var_name = self.input_variables[0]
@@ -105,7 +105,17 @@ class BasePromptTemplate(
f" Expected: {self.input_variables}"
f" Received: {list(inner_input.keys())}"
)
return self.format_prompt(**inner_input)
return inner_input
def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue:
_inner_input = self._validate_input(inner_input)
return self.format_prompt(**_inner_input)
async def _aformat_prompt_with_error_handling(
self, inner_input: Dict
) -> PromptValue:
_inner_input = self._validate_input(inner_input)
return await self.aformat_prompt(**_inner_input)
def invoke(
self, input: Dict, config: Optional[RunnableConfig] = None
@@ -122,10 +132,29 @@ class BasePromptTemplate(
run_type="prompt",
)
async def ainvoke(
self, input: Dict, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> PromptValue:
config = ensure_config(config)
if self.metadata:
config["metadata"].update(self.metadata)
if self.tags:
config["tags"].extend(self.tags)
return await self._acall_with_config(
self._aformat_prompt_with_error_handling,
input,
config,
run_type="prompt",
)
@abstractmethod
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Prompt Value."""
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Prompt Value."""
return self.format_prompt(**kwargs)
@root_validator()
def validate_variable_names(cls, values: Dict) -> Dict:
"""Validate variable names do not include restricted names."""

View File

@@ -609,6 +609,18 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
"""
return self.format_prompt(**kwargs).to_string()
async def aformat(self, **kwargs: Any) -> str:
"""Format the chat template into a string.
Args:
**kwargs: keyword arguments to use for filling in template variables
in all the template messages in this chat template.
Returns:
formatted string
"""
return (await self.aformat_prompt(**kwargs)).to_string()
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""
Format prompt. Should return a PromptValue.
@@ -621,6 +633,10 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
messages = self.format_messages(**kwargs)
return ChatPromptValue(messages=messages)
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
messages = await self.aformat_messages(**kwargs)
return ChatPromptValue(messages=messages)
@abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format kwargs into a list of messages."""

View File

@@ -37,9 +37,11 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
return ["langchain", "prompts", "image"]
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
return ImagePromptValue(image_url=self.format(**kwargs))
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
return ImagePromptValue(image_url=await self.aformat(**kwargs))
def format(
self,
**kwargs: Any,

View File

@@ -54,9 +54,22 @@ class PipelinePromptTemplate(BasePromptTemplate):
_inputs = _get_inputs(kwargs, self.final_prompt.input_variables)
return self.final_prompt.format_prompt(**_inputs)
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
for k, prompt in self.pipeline_prompts:
_inputs = _get_inputs(kwargs, prompt.input_variables)
if isinstance(prompt, BaseChatPromptTemplate):
kwargs[k] = await prompt.aformat_messages(**_inputs)
else:
kwargs[k] = await prompt.aformat(**_inputs)
_inputs = _get_inputs(kwargs, self.final_prompt.input_variables)
return await self.final_prompt.aformat_prompt(**_inputs)
def format(self, **kwargs: Any) -> str:
return self.format_prompt(**kwargs).to_string()
async def aformat(self, **kwargs: Any) -> str:
return (await self.aformat_prompt(**kwargs)).to_string()
@property
def _prompt_type(self) -> str:
raise ValueError

View File

@@ -160,9 +160,11 @@ class StringPromptTemplate(BasePromptTemplate, ABC):
return ["langchain", "prompts", "base"]
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
return StringPromptValue(text=self.format(**kwargs))
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
return StringPromptValue(text=await self.aformat(**kwargs))
def pretty_repr(self, html: bool = False) -> str:
# TODO: handle partials
dummy_vars = {