mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-28 18:48:50 +00:00
Create new RunnableSerializable base class in preparation for configurable runnables (#11279)
- Also move RunnableBranch to its own file <!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
commit
0638f7b83a
@ -21,7 +21,6 @@ from langchain.callbacks.manager import (
|
|||||||
Callbacks,
|
Callbacks,
|
||||||
)
|
)
|
||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd
|
||||||
from langchain.load.serializable import Serializable
|
|
||||||
from langchain.pydantic_v1 import (
|
from langchain.pydantic_v1 import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
Field,
|
Field,
|
||||||
@ -30,7 +29,7 @@ from langchain.pydantic_v1 import (
|
|||||||
validator,
|
validator,
|
||||||
)
|
)
|
||||||
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
||||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
from langchain.schema.runnable import RunnableConfig, RunnableSerializable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -39,7 +38,7 @@ def _get_verbosity() -> bool:
|
|||||||
return langchain.verbose
|
return langchain.verbose
|
||||||
|
|
||||||
|
|
||||||
class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||||
"""Abstract base class for creating structured sequences of calls to components.
|
"""Abstract base class for creating structured sequences of calls to components.
|
||||||
|
|
||||||
Chains should be used to encode a sequence of calls to components like
|
Chains should be used to encode a sequence of calls to components like
|
||||||
|
@ -14,7 +14,7 @@ from langchain.schema.runnable import RunnableConfig
|
|||||||
class FakeListLLM(LLM):
|
class FakeListLLM(LLM):
|
||||||
"""Fake LLM for testing purposes."""
|
"""Fake LLM for testing purposes."""
|
||||||
|
|
||||||
responses: List
|
responses: List[str]
|
||||||
sleep: Optional[float] = None
|
sleep: Optional[float] = None
|
||||||
i: int = 0
|
i: int = 0
|
||||||
|
|
||||||
|
@ -15,11 +15,10 @@ from typing import (
|
|||||||
|
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
|
||||||
from langchain.schema.messages import AnyMessage, BaseMessage, get_buffer_string
|
from langchain.schema.messages import AnyMessage, BaseMessage, get_buffer_string
|
||||||
from langchain.schema.output import LLMResult
|
from langchain.schema.output import LLMResult
|
||||||
from langchain.schema.prompt import PromptValue
|
from langchain.schema.prompt import PromptValue
|
||||||
from langchain.schema.runnable import Runnable
|
from langchain.schema.runnable import RunnableSerializable
|
||||||
from langchain.utils import get_pydantic_field_names
|
from langchain.utils import get_pydantic_field_names
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -54,7 +53,7 @@ LanguageModelOutput = TypeVar("LanguageModelOutput")
|
|||||||
|
|
||||||
|
|
||||||
class BaseLanguageModel(
|
class BaseLanguageModel(
|
||||||
Serializable, Runnable[LanguageModelInput, LanguageModelOutput], ABC
|
RunnableSerializable[LanguageModelInput, LanguageModelOutput], ABC
|
||||||
):
|
):
|
||||||
"""Abstract base class for interfacing with language models.
|
"""Abstract base class for interfacing with language models.
|
||||||
|
|
||||||
|
@ -16,7 +16,6 @@ from typing import (
|
|||||||
|
|
||||||
from typing_extensions import get_args
|
from typing_extensions import get_args
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
|
||||||
from langchain.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk
|
from langchain.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk
|
||||||
from langchain.schema.output import (
|
from langchain.schema.output import (
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
@ -25,12 +24,12 @@ from langchain.schema.output import (
|
|||||||
GenerationChunk,
|
GenerationChunk,
|
||||||
)
|
)
|
||||||
from langchain.schema.prompt import PromptValue
|
from langchain.schema.prompt import PromptValue
|
||||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
from langchain.schema.runnable import RunnableConfig, RunnableSerializable
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class BaseLLMOutputParser(Serializable, Generic[T], ABC):
|
class BaseLLMOutputParser(Generic[T], ABC):
|
||||||
"""Abstract base class for parsing the outputs of a model."""
|
"""Abstract base class for parsing the outputs of a model."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -63,7 +62,7 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC):
|
|||||||
|
|
||||||
|
|
||||||
class BaseGenerationOutputParser(
|
class BaseGenerationOutputParser(
|
||||||
BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
|
||||||
):
|
):
|
||||||
"""Base class to parse the output of an LLM call."""
|
"""Base class to parse the output of an LLM call."""
|
||||||
|
|
||||||
@ -121,7 +120,9 @@ class BaseGenerationOutputParser(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]):
|
class BaseOutputParser(
|
||||||
|
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
|
||||||
|
):
|
||||||
"""Base class to parse the output of an LLM call.
|
"""Base class to parse the output of an LLM call.
|
||||||
|
|
||||||
Output parsers help structure language model responses.
|
Output parsers help structure language model responses.
|
||||||
|
@ -7,15 +7,14 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Union
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
|
||||||
from langchain.pydantic_v1 import BaseModel, Field, create_model, root_validator
|
from langchain.pydantic_v1 import BaseModel, Field, create_model, root_validator
|
||||||
from langchain.schema.document import Document
|
from langchain.schema.document import Document
|
||||||
from langchain.schema.output_parser import BaseOutputParser
|
from langchain.schema.output_parser import BaseOutputParser
|
||||||
from langchain.schema.prompt import PromptValue
|
from langchain.schema.prompt import PromptValue
|
||||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
from langchain.schema.runnable import RunnableConfig, RunnableSerializable
|
||||||
|
|
||||||
|
|
||||||
class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC):
|
class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
||||||
"""Base class for all prompt templates, returning a prompt."""
|
"""Base class for all prompt templates, returning a prompt."""
|
||||||
|
|
||||||
input_variables: List[str]
|
input_variables: List[str]
|
||||||
|
@ -6,9 +6,8 @@ from inspect import signature
|
|||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd
|
||||||
from langchain.load.serializable import Serializable
|
|
||||||
from langchain.schema.document import Document
|
from langchain.schema.document import Document
|
||||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
from langchain.schema.runnable import RunnableConfig, RunnableSerializable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
@ -18,7 +17,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
class BaseRetriever(RunnableSerializable[str, List[Document]], ABC):
|
||||||
"""Abstract base class for a Document retrieval system.
|
"""Abstract base class for a Document retrieval system.
|
||||||
|
|
||||||
A retrieval system is defined as something that can take string queries and return
|
A retrieval system is defined as something that can take string queries and return
|
||||||
|
@ -2,13 +2,14 @@ from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar
|
|||||||
from langchain.schema.runnable.base import (
|
from langchain.schema.runnable.base import (
|
||||||
Runnable,
|
Runnable,
|
||||||
RunnableBinding,
|
RunnableBinding,
|
||||||
RunnableBranch,
|
|
||||||
RunnableLambda,
|
RunnableLambda,
|
||||||
RunnableMap,
|
RunnableMap,
|
||||||
RunnableSequence,
|
RunnableSequence,
|
||||||
RunnableWithFallbacks,
|
RunnableSerializable,
|
||||||
)
|
)
|
||||||
|
from langchain.schema.runnable.branch import RunnableBranch
|
||||||
from langchain.schema.runnable.config import RunnableConfig, patch_config
|
from langchain.schema.runnable.config import RunnableConfig, patch_config
|
||||||
|
from langchain.schema.runnable.fallbacks import RunnableWithFallbacks
|
||||||
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
||||||
from langchain.schema.runnable.router import RouterInput, RouterRunnable
|
from langchain.schema.runnable.router import RouterInput, RouterRunnable
|
||||||
|
|
||||||
@ -19,6 +20,7 @@ __all__ = [
|
|||||||
"RouterInput",
|
"RouterInput",
|
||||||
"RouterRunnable",
|
"RouterRunnable",
|
||||||
"Runnable",
|
"Runnable",
|
||||||
|
"RunnableSerializable",
|
||||||
"RunnableBinding",
|
"RunnableBinding",
|
||||||
"RunnableBranch",
|
"RunnableBranch",
|
||||||
"RunnableConfig",
|
"RunnableConfig",
|
||||||
|
@ -11,8 +11,7 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.schema.runnable.base import Input, Output, RunnableSerializable
|
||||||
from langchain.schema.runnable.base import Input, Output, Runnable
|
|
||||||
from langchain.schema.runnable.config import RunnableConfig
|
from langchain.schema.runnable.config import RunnableConfig
|
||||||
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
||||||
|
|
||||||
@ -104,7 +103,7 @@ class PutLocalVar(RunnablePassthrough):
|
|||||||
|
|
||||||
|
|
||||||
class GetLocalVar(
|
class GetLocalVar(
|
||||||
Serializable, Runnable[Input, Union[Output, Dict[str, Union[Input, Output]]]]
|
RunnableSerializable[Input, Union[Output, Dict[str, Union[Input, Output]]]]
|
||||||
):
|
):
|
||||||
key: str
|
key: str
|
||||||
"""The key to extract from the local state."""
|
"""The key to extract from the local state."""
|
||||||
|
@ -36,6 +36,9 @@ if TYPE_CHECKING:
|
|||||||
CallbackManagerForChainRun,
|
CallbackManagerForChainRun,
|
||||||
)
|
)
|
||||||
from langchain.callbacks.tracers.log_stream import RunLogPatch
|
from langchain.callbacks.tracers.log_stream import RunLogPatch
|
||||||
|
from langchain.schema.runnable.fallbacks import (
|
||||||
|
RunnableWithFallbacks as RunnableWithFallbacksT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd
|
||||||
@ -119,6 +122,24 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
self.__class__.__name__ + "Output", __root__=(root_type, None)
|
self.__class__.__name__ + "Output", __root__=(root_type, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def config_schema(
|
||||||
|
self, *, include: Optional[Sequence[str]] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
|
class _Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
include = include or []
|
||||||
|
|
||||||
|
return create_model( # type: ignore[call-overload]
|
||||||
|
self.__class__.__name__ + "Config",
|
||||||
|
__config__=_Config,
|
||||||
|
**{
|
||||||
|
field_name: (field_type, None)
|
||||||
|
for field_name, field_type in RunnableConfig.__annotations__.items()
|
||||||
|
if field_name in include
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def __or__(
|
def __or__(
|
||||||
self,
|
self,
|
||||||
other: Union[
|
other: Union[
|
||||||
@ -437,7 +458,9 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
fallbacks: Sequence[Runnable[Input, Output]],
|
fallbacks: Sequence[Runnable[Input, Output]],
|
||||||
*,
|
*,
|
||||||
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,),
|
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,),
|
||||||
) -> RunnableWithFallbacks[Input, Output]:
|
) -> RunnableWithFallbacksT[Input, Output]:
|
||||||
|
from langchain.schema.runnable.fallbacks import RunnableWithFallbacks
|
||||||
|
|
||||||
return RunnableWithFallbacks(
|
return RunnableWithFallbacks(
|
||||||
runnable=self,
|
runnable=self,
|
||||||
fallbacks=fallbacks,
|
fallbacks=fallbacks,
|
||||||
@ -812,462 +835,11 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
await run_manager.on_chain_end(final_output, inputs=final_input)
|
await run_manager.on_chain_end(final_output, inputs=final_input)
|
||||||
|
|
||||||
|
|
||||||
class RunnableBranch(Serializable, Runnable[Input, Output]):
|
class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||||
"""A Runnable that selects which branch to run based on a condition.
|
pass
|
||||||
|
|
||||||
The runnable is initialized with a list of (condition, runnable) pairs and
|
|
||||||
a default branch.
|
|
||||||
|
|
||||||
When operating on an input, the first condition that evaluates to True is
|
|
||||||
selected, and the corresponding runnable is run on the input.
|
|
||||||
|
|
||||||
If no condition evaluates to True, the default branch is run on the input.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from langchain.schema.runnable import RunnableBranch
|
|
||||||
|
|
||||||
branch = RunnableBranch(
|
|
||||||
(lambda x: isinstance(x, str), lambda x: x.upper()),
|
|
||||||
(lambda x: isinstance(x, int), lambda x: x + 1),
|
|
||||||
(lambda x: isinstance(x, float), lambda x: x * 2),
|
|
||||||
lambda x: "goodbye",
|
|
||||||
)
|
|
||||||
|
|
||||||
branch.invoke("hello") # "HELLO"
|
|
||||||
branch.invoke(None) # "goodbye"
|
|
||||||
"""
|
|
||||||
|
|
||||||
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]]
|
|
||||||
default: Runnable[Input, Output]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*branches: Union[
|
|
||||||
Tuple[
|
|
||||||
Union[
|
|
||||||
Runnable[Input, bool],
|
|
||||||
Callable[[Input], bool],
|
|
||||||
Callable[[Input], Awaitable[bool]],
|
|
||||||
],
|
|
||||||
RunnableLike,
|
|
||||||
],
|
|
||||||
RunnableLike, # To accommodate the default branch
|
|
||||||
],
|
|
||||||
) -> None:
|
|
||||||
"""A Runnable that runs one of two branches based on a condition."""
|
|
||||||
if len(branches) < 2:
|
|
||||||
raise ValueError("RunnableBranch requires at least two branches")
|
|
||||||
|
|
||||||
default = branches[-1]
|
|
||||||
|
|
||||||
if not isinstance(
|
|
||||||
default, (Runnable, Callable, Mapping) # type: ignore[arg-type]
|
|
||||||
):
|
|
||||||
raise TypeError(
|
|
||||||
"RunnableBranch default must be runnable, callable or mapping."
|
|
||||||
)
|
|
||||||
|
|
||||||
default_ = cast(
|
|
||||||
Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default))
|
|
||||||
)
|
|
||||||
|
|
||||||
_branches = []
|
|
||||||
|
|
||||||
for branch in branches[:-1]:
|
|
||||||
if not isinstance(branch, (tuple, list)): # type: ignore[arg-type]
|
|
||||||
raise TypeError(
|
|
||||||
f"RunnableBranch branches must be "
|
|
||||||
f"tuples or lists, not {type(branch)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not len(branch) == 2:
|
|
||||||
raise ValueError(
|
|
||||||
f"RunnableBranch branches must be "
|
|
||||||
f"tuples or lists of length 2, not {len(branch)}"
|
|
||||||
)
|
|
||||||
condition, runnable = branch
|
|
||||||
condition = cast(Runnable[Input, bool], coerce_to_runnable(condition))
|
|
||||||
runnable = coerce_to_runnable(runnable)
|
|
||||||
_branches.append((condition, runnable))
|
|
||||||
|
|
||||||
super().__init__(branches=_branches, default=default_)
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_lc_serializable(cls) -> bool:
|
|
||||||
"""RunnableBranch is serializable if all its branches are serializable."""
|
|
||||||
return True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
|
||||||
"""The namespace of a RunnableBranch is the namespace of its default branch."""
|
|
||||||
return cls.__module__.split(".")[:-1]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_schema(self) -> type[BaseModel]:
|
|
||||||
runnables = (
|
|
||||||
[self.default]
|
|
||||||
+ [r for _, r in self.branches]
|
|
||||||
+ [r for r, _ in self.branches]
|
|
||||||
)
|
|
||||||
|
|
||||||
for runnable in runnables:
|
|
||||||
if runnable.input_schema.schema().get("type") is not None:
|
|
||||||
return runnable.input_schema
|
|
||||||
|
|
||||||
return super().input_schema
|
|
||||||
|
|
||||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
|
||||||
"""First evaluates the condition, then delegate to true or false branch."""
|
|
||||||
config = ensure_config(config)
|
|
||||||
callback_manager = get_callback_manager_for_config(config)
|
|
||||||
run_manager = callback_manager.on_chain_start(
|
|
||||||
dumpd(self),
|
|
||||||
input,
|
|
||||||
name=config.get("run_name"),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
for idx, branch in enumerate(self.branches):
|
|
||||||
condition, runnable = branch
|
|
||||||
|
|
||||||
expression_value = condition.invoke(
|
|
||||||
input,
|
|
||||||
config=patch_config(
|
|
||||||
config,
|
|
||||||
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if expression_value:
|
|
||||||
output = runnable.invoke(
|
|
||||||
input,
|
|
||||||
config=patch_config(
|
|
||||||
config,
|
|
||||||
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
output = self.default.invoke(
|
|
||||||
input,
|
|
||||||
config=patch_config(
|
|
||||||
config, callbacks=run_manager.get_child(tag="branch:default")
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
run_manager.on_chain_error(e)
|
|
||||||
raise
|
|
||||||
run_manager.on_chain_end(dumpd(output))
|
|
||||||
return output
|
|
||||||
|
|
||||||
async def ainvoke(
|
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
|
||||||
) -> Output:
|
|
||||||
"""Async version of invoke."""
|
|
||||||
config = ensure_config(config)
|
|
||||||
callback_manager = get_callback_manager_for_config(config)
|
|
||||||
run_manager = callback_manager.on_chain_start(
|
|
||||||
dumpd(self),
|
|
||||||
input,
|
|
||||||
name=config.get("run_name"),
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
for idx, branch in enumerate(self.branches):
|
|
||||||
condition, runnable = branch
|
|
||||||
|
|
||||||
expression_value = await condition.ainvoke(
|
|
||||||
input,
|
|
||||||
config=patch_config(
|
|
||||||
config,
|
|
||||||
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if expression_value:
|
|
||||||
output = await runnable.ainvoke(
|
|
||||||
input,
|
|
||||||
config=patch_config(
|
|
||||||
config,
|
|
||||||
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
|
|
||||||
),
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
output = await self.default.ainvoke(
|
|
||||||
input,
|
|
||||||
config=patch_config(
|
|
||||||
config, callbacks=run_manager.get_child(tag="branch:default")
|
|
||||||
),
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
run_manager.on_chain_error(e)
|
|
||||||
raise
|
|
||||||
run_manager.on_chain_end(dumpd(output))
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||||
"""
|
|
||||||
A Runnable that can fallback to other Runnables if it fails.
|
|
||||||
"""
|
|
||||||
|
|
||||||
runnable: Runnable[Input, Output]
|
|
||||||
fallbacks: Sequence[Runnable[Input, Output]]
|
|
||||||
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,)
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def InputType(self) -> Type[Input]:
|
|
||||||
return self.runnable.InputType
|
|
||||||
|
|
||||||
@property
|
|
||||||
def OutputType(self) -> Type[Output]:
|
|
||||||
return self.runnable.OutputType
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_schema(self) -> Type[BaseModel]:
|
|
||||||
return self.runnable.input_schema
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_schema(self) -> Type[BaseModel]:
|
|
||||||
return self.runnable.output_schema
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_lc_serializable(cls) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
|
||||||
return cls.__module__.split(".")[:-1]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def runnables(self) -> Iterator[Runnable[Input, Output]]:
|
|
||||||
yield self.runnable
|
|
||||||
yield from self.fallbacks
|
|
||||||
|
|
||||||
def invoke(
|
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
|
||||||
) -> Output:
|
|
||||||
# setup callbacks
|
|
||||||
config = ensure_config(config)
|
|
||||||
callback_manager = get_callback_manager_for_config(config)
|
|
||||||
# start the root run
|
|
||||||
run_manager = callback_manager.on_chain_start(
|
|
||||||
dumpd(self), input, name=config.get("run_name")
|
|
||||||
)
|
|
||||||
first_error = None
|
|
||||||
for runnable in self.runnables:
|
|
||||||
try:
|
|
||||||
output = runnable.invoke(
|
|
||||||
input,
|
|
||||||
patch_config(config, callbacks=run_manager.get_child()),
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
except self.exceptions_to_handle as e:
|
|
||||||
if first_error is None:
|
|
||||||
first_error = e
|
|
||||||
except BaseException as e:
|
|
||||||
run_manager.on_chain_error(e)
|
|
||||||
raise e
|
|
||||||
else:
|
|
||||||
run_manager.on_chain_end(output)
|
|
||||||
return output
|
|
||||||
if first_error is None:
|
|
||||||
raise ValueError("No error stored at end of fallbacks.")
|
|
||||||
run_manager.on_chain_error(first_error)
|
|
||||||
raise first_error
|
|
||||||
|
|
||||||
async def ainvoke(
|
|
||||||
self,
|
|
||||||
input: Input,
|
|
||||||
config: Optional[RunnableConfig] = None,
|
|
||||||
**kwargs: Optional[Any],
|
|
||||||
) -> Output:
|
|
||||||
# setup callbacks
|
|
||||||
config = ensure_config(config)
|
|
||||||
callback_manager = get_async_callback_manager_for_config(config)
|
|
||||||
# start the root run
|
|
||||||
run_manager = await callback_manager.on_chain_start(
|
|
||||||
dumpd(self), input, name=config.get("run_name")
|
|
||||||
)
|
|
||||||
|
|
||||||
first_error = None
|
|
||||||
for runnable in self.runnables:
|
|
||||||
try:
|
|
||||||
output = await runnable.ainvoke(
|
|
||||||
input,
|
|
||||||
patch_config(config, callbacks=run_manager.get_child()),
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
except self.exceptions_to_handle as e:
|
|
||||||
if first_error is None:
|
|
||||||
first_error = e
|
|
||||||
except BaseException as e:
|
|
||||||
await run_manager.on_chain_error(e)
|
|
||||||
raise e
|
|
||||||
else:
|
|
||||||
await run_manager.on_chain_end(output)
|
|
||||||
return output
|
|
||||||
if first_error is None:
|
|
||||||
raise ValueError("No error stored at end of fallbacks.")
|
|
||||||
await run_manager.on_chain_error(first_error)
|
|
||||||
raise first_error
|
|
||||||
|
|
||||||
def batch(
|
|
||||||
self,
|
|
||||||
inputs: List[Input],
|
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
|
||||||
*,
|
|
||||||
return_exceptions: bool = False,
|
|
||||||
**kwargs: Optional[Any],
|
|
||||||
) -> List[Output]:
|
|
||||||
from langchain.callbacks.manager import CallbackManager
|
|
||||||
|
|
||||||
if return_exceptions:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
if not inputs:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# setup callbacks
|
|
||||||
configs = get_config_list(config, len(inputs))
|
|
||||||
callback_managers = [
|
|
||||||
CallbackManager.configure(
|
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
local_callbacks=None,
|
|
||||||
verbose=False,
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
local_tags=None,
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
local_metadata=None,
|
|
||||||
)
|
|
||||||
for config in configs
|
|
||||||
]
|
|
||||||
# start the root runs, one per input
|
|
||||||
run_managers = [
|
|
||||||
cm.on_chain_start(
|
|
||||||
dumpd(self),
|
|
||||||
input if isinstance(input, dict) else {"input": input},
|
|
||||||
name=config.get("run_name"),
|
|
||||||
)
|
|
||||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
|
||||||
]
|
|
||||||
|
|
||||||
first_error = None
|
|
||||||
for runnable in self.runnables:
|
|
||||||
try:
|
|
||||||
outputs = runnable.batch(
|
|
||||||
inputs,
|
|
||||||
[
|
|
||||||
# each step a child run of the corresponding root run
|
|
||||||
patch_config(config, callbacks=rm.get_child())
|
|
||||||
for rm, config in zip(run_managers, configs)
|
|
||||||
],
|
|
||||||
return_exceptions=return_exceptions,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
except self.exceptions_to_handle as e:
|
|
||||||
if first_error is None:
|
|
||||||
first_error = e
|
|
||||||
except BaseException as e:
|
|
||||||
for rm in run_managers:
|
|
||||||
rm.on_chain_error(e)
|
|
||||||
raise e
|
|
||||||
else:
|
|
||||||
for rm, output in zip(run_managers, outputs):
|
|
||||||
rm.on_chain_end(output)
|
|
||||||
return outputs
|
|
||||||
if first_error is None:
|
|
||||||
raise ValueError("No error stored at end of fallbacks.")
|
|
||||||
for rm in run_managers:
|
|
||||||
rm.on_chain_error(first_error)
|
|
||||||
raise first_error
|
|
||||||
|
|
||||||
async def abatch(
|
|
||||||
self,
|
|
||||||
inputs: List[Input],
|
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
|
||||||
*,
|
|
||||||
return_exceptions: bool = False,
|
|
||||||
**kwargs: Optional[Any],
|
|
||||||
) -> List[Output]:
|
|
||||||
from langchain.callbacks.manager import AsyncCallbackManager
|
|
||||||
|
|
||||||
if return_exceptions:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
if not inputs:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# setup callbacks
|
|
||||||
configs = get_config_list(config, len(inputs))
|
|
||||||
callback_managers = [
|
|
||||||
AsyncCallbackManager.configure(
|
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
local_callbacks=None,
|
|
||||||
verbose=False,
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
local_tags=None,
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
local_metadata=None,
|
|
||||||
)
|
|
||||||
for config in configs
|
|
||||||
]
|
|
||||||
# start the root runs, one per input
|
|
||||||
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
|
||||||
*(
|
|
||||||
cm.on_chain_start(
|
|
||||||
dumpd(self),
|
|
||||||
input,
|
|
||||||
name=config.get("run_name"),
|
|
||||||
)
|
|
||||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
first_error = None
|
|
||||||
for runnable in self.runnables:
|
|
||||||
try:
|
|
||||||
outputs = await runnable.abatch(
|
|
||||||
inputs,
|
|
||||||
[
|
|
||||||
# each step a child run of the corresponding root run
|
|
||||||
patch_config(config, callbacks=rm.get_child())
|
|
||||||
for rm, config in zip(run_managers, configs)
|
|
||||||
],
|
|
||||||
return_exceptions=return_exceptions,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
except self.exceptions_to_handle as e:
|
|
||||||
if first_error is None:
|
|
||||||
first_error = e
|
|
||||||
except BaseException as e:
|
|
||||||
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
|
|
||||||
else:
|
|
||||||
await asyncio.gather(
|
|
||||||
*(
|
|
||||||
rm.on_chain_end(output)
|
|
||||||
for rm, output in zip(run_managers, outputs)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return outputs
|
|
||||||
if first_error is None:
|
|
||||||
raise ValueError("No error stored at end of fallbacks.")
|
|
||||||
await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers))
|
|
||||||
raise first_error
|
|
||||||
|
|
||||||
|
|
||||||
class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|
||||||
"""
|
"""
|
||||||
A sequence of runnables, where the output of each is the input of the next.
|
A sequence of runnables, where the output of each is the input of the next.
|
||||||
"""
|
"""
|
||||||
@ -1749,7 +1321,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
|
||||||
"""
|
"""
|
||||||
A runnable that runs a mapping of runnables in parallel,
|
A runnable that runs a mapping of runnables in parallel,
|
||||||
and returns a mapping of their outputs.
|
and returns a mapping of their outputs.
|
||||||
@ -1799,7 +1371,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
return create_model( # type: ignore[call-overload]
|
return create_model( # type: ignore[call-overload]
|
||||||
"RunnableMapInput",
|
"RunnableMapInput",
|
||||||
**{
|
**{
|
||||||
k: (v.type_, v.default)
|
k: (v.annotation, v.default)
|
||||||
for step in self.steps.values()
|
for step in self.steps.values()
|
||||||
for k, v in step.input_schema.__fields__.items()
|
for k, v in step.input_schema.__fields__.items()
|
||||||
if k != "__root__"
|
if k != "__root__"
|
||||||
@ -2374,7 +1946,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
return await super().ainvoke(input, config)
|
return await super().ainvoke(input, config)
|
||||||
|
|
||||||
|
|
||||||
class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
|
class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
|
||||||
"""
|
"""
|
||||||
A runnable that delegates calls to another runnable
|
A runnable that delegates calls to another runnable
|
||||||
with each element of the input sequence.
|
with each element of the input sequence.
|
||||||
@ -2413,6 +1985,11 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def config_schema(
|
||||||
|
self, *, include: Optional[Sequence[str]] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
|
return self.bound.config_schema(include=include)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_lc_serializable(cls) -> bool:
|
def is_lc_serializable(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
@ -2455,7 +2032,7 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
|
|||||||
return await self._acall_with_config(self._ainvoke, input, config)
|
return await self._acall_with_config(self._ainvoke, input, config)
|
||||||
|
|
||||||
|
|
||||||
class RunnableBinding(Serializable, Runnable[Input, Output]):
|
class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||||
"""
|
"""
|
||||||
A runnable that delegates calls to another runnable with a set of kwargs.
|
A runnable that delegates calls to another runnable with a set of kwargs.
|
||||||
"""
|
"""
|
||||||
@ -2485,6 +2062,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
def output_schema(self) -> Type[BaseModel]:
|
def output_schema(self) -> Type[BaseModel]:
|
||||||
return self.bound.output_schema
|
return self.bound.output_schema
|
||||||
|
|
||||||
|
def config_schema(
|
||||||
|
self, *, include: Optional[Sequence[str]] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
|
return self.bound.config_schema(include=include)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_lc_serializable(cls) -> bool:
|
def is_lc_serializable(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
|
235
libs/langchain/langchain/schema/runnable/branch.py
Normal file
235
libs/langchain/langchain/schema/runnable/branch.py
Normal file
@ -0,0 +1,235 @@
|
|||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain.load.dump import dumpd
|
||||||
|
from langchain.pydantic_v1 import BaseModel
|
||||||
|
from langchain.schema.runnable.base import (
|
||||||
|
Runnable,
|
||||||
|
RunnableLike,
|
||||||
|
RunnableSerializable,
|
||||||
|
coerce_to_runnable,
|
||||||
|
)
|
||||||
|
from langchain.schema.runnable.config import (
|
||||||
|
RunnableConfig,
|
||||||
|
ensure_config,
|
||||||
|
get_callback_manager_for_config,
|
||||||
|
patch_config,
|
||||||
|
)
|
||||||
|
from langchain.schema.runnable.utils import Input, Output
|
||||||
|
|
||||||
|
|
||||||
|
class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||||
|
"""A Runnable that selects which branch to run based on a condition.
|
||||||
|
|
||||||
|
The runnable is initialized with a list of (condition, runnable) pairs and
|
||||||
|
a default branch.
|
||||||
|
|
||||||
|
When operating on an input, the first condition that evaluates to True is
|
||||||
|
selected, and the corresponding runnable is run on the input.
|
||||||
|
|
||||||
|
If no condition evaluates to True, the default branch is run on the input.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.schema.runnable import RunnableBranch
|
||||||
|
|
||||||
|
branch = RunnableBranch(
|
||||||
|
(lambda x: isinstance(x, str), lambda x: x.upper()),
|
||||||
|
(lambda x: isinstance(x, int), lambda x: x + 1),
|
||||||
|
(lambda x: isinstance(x, float), lambda x: x * 2),
|
||||||
|
lambda x: "goodbye",
|
||||||
|
)
|
||||||
|
|
||||||
|
branch.invoke("hello") # "HELLO"
|
||||||
|
branch.invoke(None) # "goodbye"
|
||||||
|
"""
|
||||||
|
|
||||||
|
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]]
|
||||||
|
default: Runnable[Input, Output]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*branches: Union[
|
||||||
|
Tuple[
|
||||||
|
Union[
|
||||||
|
Runnable[Input, bool],
|
||||||
|
Callable[[Input], bool],
|
||||||
|
Callable[[Input], Awaitable[bool]],
|
||||||
|
],
|
||||||
|
RunnableLike,
|
||||||
|
],
|
||||||
|
RunnableLike, # To accommodate the default branch
|
||||||
|
],
|
||||||
|
) -> None:
|
||||||
|
"""A Runnable that runs one of two branches based on a condition."""
|
||||||
|
if len(branches) < 2:
|
||||||
|
raise ValueError("RunnableBranch requires at least two branches")
|
||||||
|
|
||||||
|
default = branches[-1]
|
||||||
|
|
||||||
|
if not isinstance(
|
||||||
|
default, (Runnable, Callable, Mapping) # type: ignore[arg-type]
|
||||||
|
):
|
||||||
|
raise TypeError(
|
||||||
|
"RunnableBranch default must be runnable, callable or mapping."
|
||||||
|
)
|
||||||
|
|
||||||
|
default_ = cast(
|
||||||
|
Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default))
|
||||||
|
)
|
||||||
|
|
||||||
|
_branches = []
|
||||||
|
|
||||||
|
for branch in branches[:-1]:
|
||||||
|
if not isinstance(branch, (tuple, list)): # type: ignore[arg-type]
|
||||||
|
raise TypeError(
|
||||||
|
f"RunnableBranch branches must be "
|
||||||
|
f"tuples or lists, not {type(branch)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not len(branch) == 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"RunnableBranch branches must be "
|
||||||
|
f"tuples or lists of length 2, not {len(branch)}"
|
||||||
|
)
|
||||||
|
condition, runnable = branch
|
||||||
|
condition = cast(Runnable[Input, bool], coerce_to_runnable(condition))
|
||||||
|
runnable = coerce_to_runnable(runnable)
|
||||||
|
_branches.append((condition, runnable))
|
||||||
|
|
||||||
|
super().__init__(branches=_branches, default=default_)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
"""RunnableBranch is serializable if all its branches are serializable."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_lc_namespace(cls) -> List[str]:
|
||||||
|
"""The namespace of a RunnableBranch is the namespace of its default branch."""
|
||||||
|
return cls.__module__.split(".")[:-1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_schema(self) -> Type[BaseModel]:
|
||||||
|
runnables = (
|
||||||
|
[self.default]
|
||||||
|
+ [r for _, r in self.branches]
|
||||||
|
+ [r for r, _ in self.branches]
|
||||||
|
)
|
||||||
|
|
||||||
|
for runnable in runnables:
|
||||||
|
if runnable.input_schema.schema().get("type") is not None:
|
||||||
|
return runnable.input_schema
|
||||||
|
|
||||||
|
return super().input_schema
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
|
) -> Output:
|
||||||
|
"""First evaluates the condition, then delegate to true or false branch."""
|
||||||
|
config = ensure_config(config)
|
||||||
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
|
run_manager = callback_manager.on_chain_start(
|
||||||
|
dumpd(self),
|
||||||
|
input,
|
||||||
|
name=config.get("run_name"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for idx, branch in enumerate(self.branches):
|
||||||
|
condition, runnable = branch
|
||||||
|
|
||||||
|
expression_value = condition.invoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if expression_value:
|
||||||
|
output = runnable.invoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
output = self.default.invoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config, callbacks=run_manager.get_child(tag="branch:default")
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
run_manager.on_chain_error(e)
|
||||||
|
raise
|
||||||
|
run_manager.on_chain_end(dumpd(output))
|
||||||
|
return output
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
|
) -> Output:
|
||||||
|
"""Async version of invoke."""
|
||||||
|
config = ensure_config(config)
|
||||||
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
|
run_manager = callback_manager.on_chain_start(
|
||||||
|
dumpd(self),
|
||||||
|
input,
|
||||||
|
name=config.get("run_name"),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
for idx, branch in enumerate(self.branches):
|
||||||
|
condition, runnable = branch
|
||||||
|
|
||||||
|
expression_value = await condition.ainvoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if expression_value:
|
||||||
|
output = await runnable.ainvoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
output = await self.default.ainvoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config, callbacks=run_manager.get_child(tag="branch:default")
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
run_manager.on_chain_error(e)
|
||||||
|
raise
|
||||||
|
run_manager.on_chain_end(dumpd(output))
|
||||||
|
return output
|
286
libs/langchain/langchain/schema/runnable/fallbacks.py
Normal file
286
libs/langchain/langchain/schema/runnable/fallbacks.py
Normal file
@ -0,0 +1,286 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain.load.dump import dumpd
|
||||||
|
from langchain.pydantic_v1 import BaseModel
|
||||||
|
from langchain.schema.runnable.base import Runnable, RunnableSerializable
|
||||||
|
from langchain.schema.runnable.config import (
|
||||||
|
RunnableConfig,
|
||||||
|
ensure_config,
|
||||||
|
get_async_callback_manager_for_config,
|
||||||
|
get_callback_manager_for_config,
|
||||||
|
get_config_list,
|
||||||
|
patch_config,
|
||||||
|
)
|
||||||
|
from langchain.schema.runnable.utils import Input, Output
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun
|
||||||
|
|
||||||
|
|
||||||
|
class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||||
|
"""
|
||||||
|
A Runnable that can fallback to other Runnables if it fails.
|
||||||
|
"""
|
||||||
|
|
||||||
|
runnable: Runnable[Input, Output]
|
||||||
|
fallbacks: Sequence[Runnable[Input, Output]]
|
||||||
|
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def InputType(self) -> Type[Input]:
|
||||||
|
return self.runnable.InputType
|
||||||
|
|
||||||
|
@property
|
||||||
|
def OutputType(self) -> Type[Output]:
|
||||||
|
return self.runnable.OutputType
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_schema(self) -> Type[BaseModel]:
|
||||||
|
return self.runnable.input_schema
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_schema(self) -> Type[BaseModel]:
|
||||||
|
return self.runnable.output_schema
|
||||||
|
|
||||||
|
def config_schema(
|
||||||
|
self, *, include: Optional[Sequence[str]] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
|
return self.runnable.config_schema(include=include)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_lc_namespace(cls) -> List[str]:
|
||||||
|
return cls.__module__.split(".")[:-1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def runnables(self) -> Iterator[Runnable[Input, Output]]:
|
||||||
|
yield self.runnable
|
||||||
|
yield from self.fallbacks
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
|
) -> Output:
|
||||||
|
# setup callbacks
|
||||||
|
config = ensure_config(config)
|
||||||
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
|
# start the root run
|
||||||
|
run_manager = callback_manager.on_chain_start(
|
||||||
|
dumpd(self), input, name=config.get("run_name")
|
||||||
|
)
|
||||||
|
first_error = None
|
||||||
|
for runnable in self.runnables:
|
||||||
|
try:
|
||||||
|
output = runnable.invoke(
|
||||||
|
input,
|
||||||
|
patch_config(config, callbacks=run_manager.get_child()),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except self.exceptions_to_handle as e:
|
||||||
|
if first_error is None:
|
||||||
|
first_error = e
|
||||||
|
except BaseException as e:
|
||||||
|
run_manager.on_chain_error(e)
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
run_manager.on_chain_end(output)
|
||||||
|
return output
|
||||||
|
if first_error is None:
|
||||||
|
raise ValueError("No error stored at end of fallbacks.")
|
||||||
|
run_manager.on_chain_error(first_error)
|
||||||
|
raise first_error
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self,
|
||||||
|
input: Input,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> Output:
|
||||||
|
# setup callbacks
|
||||||
|
config = ensure_config(config)
|
||||||
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
|
# start the root run
|
||||||
|
run_manager = await callback_manager.on_chain_start(
|
||||||
|
dumpd(self), input, name=config.get("run_name")
|
||||||
|
)
|
||||||
|
|
||||||
|
first_error = None
|
||||||
|
for runnable in self.runnables:
|
||||||
|
try:
|
||||||
|
output = await runnable.ainvoke(
|
||||||
|
input,
|
||||||
|
patch_config(config, callbacks=run_manager.get_child()),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except self.exceptions_to_handle as e:
|
||||||
|
if first_error is None:
|
||||||
|
first_error = e
|
||||||
|
except BaseException as e:
|
||||||
|
await run_manager.on_chain_error(e)
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
await run_manager.on_chain_end(output)
|
||||||
|
return output
|
||||||
|
if first_error is None:
|
||||||
|
raise ValueError("No error stored at end of fallbacks.")
|
||||||
|
await run_manager.on_chain_error(first_error)
|
||||||
|
raise first_error
|
||||||
|
|
||||||
|
def batch(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> List[Output]:
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
|
if return_exceptions:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
if not inputs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# setup callbacks
|
||||||
|
configs = get_config_list(config, len(inputs))
|
||||||
|
callback_managers = [
|
||||||
|
CallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
local_callbacks=None,
|
||||||
|
verbose=False,
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
local_tags=None,
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
local_metadata=None,
|
||||||
|
)
|
||||||
|
for config in configs
|
||||||
|
]
|
||||||
|
# start the root runs, one per input
|
||||||
|
run_managers = [
|
||||||
|
cm.on_chain_start(
|
||||||
|
dumpd(self),
|
||||||
|
input if isinstance(input, dict) else {"input": input},
|
||||||
|
name=config.get("run_name"),
|
||||||
|
)
|
||||||
|
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||||
|
]
|
||||||
|
|
||||||
|
first_error = None
|
||||||
|
for runnable in self.runnables:
|
||||||
|
try:
|
||||||
|
outputs = runnable.batch(
|
||||||
|
inputs,
|
||||||
|
[
|
||||||
|
# each step a child run of the corresponding root run
|
||||||
|
patch_config(config, callbacks=rm.get_child())
|
||||||
|
for rm, config in zip(run_managers, configs)
|
||||||
|
],
|
||||||
|
return_exceptions=return_exceptions,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except self.exceptions_to_handle as e:
|
||||||
|
if first_error is None:
|
||||||
|
first_error = e
|
||||||
|
except BaseException as e:
|
||||||
|
for rm in run_managers:
|
||||||
|
rm.on_chain_error(e)
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
for rm, output in zip(run_managers, outputs):
|
||||||
|
rm.on_chain_end(output)
|
||||||
|
return outputs
|
||||||
|
if first_error is None:
|
||||||
|
raise ValueError("No error stored at end of fallbacks.")
|
||||||
|
for rm in run_managers:
|
||||||
|
rm.on_chain_error(first_error)
|
||||||
|
raise first_error
|
||||||
|
|
||||||
|
async def abatch(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> List[Output]:
|
||||||
|
from langchain.callbacks.manager import AsyncCallbackManager
|
||||||
|
|
||||||
|
if return_exceptions:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
if not inputs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# setup callbacks
|
||||||
|
configs = get_config_list(config, len(inputs))
|
||||||
|
callback_managers = [
|
||||||
|
AsyncCallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
local_callbacks=None,
|
||||||
|
verbose=False,
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
local_tags=None,
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
local_metadata=None,
|
||||||
|
)
|
||||||
|
for config in configs
|
||||||
|
]
|
||||||
|
# start the root runs, one per input
|
||||||
|
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||||
|
*(
|
||||||
|
cm.on_chain_start(
|
||||||
|
dumpd(self),
|
||||||
|
input,
|
||||||
|
name=config.get("run_name"),
|
||||||
|
)
|
||||||
|
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
first_error = None
|
||||||
|
for runnable in self.runnables:
|
||||||
|
try:
|
||||||
|
outputs = await runnable.abatch(
|
||||||
|
inputs,
|
||||||
|
[
|
||||||
|
# each step a child run of the corresponding root run
|
||||||
|
patch_config(config, callbacks=rm.get_child())
|
||||||
|
for rm, config in zip(run_managers, configs)
|
||||||
|
],
|
||||||
|
return_exceptions=return_exceptions,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except self.exceptions_to_handle as e:
|
||||||
|
if first_error is None:
|
||||||
|
first_error = e
|
||||||
|
except BaseException as e:
|
||||||
|
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
|
||||||
|
else:
|
||||||
|
await asyncio.gather(
|
||||||
|
*(
|
||||||
|
rm.on_chain_end(output)
|
||||||
|
for rm, output in zip(run_managers, outputs)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return outputs
|
||||||
|
if first_error is None:
|
||||||
|
raise ValueError("No error stored at end of fallbacks.")
|
||||||
|
await asyncio.gather(*(rm.on_chain_error(first_error) for rm in run_managers))
|
||||||
|
raise first_error
|
@ -16,9 +16,13 @@ from typing import (
|
|||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
|
||||||
from langchain.pydantic_v1 import BaseModel, create_model
|
from langchain.pydantic_v1 import BaseModel, create_model
|
||||||
from langchain.schema.runnable.base import Input, Runnable, RunnableMap
|
from langchain.schema.runnable.base import (
|
||||||
|
Input,
|
||||||
|
Runnable,
|
||||||
|
RunnableMap,
|
||||||
|
RunnableSerializable,
|
||||||
|
)
|
||||||
from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config
|
from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config
|
||||||
from langchain.schema.runnable.utils import AddableDict
|
from langchain.schema.runnable.utils import AddableDict
|
||||||
from langchain.utils.aiter import atee, py_anext
|
from langchain.utils.aiter import atee, py_anext
|
||||||
@ -33,7 +37,7 @@ async def aidentity(x: Input) -> Input:
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
class RunnablePassthrough(RunnableSerializable[Input, Input]):
|
||||||
"""
|
"""
|
||||||
A runnable that passes through the input.
|
A runnable that passes through the input.
|
||||||
"""
|
"""
|
||||||
@ -109,7 +113,7 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
class RunnableAssign(Serializable, Runnable[Dict[str, Any], Dict[str, Any]]):
|
class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||||
"""
|
"""
|
||||||
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
|
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
|
||||||
"""
|
"""
|
||||||
|
@ -14,8 +14,13 @@ from typing import (
|
|||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.schema.runnable.base import (
|
||||||
from langchain.schema.runnable.base import Input, Output, Runnable, coerce_to_runnable
|
Input,
|
||||||
|
Output,
|
||||||
|
Runnable,
|
||||||
|
RunnableSerializable,
|
||||||
|
coerce_to_runnable,
|
||||||
|
)
|
||||||
from langchain.schema.runnable.config import (
|
from langchain.schema.runnable.config import (
|
||||||
RunnableConfig,
|
RunnableConfig,
|
||||||
get_config_list,
|
get_config_list,
|
||||||
@ -36,7 +41,7 @@ class RouterInput(TypedDict):
|
|||||||
input: Any
|
input: Any
|
||||||
|
|
||||||
|
|
||||||
class RouterRunnable(Serializable, Runnable[RouterInput, Output]):
|
class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
||||||
"""
|
"""
|
||||||
A runnable that routes to a set of runnables based on Input['key'].
|
A runnable that routes to a set of runnables based on Input['key'].
|
||||||
Returns the output of the selected runnable.
|
Returns the output of the selected runnable.
|
||||||
|
@ -17,6 +17,7 @@ from langchain.callbacks.manager import (
|
|||||||
CallbackManagerForToolRun,
|
CallbackManagerForToolRun,
|
||||||
Callbacks,
|
Callbacks,
|
||||||
)
|
)
|
||||||
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.pydantic_v1 import (
|
from langchain.pydantic_v1 import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
Extra,
|
Extra,
|
||||||
@ -25,7 +26,7 @@ from langchain.pydantic_v1 import (
|
|||||||
root_validator,
|
root_validator,
|
||||||
validate_arguments,
|
validate_arguments,
|
||||||
)
|
)
|
||||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
from langchain.schema.runnable import Runnable, RunnableConfig, RunnableSerializable
|
||||||
|
|
||||||
|
|
||||||
class SchemaAnnotationError(TypeError):
|
class SchemaAnnotationError(TypeError):
|
||||||
@ -97,7 +98,7 @@ class ToolException(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BaseTool(BaseModel, Runnable[Union[str, Dict], Any]):
|
class BaseTool(RunnableSerializable[Union[str, Dict], Any]):
|
||||||
"""Interface LangChain tools must implement."""
|
"""Interface LangChain tools must implement."""
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||||
@ -165,10 +166,9 @@ class ChildTool(BaseTool):
|
|||||||
] = False
|
] = False
|
||||||
"""Handle the content of the ToolException thrown."""
|
"""Handle the content of the ToolException thrown."""
|
||||||
|
|
||||||
class Config:
|
class Config(Serializable.Config):
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
"""Tools for interacting with Spark SQL."""
|
"""Tools for interacting with Spark SQL."""
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
from langchain.pydantic_v1 import BaseModel, Field, root_validator
|
||||||
|
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
@ -21,13 +21,8 @@ class BaseSparkSQLTool(BaseModel):
|
|||||||
|
|
||||||
db: SparkSQL = Field(exclude=True)
|
db: SparkSQL = Field(exclude=True)
|
||||||
|
|
||||||
# Override BaseTool.Config to appease mypy
|
|
||||||
# See https://github.com/pydantic/pydantic/issues/4173
|
|
||||||
class Config(BaseTool.Config):
|
class Config(BaseTool.Config):
|
||||||
"""Configuration for this pydantic object."""
|
pass
|
||||||
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
extra = Extra.forbid
|
|
||||||
|
|
||||||
|
|
||||||
class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool):
|
class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool):
|
||||||
|
@ -21,13 +21,8 @@ class BaseSQLDatabaseTool(BaseModel):
|
|||||||
|
|
||||||
db: SQLDatabase = Field(exclude=True)
|
db: SQLDatabase = Field(exclude=True)
|
||||||
|
|
||||||
# Override BaseTool.Config to appease mypy
|
|
||||||
# See https://github.com/pydantic/pydantic/issues/4173
|
|
||||||
class Config(BaseTool.Config):
|
class Config(BaseTool.Config):
|
||||||
"""Configuration for this pydantic object."""
|
pass
|
||||||
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
extra = Extra.forbid
|
|
||||||
|
|
||||||
|
|
||||||
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
|
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||||
|
@ -18,9 +18,7 @@ class BaseVectorStoreTool(BaseModel):
|
|||||||
llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
|
llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
|
||||||
|
|
||||||
class Config(BaseTool.Config):
|
class Config(BaseTool.Config):
|
||||||
"""Configuration for this pydantic object."""
|
pass
|
||||||
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
|
|
||||||
def _create_description_from_template(values: Dict[str, Any]) -> Dict[str, Any]:
|
def _create_description_from_template(values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
Loading…
Reference in New Issue
Block a user