Compare commits

...

9 Commits

Author SHA1 Message Date
Bagatur
05838ebcce refine 2023-12-21 18:55:52 -05:00
Bagatur
243a74735b wip 2023-12-21 15:55:21 -05:00
Bagatur
8096fa2027 Merge branch 'master' into nc/20dec/runnable-chain 2023-12-21 14:36:34 -05:00
Bagatur
7f5faf1076 fmt 2023-12-21 10:19:32 -05:00
Bagatur
062ad47c65 Merge branch 'master' into nc/20dec/runnable-chain 2023-12-21 10:14:02 -05:00
Nuno Campos
e343aec97b Try removing ABC from child class 2023-12-20 15:21:05 -08:00
Nuno Campos
775bd6fb90 Add tests 2023-12-20 15:17:00 -08:00
Nuno Campos
582461945f Add conv retrieval 2023-12-20 13:37:35 -08:00
Nuno Campos
ad1ab2b566 WIP Runnable Chain 2023-12-20 12:55:54 -08:00
10 changed files with 336 additions and 20 deletions

View File

@@ -108,8 +108,6 @@ def _config_with_context(
raise ValueError(
f"Deadlock detected between context keys {key} and {dep}"
)
if len(getters) < 1:
raise ValueError(f"Expected at least one getter for context key {key}")
if len(setters) != 1:
raise ValueError(f"Expected exactly one setter for context key {key}")
setter_idx = setters[0][1]
@@ -118,7 +116,8 @@ def _config_with_context(
f"Context setter for key {key} must be defined after all getters."
)
context_funcs[getters[0][0].id] = partial(getter, events[key], values)
if getters:
context_funcs[getters[0][0].id] = partial(getter, events[key], values)
context_funcs[setters[0][0].id] = partial(setter, events[key], values)
return patch_config(config, configurable=context_funcs)

View File

@@ -6,9 +6,10 @@ import logging
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, Union
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Type, Union
import yaml
from langchain_core.beta.runnables.context import Context
from langchain_core.load.dump import dumpd
from langchain_core.memory import BaseMemory
from langchain_core.outputs import RunInfo
@@ -19,7 +20,9 @@ from langchain_core.pydantic_v1 import (
root_validator,
validator,
)
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
from langchain_core.runnables.configurable import ConfigurableFieldSpec
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
@@ -666,3 +669,141 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
) -> List[Dict[str, str]]:
"""Call the chain on all inputs in the list."""
return [self(inputs, callbacks=callbacks) for inputs in input_list]
class RunnableChain(Chain):
@abstractmethod
def as_runnable(self) -> Runnable:
...
def as_runnable_wrapped(self, return_only_outputs: bool = False) -> Runnable:
context = Context.create_scope("runnable-chain")
def prep_outputs(all: Dict[str, Any]) -> Dict[str, Any]:
# print("before outprep", all)
return self.prep_outputs(all["inputs"], all["outputs"], return_only_outputs)
return (
self.prep_inputs
# | RunnableLambda(lambda i: print("after prep", i) or i.copy())
| context.setter("inputs")
| self.as_runnable()
| {"outputs": RunnablePassthrough(), "inputs": context.getter("inputs")}
| prep_outputs
)
@property
def InputType(self) -> Type[Dict[str, Any]]:
return self.as_runnable().InputType
@property
def OutputType(self) -> Type[Dict[str, Any]]:
return self.as_runnable().OutputType
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self.as_runnable().get_input_schema(config)
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self.as_runnable().get_output_schema(config)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return self.as_runnable_wrapped().config_specs
def invoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
return_only_outputs: bool = False,
**kwargs: Any,
) -> Dict[str, Any]:
return self.as_runnable_wrapped(return_only_outputs).invoke(
input, config, **kwargs
)
async def ainvoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
return_only_outputs: bool = False,
**kwargs: Any,
) -> Dict[str, Any]:
return await self.as_runnable_wrapped(return_only_outputs).ainvoke(
input, config, **kwargs
)
def stream(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
return_only_outputs: bool = False,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
yield from self.as_runnable_wrapped(return_only_outputs).stream(
input, config, **kwargs
)
async def astream(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
return_only_outputs: bool = False,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
async for item in self.as_runnable_wrapped(return_only_outputs).astream(
input, config, **kwargs
):
yield item
def batch(
self,
inputs: List[Dict[str, Any]],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
return_only_outputs: bool = False,
**kwargs: Any,
) -> List[Dict[str, Any]]:
return self.as_runnable_wrapped(return_only_outputs).batch(
inputs, config, **kwargs, return_exceptions=return_exceptions
)
async def abatch(
self,
inputs: List[Dict[str, Any]],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
return_only_outputs: bool = False,
**kwargs: Any,
) -> List[Dict[str, Any]]:
return await self.as_runnable_wrapped(return_only_outputs).abatch(
inputs, config, **kwargs, return_exceptions=return_exceptions
)
def transform(
self,
input: Iterator[Dict[str, Any]],
config: Optional[RunnableConfig] = None,
return_only_outputs: bool = False,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
yield from self.as_runnable_wrapped(return_only_outputs).transform(
input, config, **kwargs
)
async def atransform(
self,
input: AsyncIterator[Dict[str, Any]],
config: Optional[RunnableConfig] = None,
return_only_outputs: bool = False,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
async for chunk in self.as_runnable_wrapped(return_only_outputs).atransform(
input, config, **kwargs
):
yield chunk

View File

@@ -1,21 +1,48 @@
"""Base interface for chains combining documents."""
from abc import ABC, abstractmethod
from operator import itemgetter
from typing import Any, Dict, List, Optional, Tuple, Type
from langchain_core.documents import Document
from langchain_core.prompts import BasePromptTemplate, format_document
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
from langchain_core.runnables.config import RunnableConfig
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.chains.base import Chain, RunnableChain
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
class BaseCombineDocumentsChain(Chain, ABC):
def _format_docs_chain(
input_key: str,
document_prompt: BasePromptTemplate,
document_variable_name: str,
document_separator: str,
) -> Runnable:
def format_document_(doc: Document) -> str:
return format_document(doc, document_prompt)
format_docs = (
itemgetter(input_key)
| RunnableLambda(format_document_).map()
| document_separator.join
)
def pop_raw_docs(input_: dict) -> dict:
return {k: v for k, v in input_.items() if k != input_key}
return (
RunnablePassthrough.assign(**{document_variable_name: format_docs})
| pop_raw_docs
)
class BaseCombineDocumentsChain(RunnableChain, ABC):
"""Base interface for chains combining documents.
Subclasses of this chain deal with combining documents in a variety of

View File

@@ -2,12 +2,13 @@
from __future__ import annotations
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, Union
from langchain_core.documents import Document
from langchain_core.prompts import BasePromptTemplate, format_document
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import (
@@ -133,6 +134,44 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
)
return values
def as_runnable(self) -> Runnable:
def format_inputs(inputs: dict) -> dict:
doc = inputs[self.input_key][len(inputs.get("intermediate_steps", [])) - 1]
inputs[self.document_variable_name] = format_document(
doc, self.document_prompt
)
return {
k: v
for k, v in inputs.items()
if k not in ("intermediate_steps", self.input_key)
}
first_chain = format_inputs | self.initial_llm_chain.as_runnable()
refine_chain = format_inputs | self.refine_llm_chain.as_runnable()
def loop(inputs: dict) -> Union[Runnable, dict]:
if len(inputs.get("intermediate_steps", [])) < len(inputs[self.input_key]):
return (
RunnablePassthrough.assign(
intermediate_steps=lambda x: x.get("intermediate_steps", [])
+ [x[self.initial_response_name]]
)
| RunnablePassthrough.assign(
**{self.initial_response_name: refine_chain}
)
| loop
)
else:
res = {self.output_key: inputs["intermediate_steps"][-1]}
if self.return_intermediate_steps:
res["intermediate_steps"] = inputs["intermediate_steps"]
return res
return (
RunnablePassthrough.assign(**{self.initial_response_name: first_chain})
| loop
)
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:

View File

@@ -1,15 +1,19 @@
"""Chain that combines documents by stuffing into context."""
from typing import Any, Dict, List, Optional, Tuple
from langchain_core.documents import Document
from langchain_core.prompts import BasePromptTemplate, format_document
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.runnables import (
Runnable,
RunnableParallel,
)
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import (
BaseCombineDocumentsChain,
_format_docs_chain,
)
from langchain.chains.llm import LLMChain
@@ -100,6 +104,18 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
)
return values
def as_runnable(self) -> Runnable:
chain = (
_format_docs_chain(
self.input_key,
self.document_prompt,
self.document_variable_name,
self.document_separator,
)
| self.llm_chain.as_runnable()
)
return RunnableParallel({self.output_key: chain})
@property
def input_keys(self) -> List[str]:
extra_keys = [

View File

@@ -4,16 +4,24 @@ from __future__ import annotations
import inspect
import warnings
from abc import abstractmethod
from operator import itemgetter
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from langchain_core.beta.runnables.context import Context
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables import (
Runnable,
RunnableConfig,
RunnableLambda,
RunnableMap,
)
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.vectorstores import VectorStore
from langchain.callbacks.manager import (
@@ -21,7 +29,7 @@ from langchain.callbacks.manager import (
CallbackManagerForChainRun,
Callbacks,
)
from langchain.chains.base import Chain
from langchain.chains.base import RunnableChain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
@@ -63,7 +71,7 @@ class InputType(BaseModel):
"""The chat history to use for retrieval."""
class BaseConversationalRetrievalChain(Chain):
class BaseConversationalRetrievalChain(RunnableChain):
"""Chain for chatting with an index."""
combine_docs_chain: BaseCombineDocumentsChain
@@ -289,6 +297,56 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
"""If set, enforces that the documents returned are less than this limit.
This is only enforced if `combine_docs_chain` is of type StuffDocumentsChain."""
def as_runnable(self) -> Runnable:
context = Context.create_scope("conversational_retrieval")
get_chat_history = self.get_chat_history or _get_chat_history
def get_new_question(inputs: Dict[str, Any]):
if inputs["chat_history"]:
return self.question_generator
else:
return inputs["question"]
def get_answer(inputs: Dict[str, Any]):
if (
self.response_if_no_docs_found is not None
and len(inputs["input_documents"]) == 0
):
return self.response_if_no_docs_found
else:
return self.combine_docs_chain | itemgetter(
self.combine_docs_chain.output_key
)
output_map = {self.output_key: RunnablePassthrough()}
if self.return_source_documents:
output_map["source_documents"] = context.getter("input_documents")
if self.return_generated_question:
output_map["generated_question"] = context.getter("new_question")
return (
RunnableMap(
question=itemgetter("question") | context.setter("question"),
chat_history=itemgetter("chat_history")
| RunnableLambda(get_chat_history)
| context.setter("chat_history"),
)
| get_new_question
| context.setter("new_question")
| self.retriever
| self._reduce_tokens_below_limit
| context.setter("input_documents")
| {
"input_documents": context.getter("input_documents"),
"chat_history": context.getter("chat_history"),
"question": context.getter(
"new_question" if self.rephrase_question else "question"
),
}
| get_answer
| output_map
)
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
num_docs = len(docs)
@@ -400,6 +458,9 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
top_k_docs_for_context: int = 4
search_kwargs: dict = Field(default_factory=dict)
def as_runnable(self) -> Runnable:
return self
@property
def _chain_type(self) -> str:
return "chat-vector-db"

View File

@@ -31,10 +31,10 @@ from langchain.callbacks.manager import (
CallbackManagerForChainRun,
Callbacks,
)
from langchain.chains.base import Chain
from langchain.chains.base import RunnableChain
class LLMChain(Chain):
class LLMChain(RunnableChain):
"""Chain to run queries against LLMs.
Example:
@@ -76,6 +76,13 @@ class LLMChain(Chain):
extra = Extra.forbid
arbitrary_types_allowed = True
def as_runnable(self) -> Runnable:
return (
self.prompt
| self.llm.bind(**self.llm_kwargs)
| {self.output_key: self.prompt.output_parser or self.output_parser}
)
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.

View File

@@ -27,8 +27,10 @@ async def atest_simple() -> None:
verbose=True,
)
got = await qa_chain.acall("What is the answer?")
memory.clear()
assert got["chat_history"][1].content == fixed_resp
assert got["answer"] == fixed_resp
assert got == await qa_chain.ainvoke("What is the answer?")
@pytest.mark.asyncio
@@ -37,7 +39,7 @@ async def atest_fixed_message_response_when_docs_found() -> None:
answer = "I know the answer!"
llm = FakeListLLM(responses=[answer])
retriever = SequentialRetriever(
sequential_responses=[[Document(page_content=answer)]]
sequential_responses=[[Document(page_content=answer)]], cycle=True
)
memory = ConversationBufferMemory(
k=1, output_key="answer", memory_key="chat_history", return_messages=True
@@ -52,8 +54,10 @@ async def atest_fixed_message_response_when_docs_found() -> None:
verbose=True,
)
got = await qa_chain.acall("What is the answer?")
memory.clear()
assert got["chat_history"][1].content == answer
assert got["answer"] == answer
assert got == await qa_chain.ainvoke("What is the answer?")
def test_fixed_message_response_when_no_docs_found() -> None:
@@ -74,8 +78,10 @@ def test_fixed_message_response_when_no_docs_found() -> None:
verbose=True,
)
got = qa_chain("What is the answer?")
memory.clear()
assert got["chat_history"][1].content == fixed_resp
assert got["answer"] == fixed_resp
assert got == qa_chain.invoke("What is the answer?")
def test_fixed_message_response_when_docs_found() -> None:
@@ -83,7 +89,7 @@ def test_fixed_message_response_when_docs_found() -> None:
answer = "I know the answer!"
llm = FakeListLLM(responses=[answer])
retriever = SequentialRetriever(
sequential_responses=[[Document(page_content=answer)]]
sequential_responses=[[Document(page_content=answer)]], cycle=True
)
memory = ConversationBufferMemory(
k=1, output_key="answer", memory_key="chat_history", return_messages=True
@@ -98,5 +104,7 @@ def test_fixed_message_response_when_docs_found() -> None:
verbose=True,
)
got = qa_chain("What is the answer?")
memory.clear()
assert got["chat_history"][1].content == answer
assert got["answer"] == answer
assert got == qa_chain.invoke("What is the answer?")

View File

@@ -51,6 +51,7 @@ def test_valid_call(fake_llm_chain: LLMChain) -> None:
"""Test valid call of LLM chain."""
output = fake_llm_chain({"bar": "baz"})
assert output == {"bar": "baz", "text1": "foo"}
assert fake_llm_chain.invoke({"bar": "baz"}) == output
# Test with stop words.
output = fake_llm_chain({"bar": "baz", "stop": ["foo"]})
@@ -58,6 +59,18 @@ def test_valid_call(fake_llm_chain: LLMChain) -> None:
assert output == {"bar": "baz", "stop": ["foo"], "text1": "bar"}
async def test_valid_acall(fake_llm_chain: LLMChain) -> None:
"""Test valid call of LLM chain."""
output = await fake_llm_chain.acall({"bar": "baz"})
assert output == {"bar": "baz", "text1": "foo"}
assert await fake_llm_chain.ainvoke({"bar": "baz"}) == output
# Test with stop words.
output = await fake_llm_chain.acall({"bar": "baz", "stop": ["foo"]})
# Response should be `bar` now.
assert output == {"bar": "baz", "stop": ["foo"], "text1": "bar"}
def test_predict_method(fake_llm_chain: LLMChain) -> None:
"""Test predict method works."""
output = fake_llm_chain.predict(bar="baz")
@@ -73,3 +86,4 @@ def test_predict_and_parse() -> None:
chain = LLMChain(prompt=prompt, llm=llm)
output = chain.predict_and_parse(foo="foo")
assert output == ["foo", "bar"]
assert output == chain.invoke({"foo": "foo"})["text"]

View File

@@ -8,16 +8,20 @@ class SequentialRetriever(BaseRetriever):
sequential_responses: List[List[Document]]
response_index: int = 0
cycle: bool = False
def _get_relevant_documents( # type: ignore[override]
self,
query: str,
) -> List[Document]:
if self.response_index >= len(self.sequential_responses):
return []
else:
self.response_index += 1
return self.sequential_responses[self.response_index - 1]
if self.cycle:
self.response_index = 0
else:
return []
self.response_index += 1
return self.sequential_responses[self.response_index - 1]
async def _aget_relevant_documents( # type: ignore[override]
self,