From 7e5c1905b148ba86e06c616ca85ad9b038902f53 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Fri, 5 Apr 2024 16:29:53 +0200 Subject: [PATCH] core[minor]: Add async aformat_document method (#20037) --- libs/core/langchain_core/prompts/__init__.py | 7 ++- libs/core/langchain_core/prompts/base.py | 55 ++++++++++++++----- .../tests/unit_tests/prompts/test_imports.py | 1 + libs/langchain/langchain/tools/retriever.py | 9 ++- .../chains/test_combine_documents.py | 10 +++- 5 files changed, 63 insertions(+), 19 deletions(-) diff --git a/libs/core/langchain_core/prompts/__init__.py b/libs/core/langchain_core/prompts/__init__.py index 05827d9224f..1c9545dca17 100644 --- a/libs/core/langchain_core/prompts/__init__.py +++ b/libs/core/langchain_core/prompts/__init__.py @@ -24,7 +24,11 @@ from multiple components and prompt values. Prompt classes and functions make co SystemMessagePromptTemplate """ # noqa: E501 -from langchain_core.prompts.base import BasePromptTemplate, format_document +from langchain_core.prompts.base import ( + BasePromptTemplate, + aformat_document, + format_document, +) from langchain_core.prompts.chat import ( AIMessagePromptTemplate, BaseChatPromptTemplate, @@ -67,6 +71,7 @@ __all__ = [ "SystemMessagePromptTemplate", "load_prompt", "format_document", + "aformat_document", "check_valid_template", "get_template_variables", "jinja2_formatter", diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index b7275f7181a..da2a351cac8 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -251,6 +251,21 @@ class BasePromptTemplate( raise ValueError(f"{save_path} must be json or yaml") +def _get_document_info(doc: Document, prompt: BasePromptTemplate[str]) -> Dict: + base_info = {"page_content": doc.page_content, **doc.metadata} + missing_metadata = set(prompt.input_variables).difference(base_info) + if len(missing_metadata) > 0: + required_metadata = [ + iv for iv in prompt.input_variables if iv != "page_content" + ] + raise ValueError( + f"Document prompt requires documents to have metadata variables: " + f"{required_metadata}. Received document with missing metadata: " + f"{list(missing_metadata)}." + ) + return {k: base_info[k] for k in prompt.input_variables} + + def format_document(doc: Document, prompt: BasePromptTemplate[str]) -> str: """Format a document into a string based on a prompt template. @@ -285,16 +300,30 @@ def format_document(doc: Document, prompt: BasePromptTemplate[str]) -> str: format_document(doc, prompt) >>> "Page 1: This is a joke" """ - base_info = {"page_content": doc.page_content, **doc.metadata} - missing_metadata = set(prompt.input_variables).difference(base_info) - if len(missing_metadata) > 0: - required_metadata = [ - iv for iv in prompt.input_variables if iv != "page_content" - ] - raise ValueError( - f"Document prompt requires documents to have metadata variables: " - f"{required_metadata}. Received document with missing metadata: " - f"{list(missing_metadata)}." - ) - document_info = {k: base_info[k] for k in prompt.input_variables} - return prompt.format(**document_info) + return prompt.format(**_get_document_info(doc, prompt)) + + +async def aformat_document(doc: Document, prompt: BasePromptTemplate[str]) -> str: + """Format a document into a string based on a prompt template. + + First, this pulls information from the document from two sources: + + 1. `page_content`: + This takes the information from the `document.page_content` + and assigns it to a variable named `page_content`. + 2. metadata: + This takes information from `document.metadata` and assigns + it to variables of the same name. + + Those variables are then passed into the `prompt` to produce a formatted string. + + Args: + doc: Document, the page_content and metadata will be used to create + the final string. + prompt: BasePromptTemplate, will be used to format the page_content + and metadata into the final string. + + Returns: + string of the document formatted. + """ + return await prompt.aformat(**_get_document_info(doc, prompt)) diff --git a/libs/core/tests/unit_tests/prompts/test_imports.py b/libs/core/tests/unit_tests/prompts/test_imports.py index a9cacf91570..a3a43f8957b 100644 --- a/libs/core/tests/unit_tests/prompts/test_imports.py +++ b/libs/core/tests/unit_tests/prompts/test_imports.py @@ -10,6 +10,7 @@ EXPECTED_ALL = [ "FewShotPromptWithTemplates", "FewShotChatMessagePromptTemplate", "format_document", + "aformat_document", "HumanMessagePromptTemplate", "MessagesPlaceholder", "PipelinePromptTemplate", diff --git a/libs/langchain/langchain/tools/retriever.py b/libs/langchain/langchain/tools/retriever.py index b999c74b6cb..5feeab6e04b 100644 --- a/libs/langchain/langchain/tools/retriever.py +++ b/libs/langchain/langchain/tools/retriever.py @@ -4,7 +4,12 @@ from typing import Optional from langchain_core.callbacks.manager import ( Callbacks, ) -from langchain_core.prompts import BasePromptTemplate, PromptTemplate, format_document +from langchain_core.prompts import ( + BasePromptTemplate, + PromptTemplate, + aformat_document, + format_document, +) from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.retrievers import BaseRetriever @@ -39,7 +44,7 @@ async def _aget_relevant_documents( ) -> str: docs = await retriever.aget_relevant_documents(query, callbacks=callbacks) return document_separator.join( - format_document(doc, document_prompt) for doc in docs + [await aformat_document(doc, document_prompt) for doc in docs] ) diff --git a/libs/langchain/tests/unit_tests/chains/test_combine_documents.py b/libs/langchain/tests/unit_tests/chains/test_combine_documents.py index 4dbb4b745ee..8de556bb8b9 100644 --- a/libs/langchain/tests/unit_tests/chains/test_combine_documents.py +++ b/libs/langchain/tests/unit_tests/chains/test_combine_documents.py @@ -4,7 +4,7 @@ from typing import Any, List import pytest from langchain_core.documents import Document -from langchain_core.prompts import PromptTemplate, format_document +from langchain_core.prompts import PromptTemplate, aformat_document, format_document from langchain.chains.combine_documents.reduce import ( collapse_docs, @@ -119,7 +119,7 @@ def test__collapse_docs_metadata() -> None: assert output == expected_output -def test_format_doc_with_metadata() -> None: +async def test_format_doc_with_metadata() -> None: """Test format doc on a valid document.""" doc = Document(page_content="foo", metadata={"bar": "baz"}) prompt = PromptTemplate( @@ -128,9 +128,11 @@ def test_format_doc_with_metadata() -> None: expected_output = "foo, baz" output = format_document(doc, prompt) assert output == expected_output + output = await aformat_document(doc, prompt) + assert output == expected_output -def test_format_doc_missing_metadata() -> None: +async def test_format_doc_missing_metadata() -> None: """Test format doc on a document with missing metadata.""" doc = Document(page_content="foo") prompt = PromptTemplate( @@ -138,3 +140,5 @@ def test_format_doc_missing_metadata() -> None: ) with pytest.raises(ValueError): format_document(doc, prompt) + with pytest.raises(ValueError): + await aformat_document(doc, prompt)