From f97db8cc7bb85484e026758ca1ba464d7d3e0e9a Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 26 Dec 2022 09:14:40 -0500 Subject: [PATCH] return intermediate steps in combine document chains --- .../data_augmented_generation/summarize.ipynb | 40 +++++++++++++++---- langchain/chains/combine_documents/base.py | 9 +++-- .../chains/combine_documents/map_reduce.py | 25 ++++++++++-- langchain/chains/combine_documents/refine.py | 25 ++++++++++-- langchain/chains/combine_documents/stuff.py | 6 +-- langchain/chains/mapreduce.py | 2 +- langchain/chains/qa_with_sources/base.py | 2 +- langchain/chains/vector_db_qa/base.py | 2 +- 8 files changed, 86 insertions(+), 25 deletions(-) diff --git a/docs/examples/data_augmented_generation/summarize.ipynb b/docs/examples/data_augmented_generation/summarize.ipynb index ee6b6572b31..5469c71d892 100644 --- a/docs/examples/data_augmented_generation/summarize.ipynb +++ b/docs/examples/data_augmented_generation/summarize.ipynb @@ -60,7 +60,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "id": "8dff4f43", "metadata": {}, "outputs": [], @@ -70,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "id": "27989fc4", "metadata": {}, "outputs": [], @@ -131,33 +131,57 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "id": "ef28e1d4", "metadata": {}, "outputs": [], "source": [ - "chain = load_summarize_chain(llm, chain_type=\"map_reduce\")" + "chain = load_summarize_chain(llm, chain_type=\"map_reduce\", verbose=True, return_map_steps=True)" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "f82c5f9f", "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new MapReduceDocumentsChain chain...\u001b[0m\n", + "\n", + "\u001b[1m> Finished MapReduceDocumentsChain chain.\u001b[0m\n" + ] + } + ], + "source": [ + "res = chain(docs)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "f5a2b653", + "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "\" In response to Vladimir Putin's aggression in Ukraine, the US and its allies have taken action to hold him accountable, including economic sanctions, cutting off access to technology, and seizing the assets of Russian oligarchs. They are also providing military, economic, and humanitarian assistance to the Ukrainians, and releasing 60 million barrels of oil from reserves around the world. President Biden has passed several laws to provide economic relief to Americans and create jobs, and is making sure taxpayer dollars support American jobs and businesses.\"" + "[{'text': \" In response to Russia's aggression in Ukraine, the United States has united with other freedom-loving nations to impose economic sanctions and hold Putin accountable. The U.S. Department of Justice is also assembling a task force to go after the crimes of Russian oligarchs and seize their ill-gotten gains.\"},\n", + " {'text': ' The United States and its European allies are taking action to punish Russia for its invasion of Ukraine, including seizing assets, closing off airspace, and providing economic and military assistance to Ukraine. The US is also mobilizing forces to protect NATO countries and has released 30 million barrels of oil from its Strategic Petroleum Reserve to help blunt gas prices. The world is uniting in support of Ukraine and democracy, and the US stands with its Ukrainian American citizens.'},\n", + " {'text': ' President Biden and Vice President Harris ran for office with a new economic vision for America, and have since passed the American Rescue Plan and the Bipartisan Infrastructure Law to help working people and rebuild America. These plans will create jobs, modernize roads, airports, ports, and waterways, and provide clean water and high-speed internet for all Americans. The government will also be investing in American products to support American jobs.'}]" ] }, - "execution_count": 9, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "chain.run(docs)" + "res['map_steps']" ] }, { diff --git a/langchain/chains/combine_documents/base.py b/langchain/chains/combine_documents/base.py index 7d5574caaea..944440e94a0 100644 --- a/langchain/chains/combine_documents/base.py +++ b/langchain/chains/combine_documents/base.py @@ -1,7 +1,7 @@ """Base interface for chains combining documents.""" from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from pydantic import BaseModel @@ -39,12 +39,13 @@ class BaseCombineDocumentsChain(Chain, BaseModel, ABC): return None @abstractmethod - def combine_docs(self, docs: List[Document], **kwargs: Any) -> str: + def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: """Combine documents into a single string.""" def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: docs = inputs[self.input_key] # Other keys are assumed to be needed for LLM prediction other_keys = {k: v for k, v in inputs.items() if k != self.input_key} - output = self.combine_docs(docs, **other_keys) - return {self.output_key: output} + output, extra_return_dict = self.combine_docs(docs, **other_keys) + extra_return_dict[self.output_key] = output + return extra_return_dict diff --git a/langchain/chains/combine_documents/map_reduce.py b/langchain/chains/combine_documents/map_reduce.py index 8653ef1a7f8..dd192062f3d 100644 --- a/langchain/chains/combine_documents/map_reduce.py +++ b/langchain/chains/combine_documents/map_reduce.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple from pydantic import BaseModel, Extra, root_validator @@ -65,6 +65,19 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel): document_variable_name: str """The variable name in the llm_chain to put the documents in. If only one variable in the llm_chain, this need not be provided.""" + return_map_steps: bool = False + """Return the results of the map steps in the output.""" + + @property + def output_keys(self) -> List[str]: + """Expect input key. + + :meta private: + """ + _output_keys = super().output_keys + if self.return_map_steps: + _output_keys = _output_keys + ["map_steps"] + return _output_keys class Config: """Configuration for this pydantic object.""" @@ -102,7 +115,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel): def combine_docs( self, docs: List[Document], token_max: int = 3000, **kwargs: Any - ) -> str: + ) -> Tuple[str, dict]: """Combine documents in a map reduce manner. Combine by mapping first chain over all documents, then reducing the results. @@ -133,5 +146,9 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel): num_tokens = self.combine_document_chain.prompt_length( result_docs, **kwargs ) - output = self.combine_document_chain.combine_docs(result_docs, **kwargs) - return output + if self.return_map_steps: + extra_return_dict = {"map_steps": results} + else: + extra_return_dict = {} + output, _ = self.combine_document_chain.combine_docs(result_docs, **kwargs) + return output, extra_return_dict diff --git a/langchain/chains/combine_documents/refine.py b/langchain/chains/combine_documents/refine.py index 105df6f4887..c91bf07089f 100644 --- a/langchain/chains/combine_documents/refine.py +++ b/langchain/chains/combine_documents/refine.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict, List +from typing import Any, Dict, List, Tuple from pydantic import BaseModel, Extra, Field, root_validator @@ -33,6 +33,19 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel): default_factory=_get_default_document_prompt ) """Prompt to use to format each document.""" + return_refine_steps: bool = False + """Return the results of the refine steps in the output.""" + + @property + def output_keys(self) -> List[str]: + """Expect input key. + + :meta private: + """ + _output_keys = super().output_keys + if self.return_refine_steps: + _output_keys = _output_keys + ["refine_steps"] + return _output_keys class Config: """Configuration for this pydantic object.""" @@ -61,7 +74,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel): ) return values - def combine_docs(self, docs: List[Document], **kwargs: Any) -> str: + def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: """Combine by mapping first chain over all, then stuffing into final chain.""" base_info = {"page_content": docs[0].page_content} base_info.update(docs[0].metadata) @@ -71,6 +84,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel): } inputs = {**base_inputs, **kwargs} res = self.initial_llm_chain.predict(**inputs) + refine_steps = [res] for doc in docs[1:]: base_info = {"page_content": doc.page_content} base_info.update(doc.metadata) @@ -85,4 +99,9 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel): } inputs = {**base_inputs, **kwargs} res = self.refine_llm_chain.predict(**inputs) - return res + refine_steps.append(res) + if self.return_refine_steps: + extra_return_dict = {"refine_steps": refine_steps} + else: + extra_return_dict = {} + return res, extra_return_dict diff --git a/langchain/chains/combine_documents/stuff.py b/langchain/chains/combine_documents/stuff.py index 796de39e37f..67bdfa7512b 100644 --- a/langchain/chains/combine_documents/stuff.py +++ b/langchain/chains/combine_documents/stuff.py @@ -1,6 +1,6 @@ """Chain that combines documents by stuffing into context.""" -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from pydantic import BaseModel, Extra, Field, root_validator @@ -78,8 +78,8 @@ class StuffDocumentsChain(BaseCombineDocumentsChain, BaseModel): prompt = self.llm_chain.prompt.format(**inputs) return self.llm_chain.llm.get_num_tokens(prompt) - def combine_docs(self, docs: List[Document], **kwargs: Any) -> str: + def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: """Stuff all documents into one prompt and pass to LLM.""" inputs = self._get_inputs(docs, **kwargs) # Call predict on the LLM. - return self.llm_chain.predict(**inputs) + return self.llm_chain.predict(**inputs), {} diff --git a/langchain/chains/mapreduce.py b/langchain/chains/mapreduce.py index ea01ab54283..583e484badd 100644 --- a/langchain/chains/mapreduce.py +++ b/langchain/chains/mapreduce.py @@ -70,5 +70,5 @@ class MapReduceChain(Chain, BaseModel): # Split the larger text into smaller chunks. texts = self.text_splitter.split_text(inputs[self.input_key]) docs = [Document(page_content=text) for text in texts] - outputs = self.combine_documents_chain.combine_docs(docs) + outputs, _ = self.combine_documents_chain.combine_docs(docs) return {self.output_key: outputs} diff --git a/langchain/chains/qa_with_sources/base.py b/langchain/chains/qa_with_sources/base.py index 48a3d017271..9eabb9f56e3 100644 --- a/langchain/chains/qa_with_sources/base.py +++ b/langchain/chains/qa_with_sources/base.py @@ -106,7 +106,7 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC): def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: docs = self._get_docs(inputs) - answer = self.combine_document_chain.combine_docs(docs, **inputs) + answer, _ = self.combine_document_chain.combine_docs(docs, **inputs) if "\nSOURCES: " in answer: answer, sources = answer.split("\nSOURCES: ") else: diff --git a/langchain/chains/vector_db_qa/base.py b/langchain/chains/vector_db_qa/base.py index 1788d962a72..2067dc34006 100644 --- a/langchain/chains/vector_db_qa/base.py +++ b/langchain/chains/vector_db_qa/base.py @@ -101,5 +101,5 @@ class VectorDBQA(Chain, BaseModel): def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: question = inputs[self.input_key] docs = self.vectorstore.similarity_search(question, k=self.k) - answer = self.combine_documents_chain.combine_docs(docs, question=question) + answer, _ = self.combine_documents_chain.combine_docs(docs, question=question) return {self.output_key: answer}