mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-22 11:00:37 +00:00
core[minor]: Implement aformat_prompt and ainvoke in BasePromptTemplate (#20035)
This commit is contained in:
committed by
GitHub
parent
7e5c1905b1
commit
4d8a6a27a3
@@ -87,7 +87,7 @@ class BasePromptTemplate(
|
||||
**{k: (self.input_types.get(k, str), None) for k in self.input_variables},
|
||||
)
|
||||
|
||||
def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue:
|
||||
def _validate_input(self, inner_input: Dict) -> Dict:
|
||||
if not isinstance(inner_input, dict):
|
||||
if len(self.input_variables) == 1:
|
||||
var_name = self.input_variables[0]
|
||||
@@ -105,7 +105,17 @@ class BasePromptTemplate(
|
||||
f" Expected: {self.input_variables}"
|
||||
f" Received: {list(inner_input.keys())}"
|
||||
)
|
||||
return self.format_prompt(**inner_input)
|
||||
return inner_input
|
||||
|
||||
def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue:
|
||||
_inner_input = self._validate_input(inner_input)
|
||||
return self.format_prompt(**_inner_input)
|
||||
|
||||
async def _aformat_prompt_with_error_handling(
|
||||
self, inner_input: Dict
|
||||
) -> PromptValue:
|
||||
_inner_input = self._validate_input(inner_input)
|
||||
return await self.aformat_prompt(**_inner_input)
|
||||
|
||||
def invoke(
|
||||
self, input: Dict, config: Optional[RunnableConfig] = None
|
||||
@@ -122,10 +132,29 @@ class BasePromptTemplate(
|
||||
run_type="prompt",
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Dict, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> PromptValue:
|
||||
config = ensure_config(config)
|
||||
if self.metadata:
|
||||
config["metadata"].update(self.metadata)
|
||||
if self.tags:
|
||||
config["tags"].extend(self.tags)
|
||||
return await self._acall_with_config(
|
||||
self._aformat_prompt_with_error_handling,
|
||||
input,
|
||||
config,
|
||||
run_type="prompt",
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
"""Create Prompt Value."""
|
||||
|
||||
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
"""Create Prompt Value."""
|
||||
return self.format_prompt(**kwargs)
|
||||
|
||||
@root_validator()
|
||||
def validate_variable_names(cls, values: Dict) -> Dict:
|
||||
"""Validate variable names do not include restricted names."""
|
||||
|
@@ -609,6 +609,18 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||
"""
|
||||
return self.format_prompt(**kwargs).to_string()
|
||||
|
||||
async def aformat(self, **kwargs: Any) -> str:
|
||||
"""Format the chat template into a string.
|
||||
|
||||
Args:
|
||||
**kwargs: keyword arguments to use for filling in template variables
|
||||
in all the template messages in this chat template.
|
||||
|
||||
Returns:
|
||||
formatted string
|
||||
"""
|
||||
return (await self.aformat_prompt(**kwargs)).to_string()
|
||||
|
||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
"""
|
||||
Format prompt. Should return a PromptValue.
|
||||
@@ -621,6 +633,10 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||
messages = self.format_messages(**kwargs)
|
||||
return ChatPromptValue(messages=messages)
|
||||
|
||||
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
messages = await self.aformat_messages(**kwargs)
|
||||
return ChatPromptValue(messages=messages)
|
||||
|
||||
@abstractmethod
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format kwargs into a list of messages."""
|
||||
|
@@ -37,9 +37,11 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
|
||||
return ["langchain", "prompts", "image"]
|
||||
|
||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
"""Create Chat Messages."""
|
||||
return ImagePromptValue(image_url=self.format(**kwargs))
|
||||
|
||||
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
return ImagePromptValue(image_url=await self.aformat(**kwargs))
|
||||
|
||||
def format(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
|
@@ -54,9 +54,22 @@ class PipelinePromptTemplate(BasePromptTemplate):
|
||||
_inputs = _get_inputs(kwargs, self.final_prompt.input_variables)
|
||||
return self.final_prompt.format_prompt(**_inputs)
|
||||
|
||||
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
for k, prompt in self.pipeline_prompts:
|
||||
_inputs = _get_inputs(kwargs, prompt.input_variables)
|
||||
if isinstance(prompt, BaseChatPromptTemplate):
|
||||
kwargs[k] = await prompt.aformat_messages(**_inputs)
|
||||
else:
|
||||
kwargs[k] = await prompt.aformat(**_inputs)
|
||||
_inputs = _get_inputs(kwargs, self.final_prompt.input_variables)
|
||||
return await self.final_prompt.aformat_prompt(**_inputs)
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
return self.format_prompt(**kwargs).to_string()
|
||||
|
||||
async def aformat(self, **kwargs: Any) -> str:
|
||||
return (await self.aformat_prompt(**kwargs)).to_string()
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
raise ValueError
|
||||
|
@@ -160,9 +160,11 @@ class StringPromptTemplate(BasePromptTemplate, ABC):
|
||||
return ["langchain", "prompts", "base"]
|
||||
|
||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
"""Create Chat Messages."""
|
||||
return StringPromptValue(text=self.format(**kwargs))
|
||||
|
||||
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
return StringPromptValue(text=await self.aformat(**kwargs))
|
||||
|
||||
def pretty_repr(self, html: bool = False) -> str:
|
||||
# TODO: handle partials
|
||||
dummy_vars = {
|
||||
|
Reference in New Issue
Block a user