From 52e5a8b43e46977f0f6fdea1788488b7201fccc0 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 2 Oct 2023 10:07:30 +0100 Subject: [PATCH 1/9] Create new RunnableSerializable class in preparation for configurable runnables - Also move RunnableBranch to its own file --- libs/langchain/langchain/chains/base.py | 5 +- libs/langchain/langchain/llms/fake.py | 2 +- .../langchain/schema/language_model.py | 5 +- .../langchain/schema/output_parser.py | 11 +- .../langchain/schema/prompt_template.py | 5 +- libs/langchain/langchain/schema/retriever.py | 5 +- .../langchain/schema/runnable/__init__.py | 4 +- .../langchain/schema/runnable/_locals.py | 5 +- .../langchain/schema/runnable/base.py | 247 +++--------------- .../langchain/schema/runnable/branch.py | 234 +++++++++++++++++ .../langchain/schema/runnable/passthrough.py | 12 +- .../langchain/schema/runnable/router.py | 11 +- libs/langchain/langchain/tools/base.py | 5 +- .../langchain/tools/spark_sql/tool.py | 8 - .../langchain/tools/sql_database/tool.py | 8 - .../langchain/tools/vectorstore/tool.py | 5 - 16 files changed, 313 insertions(+), 259 deletions(-) create mode 100644 libs/langchain/langchain/schema/runnable/branch.py diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 94a9bdfd0aa..fd54cfe6bce 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -21,7 +21,6 @@ from langchain.callbacks.manager import ( Callbacks, ) from langchain.load.dump import dumpd -from langchain.load.serializable import Serializable from langchain.pydantic_v1 import ( BaseModel, Field, @@ -30,7 +29,7 @@ from langchain.pydantic_v1 import ( validator, ) from langchain.schema import RUN_KEY, BaseMemory, RunInfo -from langchain.schema.runnable import Runnable, RunnableConfig +from langchain.schema.runnable import RunnableConfig, RunnableSerializable logger = logging.getLogger(__name__) @@ -39,7 +38,7 @@ def _get_verbosity() -> bool: return langchain.verbose -class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): +class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): """Abstract base class for creating structured sequences of calls to components. Chains should be used to encode a sequence of calls to components like diff --git a/libs/langchain/langchain/llms/fake.py b/libs/langchain/langchain/llms/fake.py index b3c22eb644d..cb3ea2792ff 100644 --- a/libs/langchain/langchain/llms/fake.py +++ b/libs/langchain/langchain/llms/fake.py @@ -14,7 +14,7 @@ from langchain.schema.runnable import RunnableConfig class FakeListLLM(LLM): """Fake LLM for testing purposes.""" - responses: List + responses: List[str] sleep: Optional[float] = None i: int = 0 diff --git a/libs/langchain/langchain/schema/language_model.py b/libs/langchain/langchain/schema/language_model.py index 16e8edbc9cd..c4e8e5169de 100644 --- a/libs/langchain/langchain/schema/language_model.py +++ b/libs/langchain/langchain/schema/language_model.py @@ -15,11 +15,10 @@ from typing import ( from typing_extensions import TypeAlias -from langchain.load.serializable import Serializable from langchain.schema.messages import AnyMessage, BaseMessage, get_buffer_string from langchain.schema.output import LLMResult from langchain.schema.prompt import PromptValue -from langchain.schema.runnable import Runnable +from langchain.schema.runnable import RunnableSerializable from langchain.utils import get_pydantic_field_names if TYPE_CHECKING: @@ -54,7 +53,7 @@ LanguageModelOutput = TypeVar("LanguageModelOutput") class BaseLanguageModel( - Serializable, Runnable[LanguageModelInput, LanguageModelOutput], ABC + RunnableSerializable[LanguageModelInput, LanguageModelOutput], ABC ): """Abstract base class for interfacing with language models. diff --git a/libs/langchain/langchain/schema/output_parser.py b/libs/langchain/langchain/schema/output_parser.py index 157c1cd5f0f..c675dfe49b9 100644 --- a/libs/langchain/langchain/schema/output_parser.py +++ b/libs/langchain/langchain/schema/output_parser.py @@ -16,7 +16,6 @@ from typing import ( from typing_extensions import get_args -from langchain.load.serializable import Serializable from langchain.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk from langchain.schema.output import ( ChatGeneration, @@ -25,12 +24,12 @@ from langchain.schema.output import ( GenerationChunk, ) from langchain.schema.prompt import PromptValue -from langchain.schema.runnable import Runnable, RunnableConfig +from langchain.schema.runnable import RunnableConfig, RunnableSerializable T = TypeVar("T") -class BaseLLMOutputParser(Serializable, Generic[T], ABC): +class BaseLLMOutputParser(Generic[T], ABC): """Abstract base class for parsing the outputs of a model.""" @abstractmethod @@ -63,7 +62,7 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC): class BaseGenerationOutputParser( - BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T] + BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T] ): """Base class to parse the output of an LLM call.""" @@ -121,7 +120,9 @@ class BaseGenerationOutputParser( ) -class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]): +class BaseOutputParser( + BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T] +): """Base class to parse the output of an LLM call. Output parsers help structure language model responses. diff --git a/libs/langchain/langchain/schema/prompt_template.py b/libs/langchain/langchain/schema/prompt_template.py index ab790753aaf..b72e2fe55ec 100644 --- a/libs/langchain/langchain/schema/prompt_template.py +++ b/libs/langchain/langchain/schema/prompt_template.py @@ -7,15 +7,14 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Union import yaml -from langchain.load.serializable import Serializable from langchain.pydantic_v1 import BaseModel, Field, create_model, root_validator from langchain.schema.document import Document from langchain.schema.output_parser import BaseOutputParser from langchain.schema.prompt import PromptValue -from langchain.schema.runnable import Runnable, RunnableConfig +from langchain.schema.runnable import RunnableConfig, RunnableSerializable -class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC): +class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC): """Base class for all prompt templates, returning a prompt.""" input_variables: List[str] diff --git a/libs/langchain/langchain/schema/retriever.py b/libs/langchain/langchain/schema/retriever.py index 04c28354342..25934eb3edd 100644 --- a/libs/langchain/langchain/schema/retriever.py +++ b/libs/langchain/langchain/schema/retriever.py @@ -6,9 +6,8 @@ from inspect import signature from typing import TYPE_CHECKING, Any, Dict, List, Optional from langchain.load.dump import dumpd -from langchain.load.serializable import Serializable from langchain.schema.document import Document -from langchain.schema.runnable import Runnable, RunnableConfig +from langchain.schema.runnable import RunnableConfig, RunnableSerializable if TYPE_CHECKING: from langchain.callbacks.manager import ( @@ -18,7 +17,7 @@ if TYPE_CHECKING: ) -class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): +class BaseRetriever(RunnableSerializable[str, List[Document]], ABC): """Abstract base class for a Document retrieval system. A retrieval system is defined as something that can take string queries and return diff --git a/libs/langchain/langchain/schema/runnable/__init__.py b/libs/langchain/langchain/schema/runnable/__init__.py index 2b068d5ebaa..bde6121bed0 100644 --- a/libs/langchain/langchain/schema/runnable/__init__.py +++ b/libs/langchain/langchain/schema/runnable/__init__.py @@ -2,12 +2,13 @@ from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar from langchain.schema.runnable.base import ( Runnable, RunnableBinding, - RunnableBranch, RunnableLambda, RunnableMap, RunnableSequence, + RunnableSerializable, RunnableWithFallbacks, ) +from langchain.schema.runnable.branch import RunnableBranch from langchain.schema.runnable.config import RunnableConfig, patch_config from langchain.schema.runnable.passthrough import RunnablePassthrough from langchain.schema.runnable.router import RouterInput, RouterRunnable @@ -19,6 +20,7 @@ __all__ = [ "RouterInput", "RouterRunnable", "Runnable", + "RunnableSerializable", "RunnableBinding", "RunnableBranch", "RunnableConfig", diff --git a/libs/langchain/langchain/schema/runnable/_locals.py b/libs/langchain/langchain/schema/runnable/_locals.py index 839bb5fbc82..e2fe8541148 100644 --- a/libs/langchain/langchain/schema/runnable/_locals.py +++ b/libs/langchain/langchain/schema/runnable/_locals.py @@ -11,8 +11,7 @@ from typing import ( Union, ) -from langchain.load.serializable import Serializable -from langchain.schema.runnable.base import Input, Output, Runnable +from langchain.schema.runnable.base import Input, Output, RunnableSerializable from langchain.schema.runnable.config import RunnableConfig from langchain.schema.runnable.passthrough import RunnablePassthrough @@ -104,7 +103,7 @@ class PutLocalVar(RunnablePassthrough): class GetLocalVar( - Serializable, Runnable[Input, Union[Output, Dict[str, Union[Input, Output]]]] + RunnableSerializable[Input, Union[Output, Dict[str, Union[Input, Output]]]] ): key: str """The key to extract from the local state.""" diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 7286333cf41..35939b128c7 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -119,6 +119,24 @@ class Runnable(Generic[Input, Output], ABC): self.__class__.__name__ + "Output", __root__=(root_type, None) ) + def config_schema( + self, *, include: Optional[Sequence[str]] = None + ) -> Type[BaseModel]: + class _Config: + arbitrary_types_allowed = True + + include = include or [] + + return create_model( # type: ignore[call-overload] + self.__class__.__name__ + "Config", + __config__=_Config, + **{ + field_name: (field_type, None) + for field_name, field_type in RunnableConfig.__annotations__.items() + if field_name in include + }, + ) + def __or__( self, other: Union[ @@ -812,209 +830,11 @@ class Runnable(Generic[Input, Output], ABC): await run_manager.on_chain_end(final_output, inputs=final_input) -class RunnableBranch(Serializable, Runnable[Input, Output]): - """A Runnable that selects which branch to run based on a condition. - - The runnable is initialized with a list of (condition, runnable) pairs and - a default branch. - - When operating on an input, the first condition that evaluates to True is - selected, and the corresponding runnable is run on the input. - - If no condition evaluates to True, the default branch is run on the input. - - Examples: - - .. code-block:: python - - from langchain.schema.runnable import RunnableBranch - - branch = RunnableBranch( - (lambda x: isinstance(x, str), lambda x: x.upper()), - (lambda x: isinstance(x, int), lambda x: x + 1), - (lambda x: isinstance(x, float), lambda x: x * 2), - lambda x: "goodbye", - ) - - branch.invoke("hello") # "HELLO" - branch.invoke(None) # "goodbye" - """ - - branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]] - default: Runnable[Input, Output] - - def __init__( - self, - *branches: Union[ - Tuple[ - Union[ - Runnable[Input, bool], - Callable[[Input], bool], - Callable[[Input], Awaitable[bool]], - ], - RunnableLike, - ], - RunnableLike, # To accommodate the default branch - ], - ) -> None: - """A Runnable that runs one of two branches based on a condition.""" - if len(branches) < 2: - raise ValueError("RunnableBranch requires at least two branches") - - default = branches[-1] - - if not isinstance( - default, (Runnable, Callable, Mapping) # type: ignore[arg-type] - ): - raise TypeError( - "RunnableBranch default must be runnable, callable or mapping." - ) - - default_ = cast( - Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default)) - ) - - _branches = [] - - for branch in branches[:-1]: - if not isinstance(branch, (tuple, list)): # type: ignore[arg-type] - raise TypeError( - f"RunnableBranch branches must be " - f"tuples or lists, not {type(branch)}" - ) - - if not len(branch) == 2: - raise ValueError( - f"RunnableBranch branches must be " - f"tuples or lists of length 2, not {len(branch)}" - ) - condition, runnable = branch - condition = cast(Runnable[Input, bool], coerce_to_runnable(condition)) - runnable = coerce_to_runnable(runnable) - _branches.append((condition, runnable)) - - super().__init__(branches=_branches, default=default_) - - class Config: - arbitrary_types_allowed = True - - @classmethod - def is_lc_serializable(cls) -> bool: - """RunnableBranch is serializable if all its branches are serializable.""" - return True - - @classmethod - def get_lc_namespace(cls) -> List[str]: - """The namespace of a RunnableBranch is the namespace of its default branch.""" - return cls.__module__.split(".")[:-1] - - @property - def input_schema(self) -> type[BaseModel]: - runnables = ( - [self.default] - + [r for _, r in self.branches] - + [r for r, _ in self.branches] - ) - - for runnable in runnables: - if runnable.input_schema.schema().get("type") is not None: - return runnable.input_schema - - return super().input_schema - - def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: - """First evaluates the condition, then delegate to true or false branch.""" - config = ensure_config(config) - callback_manager = get_callback_manager_for_config(config) - run_manager = callback_manager.on_chain_start( - dumpd(self), - input, - name=config.get("run_name"), - ) - - try: - for idx, branch in enumerate(self.branches): - condition, runnable = branch - - expression_value = condition.invoke( - input, - config=patch_config( - config, - callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"), - ), - ) - - if expression_value: - output = runnable.invoke( - input, - config=patch_config( - config, - callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"), - ), - ) - break - else: - output = self.default.invoke( - input, - config=patch_config( - config, callbacks=run_manager.get_child(tag="branch:default") - ), - ) - except Exception as e: - run_manager.on_chain_error(e) - raise - run_manager.on_chain_end(dumpd(output)) - return output - - async def ainvoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> Output: - """Async version of invoke.""" - config = ensure_config(config) - callback_manager = get_callback_manager_for_config(config) - run_manager = callback_manager.on_chain_start( - dumpd(self), - input, - name=config.get("run_name"), - ) - try: - for idx, branch in enumerate(self.branches): - condition, runnable = branch - - expression_value = await condition.ainvoke( - input, - config=patch_config( - config, - callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"), - ), - ) - - if expression_value: - output = await runnable.ainvoke( - input, - config=patch_config( - config, - callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"), - ), - **kwargs, - ) - break - else: - output = await self.default.ainvoke( - input, - config=patch_config( - config, callbacks=run_manager.get_child(tag="branch:default") - ), - **kwargs, - ) - except Exception as e: - run_manager.on_chain_error(e) - raise - run_manager.on_chain_end(dumpd(output)) - return output +class RunnableSerializable(Serializable, Runnable[Input, Output]): + pass -class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): +class RunnableWithFallbacks(RunnableSerializable[Input, Output]): """ A Runnable that can fallback to other Runnables if it fails. """ @@ -1042,6 +862,11 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): def output_schema(self) -> Type[BaseModel]: return self.runnable.output_schema + def config_schema( + self, *, include: Optional[Sequence[str]] = None + ) -> Type[BaseModel]: + return self.runnable.config_schema(include=include) + @classmethod def is_lc_serializable(cls) -> bool: return True @@ -1267,7 +1092,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): raise first_error -class RunnableSequence(Serializable, Runnable[Input, Output]): +class RunnableSequence(RunnableSerializable[Input, Output]): """ A sequence of runnables, where the output of each is the input of the next. """ @@ -1749,7 +1574,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): yield chunk -class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): +class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]): """ A runnable that runs a mapping of runnables in parallel, and returns a mapping of their outputs. @@ -1799,7 +1624,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): return create_model( # type: ignore[call-overload] "RunnableMapInput", **{ - k: (v.type_, v.default) + k: (v.annotation, v.default) for step in self.steps.values() for k, v in step.input_schema.__fields__.items() if k != "__root__" @@ -2374,7 +2199,7 @@ class RunnableLambda(Runnable[Input, Output]): return await super().ainvoke(input, config) -class RunnableEach(Serializable, Runnable[List[Input], List[Output]]): +class RunnableEach(RunnableSerializable[List[Input], List[Output]]): """ A runnable that delegates calls to another runnable with each element of the input sequence. @@ -2413,6 +2238,11 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]): ), ) + def config_schema( + self, *, include: Optional[Sequence[str]] = None + ) -> Type[BaseModel]: + return self.bound.config_schema(include=include) + @classmethod def is_lc_serializable(cls) -> bool: return True @@ -2455,7 +2285,7 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]): return await self._acall_with_config(self._ainvoke, input, config) -class RunnableBinding(Serializable, Runnable[Input, Output]): +class RunnableBinding(RunnableSerializable[Input, Output]): """ A runnable that delegates calls to another runnable with a set of kwargs. """ @@ -2485,6 +2315,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): def output_schema(self) -> Type[BaseModel]: return self.bound.output_schema + def config_schema( + self, *, include: Optional[Sequence[str]] = None + ) -> Type[BaseModel]: + return self.bound.config_schema(include=include) + @classmethod def is_lc_serializable(cls) -> bool: return True diff --git a/libs/langchain/langchain/schema/runnable/branch.py b/libs/langchain/langchain/schema/runnable/branch.py new file mode 100644 index 00000000000..d609fedeffc --- /dev/null +++ b/libs/langchain/langchain/schema/runnable/branch.py @@ -0,0 +1,234 @@ +from typing import ( + Any, + Awaitable, + Callable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, + cast, +) + +from langchain.load.dump import dumpd +from langchain.pydantic_v1 import BaseModel +from langchain.schema.runnable.base import ( + Runnable, + RunnableLike, + RunnableSerializable, + coerce_to_runnable, +) +from langchain.schema.runnable.config import ( + RunnableConfig, + ensure_config, + get_callback_manager_for_config, + patch_config, +) +from langchain.schema.runnable.utils import Input, Output + + +class RunnableBranch(RunnableSerializable[Input, Output]): + """A Runnable that selects which branch to run based on a condition. + + The runnable is initialized with a list of (condition, runnable) pairs and + a default branch. + + When operating on an input, the first condition that evaluates to True is + selected, and the corresponding runnable is run on the input. + + If no condition evaluates to True, the default branch is run on the input. + + Examples: + + .. code-block:: python + + from langchain.schema.runnable import RunnableBranch + + branch = RunnableBranch( + (lambda x: isinstance(x, str), lambda x: x.upper()), + (lambda x: isinstance(x, int), lambda x: x + 1), + (lambda x: isinstance(x, float), lambda x: x * 2), + lambda x: "goodbye", + ) + + branch.invoke("hello") # "HELLO" + branch.invoke(None) # "goodbye" + """ + + branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]] + default: Runnable[Input, Output] + + def __init__( + self, + *branches: Union[ + Tuple[ + Union[ + Runnable[Input, bool], + Callable[[Input], bool], + Callable[[Input], Awaitable[bool]], + ], + RunnableLike, + ], + RunnableLike, # To accommodate the default branch + ], + ) -> None: + """A Runnable that runs one of two branches based on a condition.""" + if len(branches) < 2: + raise ValueError("RunnableBranch requires at least two branches") + + default = branches[-1] + + if not isinstance( + default, (Runnable, Callable, Mapping) # type: ignore[arg-type] + ): + raise TypeError( + "RunnableBranch default must be runnable, callable or mapping." + ) + + default_ = cast( + Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default)) + ) + + _branches = [] + + for branch in branches[:-1]: + if not isinstance(branch, (tuple, list)): # type: ignore[arg-type] + raise TypeError( + f"RunnableBranch branches must be " + f"tuples or lists, not {type(branch)}" + ) + + if not len(branch) == 2: + raise ValueError( + f"RunnableBranch branches must be " + f"tuples or lists of length 2, not {len(branch)}" + ) + condition, runnable = branch + condition = cast(Runnable[Input, bool], coerce_to_runnable(condition)) + runnable = coerce_to_runnable(runnable) + _branches.append((condition, runnable)) + + super().__init__(branches=_branches, default=default_) + + class Config: + arbitrary_types_allowed = True + + @classmethod + def is_lc_serializable(cls) -> bool: + """RunnableBranch is serializable if all its branches are serializable.""" + return True + + @classmethod + def get_lc_namespace(cls) -> List[str]: + """The namespace of a RunnableBranch is the namespace of its default branch.""" + return cls.__module__.split(".")[:-1] + + @property + def input_schema(self) -> type[BaseModel]: + runnables = ( + [self.default] + + [r for _, r in self.branches] + + [r for r, _ in self.branches] + ) + + for runnable in runnables: + if runnable.input_schema.schema().get("type") is not None: + return runnable.input_schema + + return super().input_schema + + def invoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: + """First evaluates the condition, then delegate to true or false branch.""" + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) + run_manager = callback_manager.on_chain_start( + dumpd(self), + input, + name=config.get("run_name"), + ) + + try: + for idx, branch in enumerate(self.branches): + condition, runnable = branch + + expression_value = condition.invoke( + input, + config=patch_config( + config, + callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"), + ), + ) + + if expression_value: + output = runnable.invoke( + input, + config=patch_config( + config, + callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"), + ), + **kwargs, + ) + break + else: + output = self.default.invoke( + input, + config=patch_config( + config, callbacks=run_manager.get_child(tag="branch:default") + ), + **kwargs, + ) + except Exception as e: + run_manager.on_chain_error(e) + raise + run_manager.on_chain_end(dumpd(output)) + return output + + async def ainvoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: + """Async version of invoke.""" + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) + run_manager = callback_manager.on_chain_start( + dumpd(self), + input, + name=config.get("run_name"), + ) + try: + for idx, branch in enumerate(self.branches): + condition, runnable = branch + + expression_value = await condition.ainvoke( + input, + config=patch_config( + config, + callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"), + ), + ) + + if expression_value: + output = await runnable.ainvoke( + input, + config=patch_config( + config, + callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"), + ), + **kwargs, + ) + break + else: + output = await self.default.ainvoke( + input, + config=patch_config( + config, callbacks=run_manager.get_child(tag="branch:default") + ), + **kwargs, + ) + except Exception as e: + run_manager.on_chain_error(e) + raise + run_manager.on_chain_end(dumpd(output)) + return output diff --git a/libs/langchain/langchain/schema/runnable/passthrough.py b/libs/langchain/langchain/schema/runnable/passthrough.py index 18afe82591d..1d1b046a572 100644 --- a/libs/langchain/langchain/schema/runnable/passthrough.py +++ b/libs/langchain/langchain/schema/runnable/passthrough.py @@ -16,9 +16,13 @@ from typing import ( cast, ) -from langchain.load.serializable import Serializable from langchain.pydantic_v1 import BaseModel, create_model -from langchain.schema.runnable.base import Input, Runnable, RunnableMap +from langchain.schema.runnable.base import ( + Input, + Runnable, + RunnableMap, + RunnableSerializable, +) from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config from langchain.schema.runnable.utils import AddableDict from langchain.utils.aiter import atee, py_anext @@ -33,7 +37,7 @@ async def aidentity(x: Input) -> Input: return x -class RunnablePassthrough(Serializable, Runnable[Input, Input]): +class RunnablePassthrough(RunnableSerializable[Input, Input]): """ A runnable that passes through the input. """ @@ -109,7 +113,7 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]): yield chunk -class RunnableAssign(Serializable, Runnable[Dict[str, Any], Dict[str, Any]]): +class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): """ A runnable that assigns key-value pairs to Dict[str, Any] inputs. """ diff --git a/libs/langchain/langchain/schema/runnable/router.py b/libs/langchain/langchain/schema/runnable/router.py index f697c0328c9..9638235fc87 100644 --- a/libs/langchain/langchain/schema/runnable/router.py +++ b/libs/langchain/langchain/schema/runnable/router.py @@ -14,8 +14,13 @@ from typing import ( from typing_extensions import TypedDict -from langchain.load.serializable import Serializable -from langchain.schema.runnable.base import Input, Output, Runnable, coerce_to_runnable +from langchain.schema.runnable.base import ( + Input, + Output, + Runnable, + RunnableSerializable, + coerce_to_runnable, +) from langchain.schema.runnable.config import ( RunnableConfig, get_config_list, @@ -36,7 +41,7 @@ class RouterInput(TypedDict): input: Any -class RouterRunnable(Serializable, Runnable[RouterInput, Output]): +class RouterRunnable(RunnableSerializable[RouterInput, Output]): """ A runnable that routes to a set of runnables based on Input['key']. Returns the output of the selected runnable. diff --git a/libs/langchain/langchain/tools/base.py b/libs/langchain/langchain/tools/base.py index 269e2b4846a..2ae81d246bc 100644 --- a/libs/langchain/langchain/tools/base.py +++ b/libs/langchain/langchain/tools/base.py @@ -25,7 +25,7 @@ from langchain.pydantic_v1 import ( root_validator, validate_arguments, ) -from langchain.schema.runnable import Runnable, RunnableConfig +from langchain.schema.runnable import Runnable, RunnableConfig, RunnableSerializable class SchemaAnnotationError(TypeError): @@ -97,7 +97,7 @@ class ToolException(Exception): pass -class BaseTool(BaseModel, Runnable[Union[str, Dict], Any]): +class BaseTool(RunnableSerializable[Union[str, Dict], Any]): """Interface LangChain tools must implement.""" def __init_subclass__(cls, **kwargs: Any) -> None: @@ -168,7 +168,6 @@ class ChildTool(BaseTool): class Config: """Configuration for this pydantic object.""" - extra = Extra.forbid arbitrary_types_allowed = True @property diff --git a/libs/langchain/langchain/tools/spark_sql/tool.py b/libs/langchain/langchain/tools/spark_sql/tool.py index 4a650f7d518..c79bfd193ad 100644 --- a/libs/langchain/langchain/tools/spark_sql/tool.py +++ b/libs/langchain/langchain/tools/spark_sql/tool.py @@ -21,14 +21,6 @@ class BaseSparkSQLTool(BaseModel): db: SparkSQL = Field(exclude=True) - # Override BaseTool.Config to appease mypy - # See https://github.com/pydantic/pydantic/issues/4173 - class Config(BaseTool.Config): - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - extra = Extra.forbid - class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool): """Tool for querying a Spark SQL.""" diff --git a/libs/langchain/langchain/tools/sql_database/tool.py b/libs/langchain/langchain/tools/sql_database/tool.py index 75f45c7b9e7..eba921c163c 100644 --- a/libs/langchain/langchain/tools/sql_database/tool.py +++ b/libs/langchain/langchain/tools/sql_database/tool.py @@ -21,14 +21,6 @@ class BaseSQLDatabaseTool(BaseModel): db: SQLDatabase = Field(exclude=True) - # Override BaseTool.Config to appease mypy - # See https://github.com/pydantic/pydantic/issues/4173 - class Config(BaseTool.Config): - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - extra = Extra.forbid - class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): """Tool for querying a SQL database.""" diff --git a/libs/langchain/langchain/tools/vectorstore/tool.py b/libs/langchain/langchain/tools/vectorstore/tool.py index a0507964e78..c62145c4c90 100644 --- a/libs/langchain/langchain/tools/vectorstore/tool.py +++ b/libs/langchain/langchain/tools/vectorstore/tool.py @@ -17,11 +17,6 @@ class BaseVectorStoreTool(BaseModel): vectorstore: VectorStore = Field(exclude=True) llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0)) - class Config(BaseTool.Config): - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - def _create_description_from_template(values: Dict[str, Any]) -> Dict[str, Any]: values["description"] = values["template"].format(name=values["name"]) From 040bb2983d72eb6c309d608cc02c7982886362fd Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 2 Oct 2023 10:11:26 +0100 Subject: [PATCH 2/9] Lint --- libs/langchain/langchain/tools/base.py | 3 ++- libs/langchain/langchain/tools/spark_sql/tool.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/tools/base.py b/libs/langchain/langchain/tools/base.py index 2ae81d246bc..e974070ece9 100644 --- a/libs/langchain/langchain/tools/base.py +++ b/libs/langchain/langchain/tools/base.py @@ -17,6 +17,7 @@ from langchain.callbacks.manager import ( CallbackManagerForToolRun, Callbacks, ) +from langchain.load.serializable import Serializable from langchain.pydantic_v1 import ( BaseModel, Extra, @@ -165,7 +166,7 @@ class ChildTool(BaseTool): ] = False """Handle the content of the ToolException thrown.""" - class Config: + class Config(Serializable.Config): """Configuration for this pydantic object.""" arbitrary_types_allowed = True diff --git a/libs/langchain/langchain/tools/spark_sql/tool.py b/libs/langchain/langchain/tools/spark_sql/tool.py index c79bfd193ad..ad1c6f5ba15 100644 --- a/libs/langchain/langchain/tools/spark_sql/tool.py +++ b/libs/langchain/langchain/tools/spark_sql/tool.py @@ -2,7 +2,7 @@ """Tools for interacting with Spark SQL.""" from typing import Any, Dict, Optional -from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator +from langchain.pydantic_v1 import BaseModel, Field, root_validator from langchain.schema.language_model import BaseLanguageModel from langchain.callbacks.manager import ( @@ -21,6 +21,9 @@ class BaseSparkSQLTool(BaseModel): db: SparkSQL = Field(exclude=True) + class Config(BaseTool.Config): + pass + class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool): """Tool for querying a Spark SQL.""" From f7dd10b820737b9f461f5947a1a2920ba9dc012a Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 2 Oct 2023 10:13:09 +0100 Subject: [PATCH 3/9] Lint --- libs/langchain/langchain/tools/sql_database/tool.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/libs/langchain/langchain/tools/sql_database/tool.py b/libs/langchain/langchain/tools/sql_database/tool.py index eba921c163c..99289e4a97c 100644 --- a/libs/langchain/langchain/tools/sql_database/tool.py +++ b/libs/langchain/langchain/tools/sql_database/tool.py @@ -21,6 +21,9 @@ class BaseSQLDatabaseTool(BaseModel): db: SQLDatabase = Field(exclude=True) + class Config(BaseTool.Config): + pass + class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): """Tool for querying a SQL database.""" From a6afd45c63e4bcb11e8e33fdc92b5d46765d96df Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 2 Oct 2023 10:14:56 +0100 Subject: [PATCH 4/9] Lint --- libs/langchain/langchain/schema/runnable/branch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/schema/runnable/branch.py b/libs/langchain/langchain/schema/runnable/branch.py index d609fedeffc..105582f0c2f 100644 --- a/libs/langchain/langchain/schema/runnable/branch.py +++ b/libs/langchain/langchain/schema/runnable/branch.py @@ -7,6 +7,7 @@ from typing import ( Optional, Sequence, Tuple, + Type, Union, cast, ) @@ -125,7 +126,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]): return cls.__module__.split(".")[:-1] @property - def input_schema(self) -> type[BaseModel]: + def input_schema(self) -> Type[BaseModel]: runnables = ( [self.default] + [r for _, r in self.branches] From 01dbfc2bc719a3ad75e693ec6ab4dac7bc1a0566 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 2 Oct 2023 10:21:40 +0100 Subject: [PATCH 5/9] Lint --- libs/langchain/langchain/tools/vectorstore/tool.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/libs/langchain/langchain/tools/vectorstore/tool.py b/libs/langchain/langchain/tools/vectorstore/tool.py index c62145c4c90..1504bdce834 100644 --- a/libs/langchain/langchain/tools/vectorstore/tool.py +++ b/libs/langchain/langchain/tools/vectorstore/tool.py @@ -17,6 +17,9 @@ class BaseVectorStoreTool(BaseModel): vectorstore: VectorStore = Field(exclude=True) llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0)) + class Config(BaseTool.Config): + pass + def _create_description_from_template(values: Dict[str, Any]) -> Dict[str, Any]: values["description"] = values["template"].format(name=values["name"]) From a3b82d18319fce2c55aeee959c782cedadc2440a Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 2 Oct 2023 10:26:10 +0100 Subject: [PATCH 6/9] Move RunnableWithFallbacks to its own file --- .../langchain/schema/runnable/base.py | 7 +- .../langchain/schema/runnable/fallbacks.py | 277 ++++++++++++++++++ 2 files changed, 283 insertions(+), 1 deletion(-) create mode 100644 libs/langchain/langchain/schema/runnable/fallbacks.py diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 35939b128c7..384338b1921 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -36,6 +36,9 @@ if TYPE_CHECKING: CallbackManagerForChainRun, ) from langchain.callbacks.tracers.log_stream import RunLogPatch + from langchain.schema.runnable.fallbacks import ( + RunnableWithFallbacks as RunnableWithFallbacksT, + ) from langchain.load.dump import dumpd @@ -455,7 +458,9 @@ class Runnable(Generic[Input, Output], ABC): fallbacks: Sequence[Runnable[Input, Output]], *, exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,), - ) -> RunnableWithFallbacks[Input, Output]: + ) -> RunnableWithFallbacksT[Input, Output]: + from langchain.schema.runnable.fallbacks import RunnableWithFallbacks + return RunnableWithFallbacks( runnable=self, fallbacks=fallbacks, diff --git a/libs/langchain/langchain/schema/runnable/fallbacks.py b/libs/langchain/langchain/schema/runnable/fallbacks.py new file mode 100644 index 00000000000..958c8119969 --- /dev/null +++ b/libs/langchain/langchain/schema/runnable/fallbacks.py @@ -0,0 +1,277 @@ +import asyncio +from ctypes import Union +from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Sequence, Tuple, Type +from langchain.load.dump import dumpd + +from langchain.pydantic_v1 import BaseModel +from langchain.schema.runnable.base import Runnable, RunnableSerializable +from langchain.schema.runnable.config import ( + RunnableConfig, + ensure_config, + get_async_callback_manager_for_config, + get_callback_manager_for_config, + get_config_list, + patch_config, +) +from langchain.schema.runnable.utils import Input, Output + +if TYPE_CHECKING: + from langchain.callbacks.manager import AsyncCallbackManagerForChainRun + + +class RunnableWithFallbacks(RunnableSerializable[Input, Output]): + """ + A Runnable that can fallback to other Runnables if it fails. + """ + + runnable: Runnable[Input, Output] + fallbacks: Sequence[Runnable[Input, Output]] + exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,) + + class Config: + arbitrary_types_allowed = True + + @property + def InputType(self) -> Type[Input]: + return self.runnable.InputType + + @property + def OutputType(self) -> Type[Output]: + return self.runnable.OutputType + + @property + def input_schema(self) -> Type[BaseModel]: + return self.runnable.input_schema + + @property + def output_schema(self) -> Type[BaseModel]: + return self.runnable.output_schema + + def config_schema( + self, *, include: Optional[Sequence[str]] = None + ) -> Type[BaseModel]: + return self.runnable.config_schema(include=include) + + @classmethod + def is_lc_serializable(cls) -> bool: + return True + + @classmethod + def get_lc_namespace(cls) -> List[str]: + return cls.__module__.split(".")[:-1] + + @property + def runnables(self) -> Iterator[Runnable[Input, Output]]: + yield self.runnable + yield from self.fallbacks + + def invoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: + # setup callbacks + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) + # start the root run + run_manager = callback_manager.on_chain_start( + dumpd(self), input, name=config.get("run_name") + ) + first_error = None + for runnable in self.runnables: + try: + output = runnable.invoke( + input, + patch_config(config, callbacks=run_manager.get_child()), + **kwargs, + ) + except self.exceptions_to_handle as e: + if first_error is None: + first_error = e + except BaseException as e: + run_manager.on_chain_error(e) + raise e + else: + run_manager.on_chain_end(output) + return output + if first_error is None: + raise ValueError("No error stored at end of fallbacks.") + run_manager.on_chain_error(first_error) + raise first_error + + async def ainvoke( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Output: + # setup callbacks + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) + # start the root run + run_manager = await callback_manager.on_chain_start( + dumpd(self), input, name=config.get("run_name") + ) + + first_error = None + for runnable in self.runnables: + try: + output = await runnable.ainvoke( + input, + patch_config(config, callbacks=run_manager.get_child()), + **kwargs, + ) + except self.exceptions_to_handle as e: + if first_error is None: + first_error = e + except BaseException as e: + await run_manager.on_chain_error(e) + raise e + else: + await run_manager.on_chain_end(output) + return output + if first_error is None: + raise ValueError("No error stored at end of fallbacks.") + await run_manager.on_chain_error(first_error) + raise first_error + + def batch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> List[Output]: + from langchain.callbacks.manager import CallbackManager + + if return_exceptions: + raise NotImplementedError() + + if not inputs: + return [] + + # setup callbacks + configs = get_config_list(config, len(inputs)) + callback_managers = [ + CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) + for config in configs + ] + # start the root runs, one per input + run_managers = [ + cm.on_chain_start( + dumpd(self), + input if isinstance(input, dict) else {"input": input}, + name=config.get("run_name"), + ) + for cm, input, config in zip(callback_managers, inputs, configs) + ] + + first_error = None + for runnable in self.runnables: + try: + outputs = runnable.batch( + inputs, + [ + # each step a child run of the corresponding root run + patch_config(config, callbacks=rm.get_child()) + for rm, config in zip(run_managers, configs) + ], + return_exceptions=return_exceptions, + **kwargs, + ) + except self.exceptions_to_handle as e: + if first_error is None: + first_error = e + except BaseException as e: + for rm in run_managers: + rm.on_chain_error(e) + raise e + else: + for rm, output in zip(run_managers, outputs): + rm.on_chain_end(output) + return outputs + if first_error is None: + raise ValueError("No error stored at end of fallbacks.") + for rm in run_managers: + rm.on_chain_error(first_error) + raise first_error + + async def abatch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> List[Output]: + from langchain.callbacks.manager import AsyncCallbackManager + + if return_exceptions: + raise NotImplementedError() + + if not inputs: + return [] + + # setup callbacks + configs = get_config_list(config, len(inputs)) + callback_managers = [ + AsyncCallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) + for config in configs + ] + # start the root runs, one per input + run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( + *( + cm.on_chain_start( + dumpd(self), + input, + name=config.get("run_name"), + ) + for cm, input, config in zip(callback_managers, inputs, configs) + ) + ) + + first_error = None + for runnable in self.runnables: + try: + outputs = await runnable.abatch( + inputs, + [ + # each step a child run of the corresponding root run + patch_config(config, callbacks=rm.get_child()) + for rm, config in zip(run_managers, configs) + ], + return_exceptions=return_exceptions, + **kwargs, + ) + except self.exceptions_to_handle as e: + if first_error is None: + first_error = e + except BaseException as e: + await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers)) + else: + await asyncio.gather( + *( + rm.on_chain_end(output) + for rm, output in zip(run_managers, outputs) + ) + ) + return outputs + if first_error is None: + raise ValueError("No error stored at end of fallbacks.") + await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers)) + raise first_error From 17708fc156e44ffbc5adde714fce2557c799f000 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 2 Oct 2023 10:28:58 +0100 Subject: [PATCH 7/9] Lint --- .../langchain/schema/runnable/base.py | 258 ------------------ 1 file changed, 258 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 384338b1921..0d3d38aa8e1 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -839,264 +839,6 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]): pass -class RunnableWithFallbacks(RunnableSerializable[Input, Output]): - """ - A Runnable that can fallback to other Runnables if it fails. - """ - - runnable: Runnable[Input, Output] - fallbacks: Sequence[Runnable[Input, Output]] - exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,) - - class Config: - arbitrary_types_allowed = True - - @property - def InputType(self) -> Type[Input]: - return self.runnable.InputType - - @property - def OutputType(self) -> Type[Output]: - return self.runnable.OutputType - - @property - def input_schema(self) -> Type[BaseModel]: - return self.runnable.input_schema - - @property - def output_schema(self) -> Type[BaseModel]: - return self.runnable.output_schema - - def config_schema( - self, *, include: Optional[Sequence[str]] = None - ) -> Type[BaseModel]: - return self.runnable.config_schema(include=include) - - @classmethod - def is_lc_serializable(cls) -> bool: - return True - - @classmethod - def get_lc_namespace(cls) -> List[str]: - return cls.__module__.split(".")[:-1] - - @property - def runnables(self) -> Iterator[Runnable[Input, Output]]: - yield self.runnable - yield from self.fallbacks - - def invoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> Output: - # setup callbacks - config = ensure_config(config) - callback_manager = get_callback_manager_for_config(config) - # start the root run - run_manager = callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") - ) - first_error = None - for runnable in self.runnables: - try: - output = runnable.invoke( - input, - patch_config(config, callbacks=run_manager.get_child()), - **kwargs, - ) - except self.exceptions_to_handle as e: - if first_error is None: - first_error = e - except BaseException as e: - run_manager.on_chain_error(e) - raise e - else: - run_manager.on_chain_end(output) - return output - if first_error is None: - raise ValueError("No error stored at end of fallbacks.") - run_manager.on_chain_error(first_error) - raise first_error - - async def ainvoke( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Output: - # setup callbacks - config = ensure_config(config) - callback_manager = get_async_callback_manager_for_config(config) - # start the root run - run_manager = await callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") - ) - - first_error = None - for runnable in self.runnables: - try: - output = await runnable.ainvoke( - input, - patch_config(config, callbacks=run_manager.get_child()), - **kwargs, - ) - except self.exceptions_to_handle as e: - if first_error is None: - first_error = e - except BaseException as e: - await run_manager.on_chain_error(e) - raise e - else: - await run_manager.on_chain_end(output) - return output - if first_error is None: - raise ValueError("No error stored at end of fallbacks.") - await run_manager.on_chain_error(first_error) - raise first_error - - def batch( - self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - return_exceptions: bool = False, - **kwargs: Optional[Any], - ) -> List[Output]: - from langchain.callbacks.manager import CallbackManager - - if return_exceptions: - raise NotImplementedError() - - if not inputs: - return [] - - # setup callbacks - configs = get_config_list(config, len(inputs)) - callback_managers = [ - CallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) - for config in configs - ] - # start the root runs, one per input - run_managers = [ - cm.on_chain_start( - dumpd(self), - input if isinstance(input, dict) else {"input": input}, - name=config.get("run_name"), - ) - for cm, input, config in zip(callback_managers, inputs, configs) - ] - - first_error = None - for runnable in self.runnables: - try: - outputs = runnable.batch( - inputs, - [ - # each step a child run of the corresponding root run - patch_config(config, callbacks=rm.get_child()) - for rm, config in zip(run_managers, configs) - ], - return_exceptions=return_exceptions, - **kwargs, - ) - except self.exceptions_to_handle as e: - if first_error is None: - first_error = e - except BaseException as e: - for rm in run_managers: - rm.on_chain_error(e) - raise e - else: - for rm, output in zip(run_managers, outputs): - rm.on_chain_end(output) - return outputs - if first_error is None: - raise ValueError("No error stored at end of fallbacks.") - for rm in run_managers: - rm.on_chain_error(first_error) - raise first_error - - async def abatch( - self, - inputs: List[Input], - config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - *, - return_exceptions: bool = False, - **kwargs: Optional[Any], - ) -> List[Output]: - from langchain.callbacks.manager import AsyncCallbackManager - - if return_exceptions: - raise NotImplementedError() - - if not inputs: - return [] - - # setup callbacks - configs = get_config_list(config, len(inputs)) - callback_managers = [ - AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) - for config in configs - ] - # start the root runs, one per input - run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( - *( - cm.on_chain_start( - dumpd(self), - input, - name=config.get("run_name"), - ) - for cm, input, config in zip(callback_managers, inputs, configs) - ) - ) - - first_error = None - for runnable in self.runnables: - try: - outputs = await runnable.abatch( - inputs, - [ - # each step a child run of the corresponding root run - patch_config(config, callbacks=rm.get_child()) - for rm, config in zip(run_managers, configs) - ], - return_exceptions=return_exceptions, - **kwargs, - ) - except self.exceptions_to_handle as e: - if first_error is None: - first_error = e - except BaseException as e: - await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers)) - else: - await asyncio.gather( - *( - rm.on_chain_end(output) - for rm, output in zip(run_managers, outputs) - ) - ) - return outputs - if first_error is None: - raise ValueError("No error stored at end of fallbacks.") - await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers)) - raise first_error - - class RunnableSequence(RunnableSerializable[Input, Output]): """ A sequence of runnables, where the output of each is the input of the next. From 1d46ddd16d1f56ac8ae11a35eeeb0dcc4da46e26 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 2 Oct 2023 10:29:20 +0100 Subject: [PATCH 8/9] Lint --- libs/langchain/langchain/schema/runnable/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/langchain/schema/runnable/__init__.py b/libs/langchain/langchain/schema/runnable/__init__.py index bde6121bed0..e9277319d91 100644 --- a/libs/langchain/langchain/schema/runnable/__init__.py +++ b/libs/langchain/langchain/schema/runnable/__init__.py @@ -6,10 +6,10 @@ from langchain.schema.runnable.base import ( RunnableMap, RunnableSequence, RunnableSerializable, - RunnableWithFallbacks, ) from langchain.schema.runnable.branch import RunnableBranch from langchain.schema.runnable.config import RunnableConfig, patch_config +from langchain.schema.runnable.fallbacks import RunnableWithFallbacks from langchain.schema.runnable.passthrough import RunnablePassthrough from langchain.schema.runnable.router import RouterInput, RouterRunnable From c6a720f256af002062da810e533c0ba6553433e5 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 2 Oct 2023 10:34:13 +0100 Subject: [PATCH 9/9] Lint --- .../langchain/schema/runnable/fallbacks.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/fallbacks.py b/libs/langchain/langchain/schema/runnable/fallbacks.py index 958c8119969..bba8d9a9e11 100644 --- a/libs/langchain/langchain/schema/runnable/fallbacks.py +++ b/libs/langchain/langchain/schema/runnable/fallbacks.py @@ -1,8 +1,17 @@ import asyncio -from ctypes import Union -from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Sequence, Tuple, Type -from langchain.load.dump import dumpd +from typing import ( + TYPE_CHECKING, + Any, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, + Union, +) +from langchain.load.dump import dumpd from langchain.pydantic_v1 import BaseModel from langchain.schema.runnable.base import Runnable, RunnableSerializable from langchain.schema.runnable.config import (