This commit is contained in:
Bagatur
2023-12-21 18:55:52 -05:00
parent 243a74735b
commit 05838ebcce
4 changed files with 76 additions and 19 deletions

View File

@@ -21,7 +21,6 @@ from langchain_core.pydantic_v1 import (
validator,
)
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
from langchain_core.runnables.base import RunnableLambda
from langchain_core.runnables.configurable import ConfigurableFieldSpec
from langchain_core.runnables.passthrough import RunnablePassthrough
@@ -681,12 +680,12 @@ class RunnableChain(Chain):
context = Context.create_scope("runnable-chain")
def prep_outputs(all: Dict[str, Any]) -> Dict[str, Any]:
print("before outprep", all)
# print("before outprep", all)
return self.prep_outputs(all["inputs"], all["outputs"], return_only_outputs)
return (
self.prep_inputs
| RunnableLambda(lambda i: print("after prep", i) or i.copy())
# | RunnableLambda(lambda i: print("after prep", i) or i.copy())
| context.setter("inputs")
| self.as_runnable()
| {"outputs": RunnablePassthrough(), "inputs": context.getter("inputs")}

View File

@@ -1,10 +1,13 @@
"""Base interface for chains combining documents."""
from abc import ABC, abstractmethod
from operator import itemgetter
from typing import Any, Dict, List, Optional, Tuple, Type
from langchain_core.documents import Document
from langchain_core.prompts import BasePromptTemplate, format_document
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
from langchain_core.runnables.config import RunnableConfig
from langchain.callbacks.manager import (
@@ -15,6 +18,30 @@ from langchain.chains.base import Chain, RunnableChain
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
def _format_docs_chain(
input_key: str,
document_prompt: BasePromptTemplate,
document_variable_name: str,
document_separator: str,
) -> Runnable:
def format_document_(doc: Document) -> str:
return format_document(doc, document_prompt)
format_docs = (
itemgetter(input_key)
| RunnableLambda(format_document_).map()
| document_separator.join
)
def pop_raw_docs(input_: dict) -> dict:
return {k: v for k, v in input_.items() if k != input_key}
return (
RunnablePassthrough.assign(**{document_variable_name: format_docs})
| pop_raw_docs
)
class BaseCombineDocumentsChain(RunnableChain, ABC):
"""Base interface for chains combining documents.

View File

@@ -2,12 +2,13 @@
from __future__ import annotations
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, Union
from langchain_core.documents import Document
from langchain_core.prompts import BasePromptTemplate, format_document
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import (
@@ -133,6 +134,44 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
)
return values
def as_runnable(self) -> Runnable:
def format_inputs(inputs: dict) -> dict:
doc = inputs[self.input_key][len(inputs.get("intermediate_steps", [])) - 1]
inputs[self.document_variable_name] = format_document(
doc, self.document_prompt
)
return {
k: v
for k, v in inputs.items()
if k not in ("intermediate_steps", self.input_key)
}
first_chain = format_inputs | self.initial_llm_chain.as_runnable()
refine_chain = format_inputs | self.refine_llm_chain.as_runnable()
def loop(inputs: dict) -> Union[Runnable, dict]:
if len(inputs.get("intermediate_steps", [])) < len(inputs[self.input_key]):
return (
RunnablePassthrough.assign(
intermediate_steps=lambda x: x.get("intermediate_steps", [])
+ [x[self.initial_response_name]]
)
| RunnablePassthrough.assign(
**{self.initial_response_name: refine_chain}
)
| loop
)
else:
res = {self.output_key: inputs["intermediate_steps"][-1]}
if self.return_intermediate_steps:
res["intermediate_steps"] = inputs["intermediate_steps"]
return res
return (
RunnablePassthrough.assign(**{self.initial_response_name: first_chain})
| loop
)
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:

View File

@@ -1,5 +1,4 @@
"""Chain that combines documents by stuffing into context."""
from operator import itemgetter
from typing import Any, Dict, List, Optional, Tuple
from langchain_core.documents import Document
@@ -8,14 +7,13 @@ from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.runnables import (
Runnable,
RunnableLambda,
RunnableParallel,
RunnablePassthrough,
)
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import (
BaseCombineDocumentsChain,
_format_docs_chain,
)
from langchain.chains.llm import LLMChain
@@ -107,19 +105,13 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
return values
def as_runnable(self) -> Runnable:
def _format_document(doc: Document):
return format_document(doc, self.document_prompt)
format_docs = (
itemgetter(self.input_key) | RunnableLambda(_format_document).map()
)
def pop_raw_docs(input_: dict) -> dict:
return {k: v for k, v in input_.items() if k != self.input_key}
chain = (
RunnablePassthrough.assign(**{self.document_variable_name: format_docs})
| pop_raw_docs
_format_docs_chain(
self.input_key,
self.document_prompt,
self.document_variable_name,
self.document_separator,
)
| self.llm_chain.as_runnable()
)
return RunnableParallel({self.output_key: chain})