This commit is contained in:
Bagatur
2023-12-21 15:55:21 -05:00
parent 8096fa2027
commit 243a74735b
2 changed files with 27 additions and 3 deletions

View File

@@ -11,11 +11,11 @@ from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.chains.base import Chain, RunnableChain
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
class BaseCombineDocumentsChain(Chain, ABC):
class BaseCombineDocumentsChain(RunnableChain, ABC):
"""Base interface for chains combining documents.
Subclasses of this chain deal with combining documents in a variety of

View File

@@ -1,11 +1,17 @@
"""Chain that combines documents by stuffing into context."""
from operator import itemgetter
from typing import Any, Dict, List, Optional, Tuple
from langchain_core.documents import Document
from langchain_core.prompts import BasePromptTemplate, format_document
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.runnables import (
Runnable,
RunnableLambda,
RunnableParallel,
RunnablePassthrough,
)
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import (
@@ -100,6 +106,24 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
)
return values
def as_runnable(self) -> Runnable:
def _format_document(doc: Document):
return format_document(doc, self.document_prompt)
format_docs = (
itemgetter(self.input_key) | RunnableLambda(_format_document).map()
)
def pop_raw_docs(input_: dict) -> dict:
return {k: v for k, v in input_.items() if k != self.input_key}
chain = (
RunnablePassthrough.assign(**{self.document_variable_name: format_docs})
| pop_raw_docs
| self.llm_chain.as_runnable()
)
return RunnableParallel({self.output_key: chain})
@property
def input_keys(self) -> List[str]:
extra_keys = [