diff --git a/docs/docs/modules/chains/document/map_reduce.ipynb b/docs/docs/modules/chains/document/map_reduce.ipynb index 218c844b143..46d58e2e624 100644 --- a/docs/docs/modules/chains/document/map_reduce.ipynb +++ b/docs/docs/modules/chains/document/map_reduce.ipynb @@ -32,7 +32,7 @@ "from functools import partial\n", "\n", "from langchain.callbacks.manager import CallbackManagerForChainRun\n", - "from langchain.chains.combine_documents.reduce import _collapse_docs, _split_list_of_docs\n", + "from langchain.chains.combine_documents import collapse_docs, split_list_of_docs\n", "from langchain.chat_models import ChatAnthropic\n", "from langchain.prompts import PromptTemplate\n", "from langchain.schema import StrOutputParser\n", @@ -109,8 +109,8 @@ " while get_num_tokens(docs) > token_max:\n", " config[\"run_name\"] = f\"Collapse {collapse_ct}\"\n", " invoke = partial(collapse_chain.invoke, config=config)\n", - " split_docs = _split_list_of_docs(docs, get_num_tokens, token_max)\n", - " docs = [_collapse_docs(_docs, invoke) for _docs in split_docs]\n", + " split_docs = split_list_of_docs(docs, get_num_tokens, token_max)\n", + " docs = [collapse_docs(_docs, invoke) for _docs in split_docs]\n", " collapse_ct += 1\n", " return docs" ] diff --git a/libs/langchain/langchain/chains/combine_documents/__init__.py b/libs/langchain/langchain/chains/combine_documents/__init__.py index f22b1ccbe84..9c66d934324 100644 --- a/libs/langchain/langchain/chains/combine_documents/__init__.py +++ b/libs/langchain/langchain/chains/combine_documents/__init__.py @@ -1 +1,9 @@ """Different ways to combine documents.""" + +from langchain.chains.combine_documents.reduce import ( + acollapse_docs, + collapse_docs, + split_list_of_docs, +) + +__all__ = ["acollapse_docs", "collapse_docs", "split_list_of_docs"] diff --git a/libs/langchain/langchain/chains/combine_documents/reduce.py b/libs/langchain/langchain/chains/combine_documents/reduce.py index 9ad2415b12a..95081704594 100644 --- a/libs/langchain/langchain/chains/combine_documents/reduce.py +++ b/libs/langchain/langchain/chains/combine_documents/reduce.py @@ -24,9 +24,21 @@ class AsyncCombineDocsProtocol(Protocol): """Async interface for the combine_docs method.""" -def _split_list_of_docs( +def split_list_of_docs( docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any ) -> List[List[Document]]: + """Split Documents into subsets that each meet a cumulative length constraint. + + Args: + docs: The full list of Documents. + length_func: Function for computing the cumulative length of a set of Documents. + token_max: The maximum cumulative length of any subset of Documents. + **kwargs: Arbitrary additional keyword params to pass to each call of the + length_func. + + Returns: + A List[List[Document]]. + """ new_result_doc_list = [] _sub_result_docs = [] for doc in docs: @@ -44,11 +56,27 @@ def _split_list_of_docs( return new_result_doc_list -def _collapse_docs( +def collapse_docs( docs: List[Document], combine_document_func: CombineDocsProtocol, **kwargs: Any, ) -> Document: + """Execute a collapse function on a set of documents and merge their metadatas. + + Args: + docs: A list of Documents to combine. + combine_document_func: A function that takes in a list of Documents and + optionally addition keyword parameters and combines them into a single + string. + **kwargs: Arbitrary additional keyword params to pass to the + combine_document_func. + + Returns: + A single Document with the output of combine_document_func for the page content + and the combined metadata's of all the input documents. All metadata values + are strings, and where there are overlapping keys across documents the + values are joined by ", ". + """ result = combine_document_func(docs, **kwargs) combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()} for doc in docs[1:]: @@ -60,11 +88,27 @@ def _collapse_docs( return Document(page_content=result, metadata=combined_metadata) -async def _acollapse_docs( +async def acollapse_docs( docs: List[Document], combine_document_func: AsyncCombineDocsProtocol, **kwargs: Any, ) -> Document: + """Execute a collapse function on a set of documents and merge their metadatas. + + Args: + docs: A list of Documents to combine. + combine_document_func: A function that takes in a list of Documents and + optionally addition keyword parameters and combines them into a single + string. + **kwargs: Arbitrary additional keyword params to pass to the + combine_document_func. + + Returns: + A single Document with the output of combine_document_func for the page content + and the combined metadata's of all the input documents. All metadata values + are strings, and where there are overlapping keys across documents the + values are joined by ", ". + """ result = await combine_document_func(docs, **kwargs) combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()} for doc in docs[1:]: @@ -245,12 +289,12 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): _token_max = token_max or self.token_max while num_tokens is not None and num_tokens > _token_max: - new_result_doc_list = _split_list_of_docs( + new_result_doc_list = split_list_of_docs( result_docs, length_func, _token_max, **kwargs ) result_docs = [] for docs in new_result_doc_list: - new_doc = _collapse_docs(docs, _collapse_docs_func, **kwargs) + new_doc = collapse_docs(docs, _collapse_docs_func, **kwargs) result_docs.append(new_doc) num_tokens = length_func(result_docs, **kwargs) return result_docs, {} @@ -273,12 +317,12 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): _token_max = token_max or self.token_max while num_tokens is not None and num_tokens > _token_max: - new_result_doc_list = _split_list_of_docs( + new_result_doc_list = split_list_of_docs( result_docs, length_func, _token_max, **kwargs ) result_docs = [] for docs in new_result_doc_list: - new_doc = await _acollapse_docs(docs, _collapse_docs_func, **kwargs) + new_doc = await acollapse_docs(docs, _collapse_docs_func, **kwargs) result_docs.append(new_doc) num_tokens = length_func(result_docs, **kwargs) return result_docs, {} diff --git a/libs/langchain/langchain/schema/prompt_template.py b/libs/langchain/langchain/schema/prompt_template.py index 58d7a386603..224e579bcc1 100644 --- a/libs/langchain/langchain/schema/prompt_template.py +++ b/libs/langchain/langchain/schema/prompt_template.py @@ -205,6 +205,7 @@ def format_document(doc: Document, prompt: BasePromptTemplate) -> str: from langchain.schema import Document from langchain.prompts import PromptTemplate + doc = Document(page_content="This is a joke", metadata={"page": "1"}) prompt = PromptTemplate.from_template("Page {page}: {page_content}") format_document(doc, prompt) diff --git a/libs/langchain/tests/unit_tests/chains/test_combine_documents.py b/libs/langchain/tests/unit_tests/chains/test_combine_documents.py index 7acbdeba631..9bd9baf5211 100644 --- a/libs/langchain/tests/unit_tests/chains/test_combine_documents.py +++ b/libs/langchain/tests/unit_tests/chains/test_combine_documents.py @@ -5,8 +5,8 @@ from typing import Any, List import pytest from langchain.chains.combine_documents.reduce import ( - _collapse_docs, - _split_list_of_docs, + collapse_docs, + split_list_of_docs, ) from langchain.chains.qa_with_sources import load_qa_with_sources_chain from langchain.docstore.document import Document @@ -32,20 +32,20 @@ def test__split_list_long_single_doc() -> None: """Test splitting of a long single doc.""" docs = [Document(page_content="foo" * 100)] with pytest.raises(ValueError): - _split_list_of_docs(docs, _fake_docs_len_func, 100) + split_list_of_docs(docs, _fake_docs_len_func, 100) def test__split_list_single_doc() -> None: """Test splitting works with just a single doc.""" docs = [Document(page_content="foo")] - doc_list = _split_list_of_docs(docs, _fake_docs_len_func, 100) + doc_list = split_list_of_docs(docs, _fake_docs_len_func, 100) assert doc_list == [docs] def test__split_list_double_doc() -> None: """Test splitting works with just two docs.""" docs = [Document(page_content="foo"), Document(page_content="bar")] - doc_list = _split_list_of_docs(docs, _fake_docs_len_func, 100) + doc_list = split_list_of_docs(docs, _fake_docs_len_func, 100) assert doc_list == [docs] @@ -59,7 +59,7 @@ def test__split_list_works_correctly() -> None: Document(page_content="bar"), Document(page_content="baz"), ] - doc_list = _split_list_of_docs(docs, _fake_docs_len_func, 10) + doc_list = split_list_of_docs(docs, _fake_docs_len_func, 10) expected_result = [ # Test a group of three. [ @@ -82,7 +82,7 @@ def test__collapse_docs_no_metadata() -> None: Document(page_content="bar"), Document(page_content="baz"), ] - output = _collapse_docs(docs, _fake_combine_docs_func) + output = collapse_docs(docs, _fake_combine_docs_func) expected_output = Document(page_content="foobarbaz") assert output == expected_output @@ -91,12 +91,12 @@ def test__collapse_docs_one_doc() -> None: """Test collapse documents functionality when only one document present.""" # Test with no metadata. docs = [Document(page_content="foo")] - output = _collapse_docs(docs, _fake_combine_docs_func) + output = collapse_docs(docs, _fake_combine_docs_func) assert output == docs[0] # Test with metadata. docs = [Document(page_content="foo", metadata={"source": "a"})] - output = _collapse_docs(docs, _fake_combine_docs_func) + output = collapse_docs(docs, _fake_combine_docs_func) assert output == docs[0] @@ -108,7 +108,7 @@ def test__collapse_docs_metadata() -> None: Document(page_content="foo", metadata=metadata1), Document(page_content="bar", metadata=metadata2), ] - output = _collapse_docs(docs, _fake_combine_docs_func) + output = collapse_docs(docs, _fake_combine_docs_func) expected_metadata = { "source": "a, b", "foo": "2, 3",