diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index da2a351cac8..381ce854d3e 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -87,7 +87,7 @@ class BasePromptTemplate( **{k: (self.input_types.get(k, str), None) for k in self.input_variables}, ) - def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue: + def _validate_input(self, inner_input: Dict) -> Dict: if not isinstance(inner_input, dict): if len(self.input_variables) == 1: var_name = self.input_variables[0] @@ -105,7 +105,17 @@ class BasePromptTemplate( f" Expected: {self.input_variables}" f" Received: {list(inner_input.keys())}" ) - return self.format_prompt(**inner_input) + return inner_input + + def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue: + _inner_input = self._validate_input(inner_input) + return self.format_prompt(**_inner_input) + + async def _aformat_prompt_with_error_handling( + self, inner_input: Dict + ) -> PromptValue: + _inner_input = self._validate_input(inner_input) + return await self.aformat_prompt(**_inner_input) def invoke( self, input: Dict, config: Optional[RunnableConfig] = None @@ -122,10 +132,29 @@ class BasePromptTemplate( run_type="prompt", ) + async def ainvoke( + self, input: Dict, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> PromptValue: + config = ensure_config(config) + if self.metadata: + config["metadata"].update(self.metadata) + if self.tags: + config["tags"].extend(self.tags) + return await self._acall_with_config( + self._aformat_prompt_with_error_handling, + input, + config, + run_type="prompt", + ) + @abstractmethod def format_prompt(self, **kwargs: Any) -> PromptValue: """Create Prompt Value.""" + async def aformat_prompt(self, **kwargs: Any) -> PromptValue: + """Create Prompt Value.""" + return self.format_prompt(**kwargs) + @root_validator() def validate_variable_names(cls, values: Dict) -> Dict: """Validate variable names do not include restricted names.""" diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 0df0a4bb146..d03461e7ffa 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -609,6 +609,18 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC): """ return self.format_prompt(**kwargs).to_string() + async def aformat(self, **kwargs: Any) -> str: + """Format the chat template into a string. + + Args: + **kwargs: keyword arguments to use for filling in template variables + in all the template messages in this chat template. + + Returns: + formatted string + """ + return (await self.aformat_prompt(**kwargs)).to_string() + def format_prompt(self, **kwargs: Any) -> PromptValue: """ Format prompt. Should return a PromptValue. @@ -621,6 +633,10 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC): messages = self.format_messages(**kwargs) return ChatPromptValue(messages=messages) + async def aformat_prompt(self, **kwargs: Any) -> PromptValue: + messages = await self.aformat_messages(**kwargs) + return ChatPromptValue(messages=messages) + @abstractmethod def format_messages(self, **kwargs: Any) -> List[BaseMessage]: """Format kwargs into a list of messages.""" diff --git a/libs/core/langchain_core/prompts/image.py b/libs/core/langchain_core/prompts/image.py index 477c30e8120..d4f47779a3a 100644 --- a/libs/core/langchain_core/prompts/image.py +++ b/libs/core/langchain_core/prompts/image.py @@ -37,9 +37,11 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]): return ["langchain", "prompts", "image"] def format_prompt(self, **kwargs: Any) -> PromptValue: - """Create Chat Messages.""" return ImagePromptValue(image_url=self.format(**kwargs)) + async def aformat_prompt(self, **kwargs: Any) -> PromptValue: + return ImagePromptValue(image_url=await self.aformat(**kwargs)) + def format( self, **kwargs: Any, diff --git a/libs/core/langchain_core/prompts/pipeline.py b/libs/core/langchain_core/prompts/pipeline.py index 5c0cc00402f..f89c341d2f3 100644 --- a/libs/core/langchain_core/prompts/pipeline.py +++ b/libs/core/langchain_core/prompts/pipeline.py @@ -54,9 +54,22 @@ class PipelinePromptTemplate(BasePromptTemplate): _inputs = _get_inputs(kwargs, self.final_prompt.input_variables) return self.final_prompt.format_prompt(**_inputs) + async def aformat_prompt(self, **kwargs: Any) -> PromptValue: + for k, prompt in self.pipeline_prompts: + _inputs = _get_inputs(kwargs, prompt.input_variables) + if isinstance(prompt, BaseChatPromptTemplate): + kwargs[k] = await prompt.aformat_messages(**_inputs) + else: + kwargs[k] = await prompt.aformat(**_inputs) + _inputs = _get_inputs(kwargs, self.final_prompt.input_variables) + return await self.final_prompt.aformat_prompt(**_inputs) + def format(self, **kwargs: Any) -> str: return self.format_prompt(**kwargs).to_string() + async def aformat(self, **kwargs: Any) -> str: + return (await self.aformat_prompt(**kwargs)).to_string() + @property def _prompt_type(self) -> str: raise ValueError diff --git a/libs/core/langchain_core/prompts/string.py b/libs/core/langchain_core/prompts/string.py index d95a3ce9d2f..b324871da56 100644 --- a/libs/core/langchain_core/prompts/string.py +++ b/libs/core/langchain_core/prompts/string.py @@ -160,9 +160,11 @@ class StringPromptTemplate(BasePromptTemplate, ABC): return ["langchain", "prompts", "base"] def format_prompt(self, **kwargs: Any) -> PromptValue: - """Create Chat Messages.""" return StringPromptValue(text=self.format(**kwargs)) + async def aformat_prompt(self, **kwargs: Any) -> PromptValue: + return StringPromptValue(text=await self.aformat(**kwargs)) + def pretty_repr(self, html: bool = False) -> str: # TODO: handle partials dummy_vars = { diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 305e981dfe9..152c612fe1d 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -28,7 +28,8 @@ from langchain_core.prompts.chat import ( ) -def create_messages() -> List[BaseMessagePromptTemplate]: +@pytest.fixture +def messages() -> List[BaseMessagePromptTemplate]: """Create messages.""" system_message_prompt = SystemMessagePromptTemplate( prompt=PromptTemplate( @@ -63,11 +64,14 @@ def create_messages() -> List[BaseMessagePromptTemplate]: ] -def create_chat_prompt_template() -> ChatPromptTemplate: +@pytest.fixture +def chat_prompt_template( + messages: List[BaseMessagePromptTemplate], +) -> ChatPromptTemplate: """Create a chat prompt template.""" return ChatPromptTemplate( input_variables=["foo", "bar", "context"], - messages=create_messages(), # type: ignore[arg-type] + messages=messages, # type: ignore[arg-type] ) @@ -110,10 +114,9 @@ def test_message_prompt_template_from_template_file() -> None: assert expected == actual -def test_chat_prompt_template() -> None: +async def test_chat_prompt_template(chat_prompt_template: ChatPromptTemplate) -> None: """Test chat prompt template.""" - prompt_template = create_chat_prompt_template() - prompt = prompt_template.format_prompt(foo="foo", bar="bar", context="context") + prompt = chat_prompt_template.format_prompt(foo="foo", bar="bar", context="context") assert isinstance(prompt, ChatPromptValue) messages = prompt.to_messages() assert len(messages) == 4 @@ -122,6 +125,12 @@ def test_chat_prompt_template() -> None: assert messages[2].content == "I'm an AI. I'm foo. I'm bar." assert messages[3].content == "I'm a generic message. I'm foo. I'm bar." + async_prompt = await chat_prompt_template.aformat_prompt( + foo="foo", bar="bar", context="context" + ) + + assert async_prompt.to_messages() == messages + string = prompt.to_string() expected = ( "System: Here's some context: context\n" @@ -131,13 +140,15 @@ def test_chat_prompt_template() -> None: ) assert string == expected - string = prompt_template.format(foo="foo", bar="bar", context="context") + string = chat_prompt_template.format(foo="foo", bar="bar", context="context") assert string == expected -def test_chat_prompt_template_from_messages() -> None: +def test_chat_prompt_template_from_messages( + messages: List[BaseMessagePromptTemplate], +) -> None: """Test creating a chat prompt template from messages.""" - chat_prompt_template = ChatPromptTemplate.from_messages(create_messages()) + chat_prompt_template = ChatPromptTemplate.from_messages(messages) assert sorted(chat_prompt_template.input_variables) == sorted( ["context", "foo", "bar"] ) @@ -171,11 +182,12 @@ def test_chat_prompt_template_from_messages_using_role_strings() -> None: ] -def test_chat_prompt_template_with_messages() -> None: - messages: List[Union[BaseMessagePromptTemplate, BaseMessage]] = ( - create_messages() + [HumanMessage(content="foo")] +def test_chat_prompt_template_with_messages( + messages: List[BaseMessagePromptTemplate], +) -> None: + chat_prompt_template = ChatPromptTemplate.from_messages( + messages + [HumanMessage(content="foo")] ) - chat_prompt_template = ChatPromptTemplate.from_messages(messages) assert sorted(chat_prompt_template.input_variables) == sorted( ["context", "foo", "bar"] ) diff --git a/libs/core/tests/unit_tests/prompts/test_pipeline_prompt.py b/libs/core/tests/unit_tests/prompts/test_pipeline_prompt.py index ece234a9d5b..f62af4c7577 100644 --- a/libs/core/tests/unit_tests/prompts/test_pipeline_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_pipeline_prompt.py @@ -32,7 +32,7 @@ def test_multi_variable_pipeline() -> None: assert output == "okay jim deep" -def test_partial_with_chat_prompts() -> None: +async def test_partial_with_chat_prompts() -> None: prompt_a = ChatPromptTemplate( input_variables=["foo"], messages=[MessagesPlaceholder(variable_name="foo")] ) @@ -43,3 +43,5 @@ def test_partial_with_chat_prompts() -> None: assert pipeline_prompt.input_variables == ["bar"] output = pipeline_prompt.format_prompt(bar="okay") assert output.to_messages()[0].content == "jim okay" + output = await pipeline_prompt.aformat_prompt(bar="okay") + assert output.to_messages()[0].content == "jim okay" diff --git a/libs/core/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py index 476666fa7cd..2a62872446d 100644 --- a/libs/core/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -351,3 +351,22 @@ def test_prompt_invoke_with_metadata() -> None: assert len(tracer.traced_runs) == 1 assert tracer.traced_runs[0].extra["metadata"] == {"version": "1", "foo": "bar"} # type: ignore assert tracer.traced_runs[0].tags == ["tag1", "tag2"] # type: ignore + + +async def test_prompt_ainvoke_with_metadata() -> None: + """Test prompt can be invoked with metadata.""" + template = "This is a {foo} test." + prompt = PromptTemplate( + input_variables=["foo"], + template=template, + metadata={"version": "1"}, + tags=["tag1", "tag2"], + ) + tracer = RunCollectorCallbackHandler() + result = await prompt.ainvoke( + {"foo": "bar"}, {"metadata": {"foo": "bar"}, "callbacks": [tracer]} + ) + assert result.to_string() == "This is a bar test." + assert len(tracer.traced_runs) == 1 + assert tracer.traced_runs[0].extra["metadata"] == {"version": "1", "foo": "bar"} # type: ignore + assert tracer.traced_runs[0].tags == ["tag1", "tag2"] # type: ignore