mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
refine
This commit is contained in:
@@ -21,7 +21,6 @@ from langchain_core.pydantic_v1 import (
|
||||
validator,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
|
||||
from langchain_core.runnables.base import RunnableLambda
|
||||
from langchain_core.runnables.configurable import ConfigurableFieldSpec
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
|
||||
@@ -681,12 +680,12 @@ class RunnableChain(Chain):
|
||||
context = Context.create_scope("runnable-chain")
|
||||
|
||||
def prep_outputs(all: Dict[str, Any]) -> Dict[str, Any]:
|
||||
print("before outprep", all)
|
||||
# print("before outprep", all)
|
||||
return self.prep_outputs(all["inputs"], all["outputs"], return_only_outputs)
|
||||
|
||||
return (
|
||||
self.prep_inputs
|
||||
| RunnableLambda(lambda i: print("after prep", i) or i.copy())
|
||||
# | RunnableLambda(lambda i: print("after prep", i) or i.copy())
|
||||
| context.setter("inputs")
|
||||
| self.as_runnable()
|
||||
| {"outputs": RunnablePassthrough(), "inputs": context.getter("inputs")}
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
"""Base interface for chains combining documents."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from operator import itemgetter
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.prompts import BasePromptTemplate, format_document
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
|
||||
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
@@ -15,6 +18,30 @@ from langchain.chains.base import Chain, RunnableChain
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
|
||||
def _format_docs_chain(
|
||||
input_key: str,
|
||||
document_prompt: BasePromptTemplate,
|
||||
document_variable_name: str,
|
||||
document_separator: str,
|
||||
) -> Runnable:
|
||||
def format_document_(doc: Document) -> str:
|
||||
return format_document(doc, document_prompt)
|
||||
|
||||
format_docs = (
|
||||
itemgetter(input_key)
|
||||
| RunnableLambda(format_document_).map()
|
||||
| document_separator.join
|
||||
)
|
||||
|
||||
def pop_raw_docs(input_: dict) -> dict:
|
||||
return {k: v for k, v in input_.items() if k != input_key}
|
||||
|
||||
return (
|
||||
RunnablePassthrough.assign(**{document_variable_name: format_docs})
|
||||
| pop_raw_docs
|
||||
)
|
||||
|
||||
|
||||
class BaseCombineDocumentsChain(RunnableChain, ABC):
|
||||
"""Base interface for chains combining documents.
|
||||
|
||||
|
||||
@@ -2,12 +2,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.prompts import BasePromptTemplate, format_document
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import Extra, Field, root_validator
|
||||
from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import (
|
||||
@@ -133,6 +134,44 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
)
|
||||
return values
|
||||
|
||||
def as_runnable(self) -> Runnable:
|
||||
def format_inputs(inputs: dict) -> dict:
|
||||
doc = inputs[self.input_key][len(inputs.get("intermediate_steps", [])) - 1]
|
||||
inputs[self.document_variable_name] = format_document(
|
||||
doc, self.document_prompt
|
||||
)
|
||||
return {
|
||||
k: v
|
||||
for k, v in inputs.items()
|
||||
if k not in ("intermediate_steps", self.input_key)
|
||||
}
|
||||
|
||||
first_chain = format_inputs | self.initial_llm_chain.as_runnable()
|
||||
refine_chain = format_inputs | self.refine_llm_chain.as_runnable()
|
||||
|
||||
def loop(inputs: dict) -> Union[Runnable, dict]:
|
||||
if len(inputs.get("intermediate_steps", [])) < len(inputs[self.input_key]):
|
||||
return (
|
||||
RunnablePassthrough.assign(
|
||||
intermediate_steps=lambda x: x.get("intermediate_steps", [])
|
||||
+ [x[self.initial_response_name]]
|
||||
)
|
||||
| RunnablePassthrough.assign(
|
||||
**{self.initial_response_name: refine_chain}
|
||||
)
|
||||
| loop
|
||||
)
|
||||
else:
|
||||
res = {self.output_key: inputs["intermediate_steps"][-1]}
|
||||
if self.return_intermediate_steps:
|
||||
res["intermediate_steps"] = inputs["intermediate_steps"]
|
||||
return res
|
||||
|
||||
return (
|
||||
RunnablePassthrough.assign(**{self.initial_response_name: first_chain})
|
||||
| loop
|
||||
)
|
||||
|
||||
def combine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""Chain that combines documents by stuffing into context."""
|
||||
from operator import itemgetter
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from langchain_core.documents import Document
|
||||
@@ -8,14 +7,13 @@ from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import Extra, Field, root_validator
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableLambda,
|
||||
RunnableParallel,
|
||||
RunnablePassthrough,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import (
|
||||
BaseCombineDocumentsChain,
|
||||
_format_docs_chain,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
@@ -107,19 +105,13 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
return values
|
||||
|
||||
def as_runnable(self) -> Runnable:
|
||||
def _format_document(doc: Document):
|
||||
return format_document(doc, self.document_prompt)
|
||||
|
||||
format_docs = (
|
||||
itemgetter(self.input_key) | RunnableLambda(_format_document).map()
|
||||
)
|
||||
|
||||
def pop_raw_docs(input_: dict) -> dict:
|
||||
return {k: v for k, v in input_.items() if k != self.input_key}
|
||||
|
||||
chain = (
|
||||
RunnablePassthrough.assign(**{self.document_variable_name: format_docs})
|
||||
| pop_raw_docs
|
||||
_format_docs_chain(
|
||||
self.input_key,
|
||||
self.document_prompt,
|
||||
self.document_variable_name,
|
||||
self.document_separator,
|
||||
)
|
||||
| self.llm_chain.as_runnable()
|
||||
)
|
||||
return RunnableParallel({self.output_key: chain})
|
||||
|
||||
Reference in New Issue
Block a user