mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
wip
This commit is contained in:
@@ -11,11 +11,11 @@ from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.base import Chain, RunnableChain
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
|
||||
class BaseCombineDocumentsChain(Chain, ABC):
|
||||
class BaseCombineDocumentsChain(RunnableChain, ABC):
|
||||
"""Base interface for chains combining documents.
|
||||
|
||||
Subclasses of this chain deal with combining documents in a variety of
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
"""Chain that combines documents by stuffing into context."""
|
||||
|
||||
from operator import itemgetter
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.prompts import BasePromptTemplate, format_document
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import Extra, Field, root_validator
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableLambda,
|
||||
RunnableParallel,
|
||||
RunnablePassthrough,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import (
|
||||
@@ -100,6 +106,24 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
)
|
||||
return values
|
||||
|
||||
def as_runnable(self) -> Runnable:
|
||||
def _format_document(doc: Document):
|
||||
return format_document(doc, self.document_prompt)
|
||||
|
||||
format_docs = (
|
||||
itemgetter(self.input_key) | RunnableLambda(_format_document).map()
|
||||
)
|
||||
|
||||
def pop_raw_docs(input_: dict) -> dict:
|
||||
return {k: v for k, v in input_.items() if k != self.input_key}
|
||||
|
||||
chain = (
|
||||
RunnablePassthrough.assign(**{self.document_variable_name: format_docs})
|
||||
| pop_raw_docs
|
||||
| self.llm_chain.as_runnable()
|
||||
)
|
||||
return RunnableParallel({self.output_key: chain})
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
extra_keys = [
|
||||
|
||||
Reference in New Issue
Block a user