From 05838ebcce48432c0221d72fce1634578748276c Mon Sep 17 00:00:00 2001 From: Bagatur Date: Thu, 21 Dec 2023 18:55:52 -0500 Subject: [PATCH] refine --- libs/langchain/langchain/chains/base.py | 5 +-- .../chains/combine_documents/base.py | 27 ++++++++++++ .../chains/combine_documents/refine.py | 41 ++++++++++++++++++- .../chains/combine_documents/stuff.py | 22 ++++------ 4 files changed, 76 insertions(+), 19 deletions(-) diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 3b8617b252d..42a9f291a46 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -21,7 +21,6 @@ from langchain_core.pydantic_v1 import ( validator, ) from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable -from langchain_core.runnables.base import RunnableLambda from langchain_core.runnables.configurable import ConfigurableFieldSpec from langchain_core.runnables.passthrough import RunnablePassthrough @@ -681,12 +680,12 @@ class RunnableChain(Chain): context = Context.create_scope("runnable-chain") def prep_outputs(all: Dict[str, Any]) -> Dict[str, Any]: - print("before outprep", all) + # print("before outprep", all) return self.prep_outputs(all["inputs"], all["outputs"], return_only_outputs) return ( self.prep_inputs - | RunnableLambda(lambda i: print("after prep", i) or i.copy()) + # | RunnableLambda(lambda i: print("after prep", i) or i.copy()) | context.setter("inputs") | self.as_runnable() | {"outputs": RunnablePassthrough(), "inputs": context.getter("inputs")} diff --git a/libs/langchain/langchain/chains/combine_documents/base.py b/libs/langchain/langchain/chains/combine_documents/base.py index 26ef0a48928..148bc6bcfbf 100644 --- a/libs/langchain/langchain/chains/combine_documents/base.py +++ b/libs/langchain/langchain/chains/combine_documents/base.py @@ -1,10 +1,13 @@ """Base interface for chains combining documents.""" from abc import ABC, abstractmethod +from operator import itemgetter from typing import Any, Dict, List, Optional, Tuple, Type from langchain_core.documents import Document +from langchain_core.prompts import BasePromptTemplate, format_document from langchain_core.pydantic_v1 import BaseModel, Field, create_model +from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough from langchain_core.runnables.config import RunnableConfig from langchain.callbacks.manager import ( @@ -15,6 +18,30 @@ from langchain.chains.base import Chain, RunnableChain from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter +def _format_docs_chain( + input_key: str, + document_prompt: BasePromptTemplate, + document_variable_name: str, + document_separator: str, +) -> Runnable: + def format_document_(doc: Document) -> str: + return format_document(doc, document_prompt) + + format_docs = ( + itemgetter(input_key) + | RunnableLambda(format_document_).map() + | document_separator.join + ) + + def pop_raw_docs(input_: dict) -> dict: + return {k: v for k, v in input_.items() if k != input_key} + + return ( + RunnablePassthrough.assign(**{document_variable_name: format_docs}) + | pop_raw_docs + ) + + class BaseCombineDocumentsChain(RunnableChain, ABC): """Base interface for chains combining documents. diff --git a/libs/langchain/langchain/chains/combine_documents/refine.py b/libs/langchain/langchain/chains/combine_documents/refine.py index fb9da0eb6d1..e5c0dc69094 100644 --- a/libs/langchain/langchain/chains/combine_documents/refine.py +++ b/libs/langchain/langchain/chains/combine_documents/refine.py @@ -2,12 +2,13 @@ from __future__ import annotations -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Union from langchain_core.documents import Document from langchain_core.prompts import BasePromptTemplate, format_document from langchain_core.prompts.prompt import PromptTemplate from langchain_core.pydantic_v1 import Extra, Field, root_validator +from langchain_core.runnables import Runnable, RunnablePassthrough from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import ( @@ -133,6 +134,44 @@ class RefineDocumentsChain(BaseCombineDocumentsChain): ) return values + def as_runnable(self) -> Runnable: + def format_inputs(inputs: dict) -> dict: + doc = inputs[self.input_key][len(inputs.get("intermediate_steps", [])) - 1] + inputs[self.document_variable_name] = format_document( + doc, self.document_prompt + ) + return { + k: v + for k, v in inputs.items() + if k not in ("intermediate_steps", self.input_key) + } + + first_chain = format_inputs | self.initial_llm_chain.as_runnable() + refine_chain = format_inputs | self.refine_llm_chain.as_runnable() + + def loop(inputs: dict) -> Union[Runnable, dict]: + if len(inputs.get("intermediate_steps", [])) < len(inputs[self.input_key]): + return ( + RunnablePassthrough.assign( + intermediate_steps=lambda x: x.get("intermediate_steps", []) + + [x[self.initial_response_name]] + ) + | RunnablePassthrough.assign( + **{self.initial_response_name: refine_chain} + ) + | loop + ) + else: + res = {self.output_key: inputs["intermediate_steps"][-1]} + if self.return_intermediate_steps: + res["intermediate_steps"] = inputs["intermediate_steps"] + return res + + return ( + RunnablePassthrough.assign(**{self.initial_response_name: first_chain}) + | loop + ) + def combine_docs( self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any ) -> Tuple[str, dict]: diff --git a/libs/langchain/langchain/chains/combine_documents/stuff.py b/libs/langchain/langchain/chains/combine_documents/stuff.py index 865a049f16b..937787484ce 100644 --- a/libs/langchain/langchain/chains/combine_documents/stuff.py +++ b/libs/langchain/langchain/chains/combine_documents/stuff.py @@ -1,5 +1,4 @@ """Chain that combines documents by stuffing into context.""" -from operator import itemgetter from typing import Any, Dict, List, Optional, Tuple from langchain_core.documents import Document @@ -8,14 +7,13 @@ from langchain_core.prompts.prompt import PromptTemplate from langchain_core.pydantic_v1 import Extra, Field, root_validator from langchain_core.runnables import ( Runnable, - RunnableLambda, RunnableParallel, - RunnablePassthrough, ) from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import ( BaseCombineDocumentsChain, + _format_docs_chain, ) from langchain.chains.llm import LLMChain @@ -107,19 +105,13 @@ class StuffDocumentsChain(BaseCombineDocumentsChain): return values def as_runnable(self) -> Runnable: - def _format_document(doc: Document): - return format_document(doc, self.document_prompt) - - format_docs = ( - itemgetter(self.input_key) | RunnableLambda(_format_document).map() - ) - - def pop_raw_docs(input_: dict) -> dict: - return {k: v for k, v in input_.items() if k != self.input_key} - chain = ( - RunnablePassthrough.assign(**{self.document_variable_name: format_docs}) - | pop_raw_docs + _format_docs_chain( + self.input_key, + self.document_prompt, + self.document_variable_name, + self.document_separator, + ) | self.llm_chain.as_runnable() ) return RunnableParallel({self.output_key: chain})