From 243a74735bcb9b4c164d23a9d6e9893cd54aa2cb Mon Sep 17 00:00:00 2001 From: Bagatur Date: Thu, 21 Dec 2023 15:55:21 -0500 Subject: [PATCH] wip --- .../chains/combine_documents/base.py | 4 +-- .../chains/combine_documents/stuff.py | 26 ++++++++++++++++++- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/chains/combine_documents/base.py b/libs/langchain/langchain/chains/combine_documents/base.py index 9dd964db71a..26ef0a48928 100644 --- a/libs/langchain/langchain/chains/combine_documents/base.py +++ b/libs/langchain/langchain/chains/combine_documents/base.py @@ -11,11 +11,11 @@ from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, ) -from langchain.chains.base import Chain +from langchain.chains.base import Chain, RunnableChain from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter -class BaseCombineDocumentsChain(Chain, ABC): +class BaseCombineDocumentsChain(RunnableChain, ABC): """Base interface for chains combining documents. Subclasses of this chain deal with combining documents in a variety of diff --git a/libs/langchain/langchain/chains/combine_documents/stuff.py b/libs/langchain/langchain/chains/combine_documents/stuff.py index a30d4a0e90b..865a049f16b 100644 --- a/libs/langchain/langchain/chains/combine_documents/stuff.py +++ b/libs/langchain/langchain/chains/combine_documents/stuff.py @@ -1,11 +1,17 @@ """Chain that combines documents by stuffing into context.""" - +from operator import itemgetter from typing import Any, Dict, List, Optional, Tuple from langchain_core.documents import Document from langchain_core.prompts import BasePromptTemplate, format_document from langchain_core.prompts.prompt import PromptTemplate from langchain_core.pydantic_v1 import Extra, Field, root_validator +from langchain_core.runnables import ( + Runnable, + RunnableLambda, + RunnableParallel, + RunnablePassthrough, +) from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import ( @@ -100,6 +106,24 @@ class StuffDocumentsChain(BaseCombineDocumentsChain): ) return values + def as_runnable(self) -> Runnable: + def _format_document(doc: Document): + return format_document(doc, self.document_prompt) + + format_docs = ( + itemgetter(self.input_key) | RunnableLambda(_format_document).map() + ) + + def pop_raw_docs(input_: dict) -> dict: + return {k: v for k, v in input_.items() if k != self.input_key} + + chain = ( + RunnablePassthrough.assign(**{self.document_variable_name: format_docs}) + | pop_raw_docs + | self.llm_chain.as_runnable() + ) + return RunnableParallel({self.output_key: chain}) + @property def input_keys(self) -> List[str]: extra_keys = [