mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
return intermediate steps in combine document chains
This commit is contained in:
parent
9ae1d75318
commit
f97db8cc7b
@ -60,7 +60,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 5,
|
||||
"id": "8dff4f43",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -70,7 +70,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 6,
|
||||
"id": "27989fc4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -131,33 +131,57 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 7,
|
||||
"id": "ef28e1d4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = load_summarize_chain(llm, chain_type=\"map_reduce\")"
|
||||
"chain = load_summarize_chain(llm, chain_type=\"map_reduce\", verbose=True, return_map_steps=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 8,
|
||||
"id": "f82c5f9f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new MapReduceDocumentsChain chain...\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished MapReduceDocumentsChain chain.\u001b[0m\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"res = chain(docs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "f5a2b653",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\" In response to Vladimir Putin's aggression in Ukraine, the US and its allies have taken action to hold him accountable, including economic sanctions, cutting off access to technology, and seizing the assets of Russian oligarchs. They are also providing military, economic, and humanitarian assistance to the Ukrainians, and releasing 60 million barrels of oil from reserves around the world. President Biden has passed several laws to provide economic relief to Americans and create jobs, and is making sure taxpayer dollars support American jobs and businesses.\""
|
||||
"[{'text': \" In response to Russia's aggression in Ukraine, the United States has united with other freedom-loving nations to impose economic sanctions and hold Putin accountable. The U.S. Department of Justice is also assembling a task force to go after the crimes of Russian oligarchs and seize their ill-gotten gains.\"},\n",
|
||||
" {'text': ' The United States and its European allies are taking action to punish Russia for its invasion of Ukraine, including seizing assets, closing off airspace, and providing economic and military assistance to Ukraine. The US is also mobilizing forces to protect NATO countries and has released 30 million barrels of oil from its Strategic Petroleum Reserve to help blunt gas prices. The world is uniting in support of Ukraine and democracy, and the US stands with its Ukrainian American citizens.'},\n",
|
||||
" {'text': ' President Biden and Vice President Harris ran for office with a new economic vision for America, and have since passed the American Rescue Plan and the Bipartisan Infrastructure Law to help working people and rebuild America. These plans will create jobs, modernize roads, airports, ports, and waterways, and provide clean water and high-speed internet for all Americans. The government will also be investing in American products to support American jobs.'}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.run(docs)"
|
||||
"res['map_steps']"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Base interface for chains combining documents."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -39,12 +39,13 @@ class BaseCombineDocumentsChain(Chain, BaseModel, ABC):
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> str:
|
||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||
"""Combine documents into a single string."""
|
||||
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
docs = inputs[self.input_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
output = self.combine_docs(docs, **other_keys)
|
||||
return {self.output_key: output}
|
||||
output, extra_return_dict = self.combine_docs(docs, **other_keys)
|
||||
extra_return_dict[self.output_key] = output
|
||||
return extra_return_dict
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
@ -65,6 +65,19 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
||||
document_variable_name: str
|
||||
"""The variable name in the llm_chain to put the documents in.
|
||||
If only one variable in the llm_chain, this need not be provided."""
|
||||
return_map_steps: bool = False
|
||||
"""Return the results of the map steps in the output."""
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = super().output_keys
|
||||
if self.return_map_steps:
|
||||
_output_keys = _output_keys + ["map_steps"]
|
||||
return _output_keys
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -102,7 +115,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
||||
|
||||
def combine_docs(
|
||||
self, docs: List[Document], token_max: int = 3000, **kwargs: Any
|
||||
) -> str:
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine documents in a map reduce manner.
|
||||
|
||||
Combine by mapping first chain over all documents, then reducing the results.
|
||||
@ -133,5 +146,9 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
||||
num_tokens = self.combine_document_chain.prompt_length(
|
||||
result_docs, **kwargs
|
||||
)
|
||||
output = self.combine_document_chain.combine_docs(result_docs, **kwargs)
|
||||
return output
|
||||
if self.return_map_steps:
|
||||
extra_return_dict = {"map_steps": results}
|
||||
else:
|
||||
extra_return_dict = {}
|
||||
output, _ = self.combine_document_chain.combine_docs(result_docs, **kwargs)
|
||||
return output, extra_return_dict
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
@ -33,6 +33,19 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
||||
default_factory=_get_default_document_prompt
|
||||
)
|
||||
"""Prompt to use to format each document."""
|
||||
return_refine_steps: bool = False
|
||||
"""Return the results of the refine steps in the output."""
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = super().output_keys
|
||||
if self.return_refine_steps:
|
||||
_output_keys = _output_keys + ["refine_steps"]
|
||||
return _output_keys
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -61,7 +74,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
||||
)
|
||||
return values
|
||||
|
||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> str:
|
||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||
"""Combine by mapping first chain over all, then stuffing into final chain."""
|
||||
base_info = {"page_content": docs[0].page_content}
|
||||
base_info.update(docs[0].metadata)
|
||||
@ -71,6 +84,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
||||
}
|
||||
inputs = {**base_inputs, **kwargs}
|
||||
res = self.initial_llm_chain.predict(**inputs)
|
||||
refine_steps = [res]
|
||||
for doc in docs[1:]:
|
||||
base_info = {"page_content": doc.page_content}
|
||||
base_info.update(doc.metadata)
|
||||
@ -85,4 +99,9 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
||||
}
|
||||
inputs = {**base_inputs, **kwargs}
|
||||
res = self.refine_llm_chain.predict(**inputs)
|
||||
return res
|
||||
refine_steps.append(res)
|
||||
if self.return_refine_steps:
|
||||
extra_return_dict = {"refine_steps": refine_steps}
|
||||
else:
|
||||
extra_return_dict = {}
|
||||
return res, extra_return_dict
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Chain that combines documents by stuffing into context."""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
@ -78,8 +78,8 @@ class StuffDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
||||
prompt = self.llm_chain.prompt.format(**inputs)
|
||||
return self.llm_chain.llm.get_num_tokens(prompt)
|
||||
|
||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> str:
|
||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||
"""Stuff all documents into one prompt and pass to LLM."""
|
||||
inputs = self._get_inputs(docs, **kwargs)
|
||||
# Call predict on the LLM.
|
||||
return self.llm_chain.predict(**inputs)
|
||||
return self.llm_chain.predict(**inputs), {}
|
||||
|
@ -70,5 +70,5 @@ class MapReduceChain(Chain, BaseModel):
|
||||
# Split the larger text into smaller chunks.
|
||||
texts = self.text_splitter.split_text(inputs[self.input_key])
|
||||
docs = [Document(page_content=text) for text in texts]
|
||||
outputs = self.combine_documents_chain.combine_docs(docs)
|
||||
outputs, _ = self.combine_documents_chain.combine_docs(docs)
|
||||
return {self.output_key: outputs}
|
||||
|
@ -106,7 +106,7 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
||||
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
docs = self._get_docs(inputs)
|
||||
answer = self.combine_document_chain.combine_docs(docs, **inputs)
|
||||
answer, _ = self.combine_document_chain.combine_docs(docs, **inputs)
|
||||
if "\nSOURCES: " in answer:
|
||||
answer, sources = answer.split("\nSOURCES: ")
|
||||
else:
|
||||
|
@ -101,5 +101,5 @@ class VectorDBQA(Chain, BaseModel):
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
question = inputs[self.input_key]
|
||||
docs = self.vectorstore.similarity_search(question, k=self.k)
|
||||
answer = self.combine_documents_chain.combine_docs(docs, question=question)
|
||||
answer, _ = self.combine_documents_chain.combine_docs(docs, question=question)
|
||||
return {self.output_key: answer}
|
||||
|
Loading…
Reference in New Issue
Block a user