mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-28 18:48:50 +00:00
Create new RunnableSerializable base class in preparation for configurable runnables (#11279)
- Also move RunnableBranch to its own file <!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
commit
0638f7b83a
@ -21,7 +21,6 @@ from langchain.callbacks.manager import (
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
@ -30,7 +29,7 @@ from langchain.pydantic_v1 import (
|
||||
validator,
|
||||
)
|
||||
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
||||
from langchain.schema.runnable import RunnableConfig, RunnableSerializable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -39,7 +38,7 @@ def _get_verbosity() -> bool:
|
||||
return langchain.verbose
|
||||
|
||||
|
||||
class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
"""Abstract base class for creating structured sequences of calls to components.
|
||||
|
||||
Chains should be used to encode a sequence of calls to components like
|
||||
|
@ -14,7 +14,7 @@ from langchain.schema.runnable import RunnableConfig
|
||||
class FakeListLLM(LLM):
|
||||
"""Fake LLM for testing purposes."""
|
||||
|
||||
responses: List
|
||||
responses: List[str]
|
||||
sleep: Optional[float] = None
|
||||
i: int = 0
|
||||
|
||||
|
@ -15,11 +15,10 @@ from typing import (
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.messages import AnyMessage, BaseMessage, get_buffer_string
|
||||
from langchain.schema.output import LLMResult
|
||||
from langchain.schema.prompt import PromptValue
|
||||
from langchain.schema.runnable import Runnable
|
||||
from langchain.schema.runnable import RunnableSerializable
|
||||
from langchain.utils import get_pydantic_field_names
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -54,7 +53,7 @@ LanguageModelOutput = TypeVar("LanguageModelOutput")
|
||||
|
||||
|
||||
class BaseLanguageModel(
|
||||
Serializable, Runnable[LanguageModelInput, LanguageModelOutput], ABC
|
||||
RunnableSerializable[LanguageModelInput, LanguageModelOutput], ABC
|
||||
):
|
||||
"""Abstract base class for interfacing with language models.
|
||||
|
||||
|
@ -16,7 +16,6 @@ from typing import (
|
||||
|
||||
from typing_extensions import get_args
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk
|
||||
from langchain.schema.output import (
|
||||
ChatGeneration,
|
||||
@ -25,12 +24,12 @@ from langchain.schema.output import (
|
||||
GenerationChunk,
|
||||
)
|
||||
from langchain.schema.prompt import PromptValue
|
||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
||||
from langchain.schema.runnable import RunnableConfig, RunnableSerializable
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseLLMOutputParser(Serializable, Generic[T], ABC):
|
||||
class BaseLLMOutputParser(Generic[T], ABC):
|
||||
"""Abstract base class for parsing the outputs of a model."""
|
||||
|
||||
@abstractmethod
|
||||
@ -63,7 +62,7 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC):
|
||||
|
||||
|
||||
class BaseGenerationOutputParser(
|
||||
BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
||||
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
|
||||
):
|
||||
"""Base class to parse the output of an LLM call."""
|
||||
|
||||
@ -121,7 +120,9 @@ class BaseGenerationOutputParser(
|
||||
)
|
||||
|
||||
|
||||
class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]):
|
||||
class BaseOutputParser(
|
||||
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
|
||||
):
|
||||
"""Base class to parse the output of an LLM call.
|
||||
|
||||
Output parsers help structure language model responses.
|
||||
|
@ -7,15 +7,14 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.pydantic_v1 import BaseModel, Field, create_model, root_validator
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.output_parser import BaseOutputParser
|
||||
from langchain.schema.prompt import PromptValue
|
||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
||||
from langchain.schema.runnable import RunnableConfig, RunnableSerializable
|
||||
|
||||
|
||||
class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC):
|
||||
class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
||||
"""Base class for all prompt templates, returning a prompt."""
|
||||
|
||||
input_variables: List[str]
|
||||
|
@ -6,9 +6,8 @@ from inspect import signature
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
||||
from langchain.schema.runnable import RunnableConfig, RunnableSerializable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.callbacks.manager import (
|
||||
@ -18,7 +17,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
|
||||
class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
class BaseRetriever(RunnableSerializable[str, List[Document]], ABC):
|
||||
"""Abstract base class for a Document retrieval system.
|
||||
|
||||
A retrieval system is defined as something that can take string queries and return
|
||||
|
@ -2,13 +2,14 @@ from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar
|
||||
from langchain.schema.runnable.base import (
|
||||
Runnable,
|
||||
RunnableBinding,
|
||||
RunnableBranch,
|
||||
RunnableLambda,
|
||||
RunnableMap,
|
||||
RunnableSequence,
|
||||
RunnableWithFallbacks,
|
||||
RunnableSerializable,
|
||||
)
|
||||
from langchain.schema.runnable.branch import RunnableBranch
|
||||
from langchain.schema.runnable.config import RunnableConfig, patch_config
|
||||
from langchain.schema.runnable.fallbacks import RunnableWithFallbacks
|
||||
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
||||
from langchain.schema.runnable.router import RouterInput, RouterRunnable
|
||||
|
||||
@ -19,6 +20,7 @@ __all__ = [
|
||||
"RouterInput",
|
||||
"RouterRunnable",
|
||||
"Runnable",
|
||||
"RunnableSerializable",
|
||||
"RunnableBinding",
|
||||
"RunnableBranch",
|
||||
"RunnableConfig",
|
||||
|
@ -11,8 +11,7 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.runnable.base import Input, Output, Runnable
|
||||
from langchain.schema.runnable.base import Input, Output, RunnableSerializable
|
||||
from langchain.schema.runnable.config import RunnableConfig
|
||||
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
||||
|
||||
@ -104,7 +103,7 @@ class PutLocalVar(RunnablePassthrough):
|
||||
|
||||
|
||||
class GetLocalVar(
|
||||
Serializable, Runnable[Input, Union[Output, Dict[str, Union[Input, Output]]]]
|
||||
RunnableSerializable[Input, Union[Output, Dict[str, Union[Input, Output]]]]
|
||||
):
|
||||
key: str
|
||||
"""The key to extract from the local state."""
|
||||
|
@ -36,6 +36,9 @@ if TYPE_CHECKING:
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.callbacks.tracers.log_stream import RunLogPatch
|
||||
from langchain.schema.runnable.fallbacks import (
|
||||
RunnableWithFallbacks as RunnableWithFallbacksT,
|
||||
)
|
||||
|
||||
|
||||
from langchain.load.dump import dumpd
|
||||
@ -119,6 +122,24 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
self.__class__.__name__ + "Output", __root__=(root_type, None)
|
||||
)
|
||||
|
||||
def config_schema(
|
||||
self, *, include: Optional[Sequence[str]] = None
|
||||
) -> Type[BaseModel]:
|
||||
class _Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
include = include or []
|
||||
|
||||
return create_model( # type: ignore[call-overload]
|
||||
self.__class__.__name__ + "Config",
|
||||
__config__=_Config,
|
||||
**{
|
||||
field_name: (field_type, None)
|
||||
for field_name, field_type in RunnableConfig.__annotations__.items()
|
||||
if field_name in include
|
||||
},
|
||||
)
|
||||
|
||||
def __or__(
|
||||
self,
|
||||
other: Union[
|
||||
@ -437,7 +458,9 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
fallbacks: Sequence[Runnable[Input, Output]],
|
||||
*,
|
||||
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,),
|
||||
) -> RunnableWithFallbacks[Input, Output]:
|
||||
) -> RunnableWithFallbacksT[Input, Output]:
|
||||
from langchain.schema.runnable.fallbacks import RunnableWithFallbacks
|
||||
|
||||
return RunnableWithFallbacks(
|
||||
runnable=self,
|
||||
fallbacks=fallbacks,
|
||||
@ -812,462 +835,11 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
await run_manager.on_chain_end(final_output, inputs=final_input)
|
||||
|
||||
|
||||
class RunnableBranch(Serializable, Runnable[Input, Output]):
|
||||
"""A Runnable that selects which branch to run based on a condition.
|
||||
|
||||
The runnable is initialized with a list of (condition, runnable) pairs and
|
||||
a default branch.
|
||||
|
||||
When operating on an input, the first condition that evaluates to True is
|
||||
selected, and the corresponding runnable is run on the input.
|
||||
|
||||
If no condition evaluates to True, the default branch is run on the input.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.schema.runnable import RunnableBranch
|
||||
|
||||
branch = RunnableBranch(
|
||||
(lambda x: isinstance(x, str), lambda x: x.upper()),
|
||||
(lambda x: isinstance(x, int), lambda x: x + 1),
|
||||
(lambda x: isinstance(x, float), lambda x: x * 2),
|
||||
lambda x: "goodbye",
|
||||
)
|
||||
|
||||
branch.invoke("hello") # "HELLO"
|
||||
branch.invoke(None) # "goodbye"
|
||||
"""
|
||||
|
||||
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]]
|
||||
default: Runnable[Input, Output]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*branches: Union[
|
||||
Tuple[
|
||||
Union[
|
||||
Runnable[Input, bool],
|
||||
Callable[[Input], bool],
|
||||
Callable[[Input], Awaitable[bool]],
|
||||
],
|
||||
RunnableLike,
|
||||
],
|
||||
RunnableLike, # To accommodate the default branch
|
||||
],
|
||||
) -> None:
|
||||
"""A Runnable that runs one of two branches based on a condition."""
|
||||
if len(branches) < 2:
|
||||
raise ValueError("RunnableBranch requires at least two branches")
|
||||
|
||||
default = branches[-1]
|
||||
|
||||
if not isinstance(
|
||||
default, (Runnable, Callable, Mapping) # type: ignore[arg-type]
|
||||
):
|
||||
raise TypeError(
|
||||
"RunnableBranch default must be runnable, callable or mapping."
|
||||
)
|
||||
|
||||
default_ = cast(
|
||||
Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default))
|
||||
)
|
||||
|
||||
_branches = []
|
||||
|
||||
for branch in branches[:-1]:
|
||||
if not isinstance(branch, (tuple, list)): # type: ignore[arg-type]
|
||||
raise TypeError(
|
||||
f"RunnableBranch branches must be "
|
||||
f"tuples or lists, not {type(branch)}"
|
||||
)
|
||||
|
||||
if not len(branch) == 2:
|
||||
raise ValueError(
|
||||
f"RunnableBranch branches must be "
|
||||
f"tuples or lists of length 2, not {len(branch)}"
|
||||
)
|
||||
condition, runnable = branch
|
||||
condition = cast(Runnable[Input, bool], coerce_to_runnable(condition))
|
||||
runnable = coerce_to_runnable(runnable)
|
||||
_branches.append((condition, runnable))
|
||||
|
||||
super().__init__(branches=_branches, default=default_)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""RunnableBranch is serializable if all its branches are serializable."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""The namespace of a RunnableBranch is the namespace of its default branch."""
|
||||
return cls.__module__.split(".")[:-1]
|
||||
|
||||
@property
|
||||
def input_schema(self) -> type[BaseModel]:
|
||||
runnables = (
|
||||
[self.default]
|
||||
+ [r for _, r in self.branches]
|
||||
+ [r for r, _ in self.branches]
|
||||
)
|
||||
|
||||
for runnable in runnables:
|
||||
if runnable.input_schema.schema().get("type") is not None:
|
||||
return runnable.input_schema
|
||||
|
||||
return super().input_schema
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
"""First evaluates the condition, then delegate to true or false branch."""
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
|
||||
try:
|
||||
for idx, branch in enumerate(self.branches):
|
||||
condition, runnable = branch
|
||||
|
||||
expression_value = condition.invoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
|
||||
),
|
||||
)
|
||||
|
||||
if expression_value:
|
||||
output = runnable.invoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
|
||||
),
|
||||
)
|
||||
break
|
||||
else:
|
||||
output = self.default.invoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config, callbacks=run_manager.get_child(tag="branch:default")
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
run_manager.on_chain_end(dumpd(output))
|
||||
return output
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
"""Async version of invoke."""
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
for idx, branch in enumerate(self.branches):
|
||||
condition, runnable = branch
|
||||
|
||||
expression_value = await condition.ainvoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
|
||||
),
|
||||
)
|
||||
|
||||
if expression_value:
|
||||
output = await runnable.ainvoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
break
|
||||
else:
|
||||
output = await self.default.ainvoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config, callbacks=run_manager.get_child(tag="branch:default")
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
run_manager.on_chain_end(dumpd(output))
|
||||
return output
|
||||
class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||
pass
|
||||
|
||||
|
||||
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
"""
|
||||
A Runnable that can fallback to other Runnables if it fails.
|
||||
"""
|
||||
|
||||
runnable: Runnable[Input, Output]
|
||||
fallbacks: Sequence[Runnable[Input, Output]]
|
||||
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
return self.runnable.InputType
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[Output]:
|
||||
return self.runnable.OutputType
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
return self.runnable.input_schema
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
return self.runnable.output_schema
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return cls.__module__.split(".")[:-1]
|
||||
|
||||
@property
|
||||
def runnables(self) -> Iterator[Runnable[Input, Output]]:
|
||||
yield self.runnable
|
||||
yield from self.fallbacks
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
# setup callbacks
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
)
|
||||
first_error = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
output = runnable.invoke(
|
||||
input,
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
**kwargs,
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
first_error = e
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_manager.on_chain_end(output)
|
||||
return output
|
||||
if first_error is None:
|
||||
raise ValueError("No error stored at end of fallbacks.")
|
||||
run_manager.on_chain_error(first_error)
|
||||
raise first_error
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
# setup callbacks
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
)
|
||||
|
||||
first_error = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
output = await runnable.ainvoke(
|
||||
input,
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
**kwargs,
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
first_error = e
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
else:
|
||||
await run_manager.on_chain_end(output)
|
||||
return output
|
||||
if first_error is None:
|
||||
raise ValueError("No error stored at end of fallbacks.")
|
||||
await run_manager.on_chain_error(first_error)
|
||||
raise first_error
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
if return_exceptions:
|
||||
raise NotImplementedError()
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
# setup callbacks
|
||||
configs = get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
for config in configs
|
||||
]
|
||||
# start the root runs, one per input
|
||||
run_managers = [
|
||||
cm.on_chain_start(
|
||||
dumpd(self),
|
||||
input if isinstance(input, dict) else {"input": input},
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
]
|
||||
|
||||
first_error = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
outputs = runnable.batch(
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
patch_config(config, callbacks=rm.get_child())
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
first_error = e
|
||||
except BaseException as e:
|
||||
for rm in run_managers:
|
||||
rm.on_chain_error(e)
|
||||
raise e
|
||||
else:
|
||||
for rm, output in zip(run_managers, outputs):
|
||||
rm.on_chain_end(output)
|
||||
return outputs
|
||||
if first_error is None:
|
||||
raise ValueError("No error stored at end of fallbacks.")
|
||||
for rm in run_managers:
|
||||
rm.on_chain_error(first_error)
|
||||
raise first_error
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
if return_exceptions:
|
||||
raise NotImplementedError()
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
# setup callbacks
|
||||
configs = get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
for config in configs
|
||||
]
|
||||
# start the root runs, one per input
|
||||
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||
*(
|
||||
cm.on_chain_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
)
|
||||
)
|
||||
|
||||
first_error = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
outputs = await runnable.abatch(
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
patch_config(config, callbacks=rm.get_child())
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
first_error = e
|
||||
except BaseException as e:
|
||||
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
|
||||
else:
|
||||
await asyncio.gather(
|
||||
*(
|
||||
rm.on_chain_end(output)
|
||||
for rm, output in zip(run_managers, outputs)
|
||||
)
|
||||
)
|
||||
return outputs
|
||||
if first_error is None:
|
||||
raise ValueError("No error stored at end of fallbacks.")
|
||||
await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers))
|
||||
raise first_error
|
||||
|
||||
|
||||
class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
"""
|
||||
A sequence of runnables, where the output of each is the input of the next.
|
||||
"""
|
||||
@ -1749,7 +1321,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
yield chunk
|
||||
|
||||
|
||||
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
"""
|
||||
A runnable that runs a mapping of runnables in parallel,
|
||||
and returns a mapping of their outputs.
|
||||
@ -1799,7 +1371,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableMapInput",
|
||||
**{
|
||||
k: (v.type_, v.default)
|
||||
k: (v.annotation, v.default)
|
||||
for step in self.steps.values()
|
||||
for k, v in step.input_schema.__fields__.items()
|
||||
if k != "__root__"
|
||||
@ -2374,7 +1946,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
return await super().ainvoke(input, config)
|
||||
|
||||
|
||||
class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
|
||||
class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
|
||||
"""
|
||||
A runnable that delegates calls to another runnable
|
||||
with each element of the input sequence.
|
||||
@ -2413,6 +1985,11 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
|
||||
),
|
||||
)
|
||||
|
||||
def config_schema(
|
||||
self, *, include: Optional[Sequence[str]] = None
|
||||
) -> Type[BaseModel]:
|
||||
return self.bound.config_schema(include=include)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
@ -2455,7 +2032,7 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
|
||||
return await self._acall_with_config(self._ainvoke, input, config)
|
||||
|
||||
|
||||
class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
"""
|
||||
A runnable that delegates calls to another runnable with a set of kwargs.
|
||||
"""
|
||||
@ -2485,6 +2062,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
return self.bound.output_schema
|
||||
|
||||
def config_schema(
|
||||
self, *, include: Optional[Sequence[str]] = None
|
||||
) -> Type[BaseModel]:
|
||||
return self.bound.config_schema(include=include)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
235
libs/langchain/langchain/schema/runnable/branch.py
Normal file
235
libs/langchain/langchain/schema/runnable/branch.py
Normal file
@ -0,0 +1,235 @@
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
from langchain.schema.runnable.base import (
|
||||
Runnable,
|
||||
RunnableLike,
|
||||
RunnableSerializable,
|
||||
coerce_to_runnable,
|
||||
)
|
||||
from langchain.schema.runnable.config import (
|
||||
RunnableConfig,
|
||||
ensure_config,
|
||||
get_callback_manager_for_config,
|
||||
patch_config,
|
||||
)
|
||||
from langchain.schema.runnable.utils import Input, Output
|
||||
|
||||
|
||||
class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
"""A Runnable that selects which branch to run based on a condition.
|
||||
|
||||
The runnable is initialized with a list of (condition, runnable) pairs and
|
||||
a default branch.
|
||||
|
||||
When operating on an input, the first condition that evaluates to True is
|
||||
selected, and the corresponding runnable is run on the input.
|
||||
|
||||
If no condition evaluates to True, the default branch is run on the input.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.schema.runnable import RunnableBranch
|
||||
|
||||
branch = RunnableBranch(
|
||||
(lambda x: isinstance(x, str), lambda x: x.upper()),
|
||||
(lambda x: isinstance(x, int), lambda x: x + 1),
|
||||
(lambda x: isinstance(x, float), lambda x: x * 2),
|
||||
lambda x: "goodbye",
|
||||
)
|
||||
|
||||
branch.invoke("hello") # "HELLO"
|
||||
branch.invoke(None) # "goodbye"
|
||||
"""
|
||||
|
||||
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]]
|
||||
default: Runnable[Input, Output]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*branches: Union[
|
||||
Tuple[
|
||||
Union[
|
||||
Runnable[Input, bool],
|
||||
Callable[[Input], bool],
|
||||
Callable[[Input], Awaitable[bool]],
|
||||
],
|
||||
RunnableLike,
|
||||
],
|
||||
RunnableLike, # To accommodate the default branch
|
||||
],
|
||||
) -> None:
|
||||
"""A Runnable that runs one of two branches based on a condition."""
|
||||
if len(branches) < 2:
|
||||
raise ValueError("RunnableBranch requires at least two branches")
|
||||
|
||||
default = branches[-1]
|
||||
|
||||
if not isinstance(
|
||||
default, (Runnable, Callable, Mapping) # type: ignore[arg-type]
|
||||
):
|
||||
raise TypeError(
|
||||
"RunnableBranch default must be runnable, callable or mapping."
|
||||
)
|
||||
|
||||
default_ = cast(
|
||||
Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default))
|
||||
)
|
||||
|
||||
_branches = []
|
||||
|
||||
for branch in branches[:-1]:
|
||||
if not isinstance(branch, (tuple, list)): # type: ignore[arg-type]
|
||||
raise TypeError(
|
||||
f"RunnableBranch branches must be "
|
||||
f"tuples or lists, not {type(branch)}"
|
||||
)
|
||||
|
||||
if not len(branch) == 2:
|
||||
raise ValueError(
|
||||
f"RunnableBranch branches must be "
|
||||
f"tuples or lists of length 2, not {len(branch)}"
|
||||
)
|
||||
condition, runnable = branch
|
||||
condition = cast(Runnable[Input, bool], coerce_to_runnable(condition))
|
||||
runnable = coerce_to_runnable(runnable)
|
||||
_branches.append((condition, runnable))
|
||||
|
||||
super().__init__(branches=_branches, default=default_)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""RunnableBranch is serializable if all its branches are serializable."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""The namespace of a RunnableBranch is the namespace of its default branch."""
|
||||
return cls.__module__.split(".")[:-1]
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
runnables = (
|
||||
[self.default]
|
||||
+ [r for _, r in self.branches]
|
||||
+ [r for r, _ in self.branches]
|
||||
)
|
||||
|
||||
for runnable in runnables:
|
||||
if runnable.input_schema.schema().get("type") is not None:
|
||||
return runnable.input_schema
|
||||
|
||||
return super().input_schema
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
"""First evaluates the condition, then delegate to true or false branch."""
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
|
||||
try:
|
||||
for idx, branch in enumerate(self.branches):
|
||||
condition, runnable = branch
|
||||
|
||||
expression_value = condition.invoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
|
||||
),
|
||||
)
|
||||
|
||||
if expression_value:
|
||||
output = runnable.invoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
break
|
||||
else:
|
||||
output = self.default.invoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config, callbacks=run_manager.get_child(tag="branch:default")
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
run_manager.on_chain_end(dumpd(output))
|
||||
return output
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
"""Async version of invoke."""
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
for idx, branch in enumerate(self.branches):
|
||||
condition, runnable = branch
|
||||
|
||||
expression_value = await condition.ainvoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
|
||||
),
|
||||
)
|
||||
|
||||
if expression_value:
|
||||
output = await runnable.ainvoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
break
|
||||
else:
|
||||
output = await self.default.ainvoke(
|
||||
input,
|
||||
config=patch_config(
|
||||
config, callbacks=run_manager.get_child(tag="branch:default")
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
run_manager.on_chain_end(dumpd(output))
|
||||
return output
|
286
libs/langchain/langchain/schema/runnable/fallbacks.py
Normal file
286
libs/langchain/langchain/schema/runnable/fallbacks.py
Normal file
@ -0,0 +1,286 @@
|
||||
import asyncio
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
from langchain.schema.runnable.base import Runnable, RunnableSerializable
|
||||
from langchain.schema.runnable.config import (
|
||||
RunnableConfig,
|
||||
ensure_config,
|
||||
get_async_callback_manager_for_config,
|
||||
get_callback_manager_for_config,
|
||||
get_config_list,
|
||||
patch_config,
|
||||
)
|
||||
from langchain.schema.runnable.utils import Input, Output
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun
|
||||
|
||||
|
||||
class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
"""
|
||||
A Runnable that can fallback to other Runnables if it fails.
|
||||
"""
|
||||
|
||||
runnable: Runnable[Input, Output]
|
||||
fallbacks: Sequence[Runnable[Input, Output]]
|
||||
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
return self.runnable.InputType
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[Output]:
|
||||
return self.runnable.OutputType
|
||||
|
||||
@property
|
||||
def input_schema(self) -> Type[BaseModel]:
|
||||
return self.runnable.input_schema
|
||||
|
||||
@property
|
||||
def output_schema(self) -> Type[BaseModel]:
|
||||
return self.runnable.output_schema
|
||||
|
||||
def config_schema(
|
||||
self, *, include: Optional[Sequence[str]] = None
|
||||
) -> Type[BaseModel]:
|
||||
return self.runnable.config_schema(include=include)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return cls.__module__.split(".")[:-1]
|
||||
|
||||
@property
|
||||
def runnables(self) -> Iterator[Runnable[Input, Output]]:
|
||||
yield self.runnable
|
||||
yield from self.fallbacks
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
# setup callbacks
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
)
|
||||
first_error = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
output = runnable.invoke(
|
||||
input,
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
**kwargs,
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
first_error = e
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_manager.on_chain_end(output)
|
||||
return output
|
||||
if first_error is None:
|
||||
raise ValueError("No error stored at end of fallbacks.")
|
||||
run_manager.on_chain_error(first_error)
|
||||
raise first_error
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
# setup callbacks
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
)
|
||||
|
||||
first_error = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
output = await runnable.ainvoke(
|
||||
input,
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
**kwargs,
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
first_error = e
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
else:
|
||||
await run_manager.on_chain_end(output)
|
||||
return output
|
||||
if first_error is None:
|
||||
raise ValueError("No error stored at end of fallbacks.")
|
||||
await run_manager.on_chain_error(first_error)
|
||||
raise first_error
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
if return_exceptions:
|
||||
raise NotImplementedError()
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
# setup callbacks
|
||||
configs = get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
for config in configs
|
||||
]
|
||||
# start the root runs, one per input
|
||||
run_managers = [
|
||||
cm.on_chain_start(
|
||||
dumpd(self),
|
||||
input if isinstance(input, dict) else {"input": input},
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
]
|
||||
|
||||
first_error = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
outputs = runnable.batch(
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
patch_config(config, callbacks=rm.get_child())
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
first_error = e
|
||||
except BaseException as e:
|
||||
for rm in run_managers:
|
||||
rm.on_chain_error(e)
|
||||
raise e
|
||||
else:
|
||||
for rm, output in zip(run_managers, outputs):
|
||||
rm.on_chain_end(output)
|
||||
return outputs
|
||||
if first_error is None:
|
||||
raise ValueError("No error stored at end of fallbacks.")
|
||||
for rm in run_managers:
|
||||
rm.on_chain_error(first_error)
|
||||
raise first_error
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
if return_exceptions:
|
||||
raise NotImplementedError()
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
# setup callbacks
|
||||
configs = get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
verbose=False,
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=None,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
local_metadata=None,
|
||||
)
|
||||
for config in configs
|
||||
]
|
||||
# start the root runs, one per input
|
||||
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||
*(
|
||||
cm.on_chain_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
)
|
||||
)
|
||||
|
||||
first_error = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
outputs = await runnable.abatch(
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
patch_config(config, callbacks=rm.get_child())
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
first_error = e
|
||||
except BaseException as e:
|
||||
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
|
||||
else:
|
||||
await asyncio.gather(
|
||||
*(
|
||||
rm.on_chain_end(output)
|
||||
for rm, output in zip(run_managers, outputs)
|
||||
)
|
||||
)
|
||||
return outputs
|
||||
if first_error is None:
|
||||
raise ValueError("No error stored at end of fallbacks.")
|
||||
await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers))
|
||||
raise first_error
|
@ -16,9 +16,13 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.pydantic_v1 import BaseModel, create_model
|
||||
from langchain.schema.runnable.base import Input, Runnable, RunnableMap
|
||||
from langchain.schema.runnable.base import (
|
||||
Input,
|
||||
Runnable,
|
||||
RunnableMap,
|
||||
RunnableSerializable,
|
||||
)
|
||||
from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config
|
||||
from langchain.schema.runnable.utils import AddableDict
|
||||
from langchain.utils.aiter import atee, py_anext
|
||||
@ -33,7 +37,7 @@ async def aidentity(x: Input) -> Input:
|
||||
return x
|
||||
|
||||
|
||||
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
||||
class RunnablePassthrough(RunnableSerializable[Input, Input]):
|
||||
"""
|
||||
A runnable that passes through the input.
|
||||
"""
|
||||
@ -109,7 +113,7 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
||||
yield chunk
|
||||
|
||||
|
||||
class RunnableAssign(Serializable, Runnable[Dict[str, Any], Dict[str, Any]]):
|
||||
class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
"""
|
||||
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
|
||||
"""
|
||||
|
@ -14,8 +14,13 @@ from typing import (
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.runnable.base import Input, Output, Runnable, coerce_to_runnable
|
||||
from langchain.schema.runnable.base import (
|
||||
Input,
|
||||
Output,
|
||||
Runnable,
|
||||
RunnableSerializable,
|
||||
coerce_to_runnable,
|
||||
)
|
||||
from langchain.schema.runnable.config import (
|
||||
RunnableConfig,
|
||||
get_config_list,
|
||||
@ -36,7 +41,7 @@ class RouterInput(TypedDict):
|
||||
input: Any
|
||||
|
||||
|
||||
class RouterRunnable(Serializable, Runnable[RouterInput, Output]):
|
||||
class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
||||
"""
|
||||
A runnable that routes to a set of runnables based on Input['key'].
|
||||
Returns the output of the selected runnable.
|
||||
|
@ -17,6 +17,7 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForToolRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
@ -25,7 +26,7 @@ from langchain.pydantic_v1 import (
|
||||
root_validator,
|
||||
validate_arguments,
|
||||
)
|
||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
||||
from langchain.schema.runnable import Runnable, RunnableConfig, RunnableSerializable
|
||||
|
||||
|
||||
class SchemaAnnotationError(TypeError):
|
||||
@ -97,7 +98,7 @@ class ToolException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BaseTool(BaseModel, Runnable[Union[str, Dict], Any]):
|
||||
class BaseTool(RunnableSerializable[Union[str, Dict], Any]):
|
||||
"""Interface LangChain tools must implement."""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
@ -165,10 +166,9 @@ class ChildTool(BaseTool):
|
||||
] = False
|
||||
"""Handle the content of the ToolException thrown."""
|
||||
|
||||
class Config:
|
||||
class Config(Serializable.Config):
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
|
@ -2,7 +2,7 @@
|
||||
"""Tools for interacting with Spark SQL."""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
||||
from langchain.pydantic_v1 import BaseModel, Field, root_validator
|
||||
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
@ -21,13 +21,8 @@ class BaseSparkSQLTool(BaseModel):
|
||||
|
||||
db: SparkSQL = Field(exclude=True)
|
||||
|
||||
# Override BaseTool.Config to appease mypy
|
||||
# See https://github.com/pydantic/pydantic/issues/4173
|
||||
class Config(BaseTool.Config):
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
extra = Extra.forbid
|
||||
pass
|
||||
|
||||
|
||||
class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool):
|
||||
|
@ -21,13 +21,8 @@ class BaseSQLDatabaseTool(BaseModel):
|
||||
|
||||
db: SQLDatabase = Field(exclude=True)
|
||||
|
||||
# Override BaseTool.Config to appease mypy
|
||||
# See https://github.com/pydantic/pydantic/issues/4173
|
||||
class Config(BaseTool.Config):
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
extra = Extra.forbid
|
||||
pass
|
||||
|
||||
|
||||
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
|
@ -18,9 +18,7 @@ class BaseVectorStoreTool(BaseModel):
|
||||
llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
|
||||
|
||||
class Config(BaseTool.Config):
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
pass
|
||||
|
||||
|
||||
def _create_description_from_template(values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
Loading…
Reference in New Issue
Block a user