mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-10 15:06:18 +00:00
Create new RunnableSerializable class in preparation for configurable runnables
- Also move RunnableBranch to its own file
This commit is contained in:
parent
33eb5f8300
commit
52e5a8b43e
@ -21,7 +21,6 @@ from langchain.callbacks.manager import (
|
|||||||
Callbacks,
|
Callbacks,
|
||||||
)
|
)
|
||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd
|
||||||
from langchain.load.serializable import Serializable
|
|
||||||
from langchain.pydantic_v1 import (
|
from langchain.pydantic_v1 import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
Field,
|
Field,
|
||||||
@ -30,7 +29,7 @@ from langchain.pydantic_v1 import (
|
|||||||
validator,
|
validator,
|
||||||
)
|
)
|
||||||
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
||||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
from langchain.schema.runnable import RunnableConfig, RunnableSerializable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -39,7 +38,7 @@ def _get_verbosity() -> bool:
|
|||||||
return langchain.verbose
|
return langchain.verbose
|
||||||
|
|
||||||
|
|
||||||
class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||||
"""Abstract base class for creating structured sequences of calls to components.
|
"""Abstract base class for creating structured sequences of calls to components.
|
||||||
|
|
||||||
Chains should be used to encode a sequence of calls to components like
|
Chains should be used to encode a sequence of calls to components like
|
||||||
|
@ -14,7 +14,7 @@ from langchain.schema.runnable import RunnableConfig
|
|||||||
class FakeListLLM(LLM):
|
class FakeListLLM(LLM):
|
||||||
"""Fake LLM for testing purposes."""
|
"""Fake LLM for testing purposes."""
|
||||||
|
|
||||||
responses: List
|
responses: List[str]
|
||||||
sleep: Optional[float] = None
|
sleep: Optional[float] = None
|
||||||
i: int = 0
|
i: int = 0
|
||||||
|
|
||||||
|
@ -15,11 +15,10 @@ from typing import (
|
|||||||
|
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
|
||||||
from langchain.schema.messages import AnyMessage, BaseMessage, get_buffer_string
|
from langchain.schema.messages import AnyMessage, BaseMessage, get_buffer_string
|
||||||
from langchain.schema.output import LLMResult
|
from langchain.schema.output import LLMResult
|
||||||
from langchain.schema.prompt import PromptValue
|
from langchain.schema.prompt import PromptValue
|
||||||
from langchain.schema.runnable import Runnable
|
from langchain.schema.runnable import RunnableSerializable
|
||||||
from langchain.utils import get_pydantic_field_names
|
from langchain.utils import get_pydantic_field_names
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -54,7 +53,7 @@ LanguageModelOutput = TypeVar("LanguageModelOutput")
|
|||||||
|
|
||||||
|
|
||||||
class BaseLanguageModel(
|
class BaseLanguageModel(
|
||||||
Serializable, Runnable[LanguageModelInput, LanguageModelOutput], ABC
|
RunnableSerializable[LanguageModelInput, LanguageModelOutput], ABC
|
||||||
):
|
):
|
||||||
"""Abstract base class for interfacing with language models.
|
"""Abstract base class for interfacing with language models.
|
||||||
|
|
||||||
|
@ -16,7 +16,6 @@ from typing import (
|
|||||||
|
|
||||||
from typing_extensions import get_args
|
from typing_extensions import get_args
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
|
||||||
from langchain.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk
|
from langchain.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk
|
||||||
from langchain.schema.output import (
|
from langchain.schema.output import (
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
@ -25,12 +24,12 @@ from langchain.schema.output import (
|
|||||||
GenerationChunk,
|
GenerationChunk,
|
||||||
)
|
)
|
||||||
from langchain.schema.prompt import PromptValue
|
from langchain.schema.prompt import PromptValue
|
||||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
from langchain.schema.runnable import RunnableConfig, RunnableSerializable
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class BaseLLMOutputParser(Serializable, Generic[T], ABC):
|
class BaseLLMOutputParser(Generic[T], ABC):
|
||||||
"""Abstract base class for parsing the outputs of a model."""
|
"""Abstract base class for parsing the outputs of a model."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -63,7 +62,7 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC):
|
|||||||
|
|
||||||
|
|
||||||
class BaseGenerationOutputParser(
|
class BaseGenerationOutputParser(
|
||||||
BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
|
||||||
):
|
):
|
||||||
"""Base class to parse the output of an LLM call."""
|
"""Base class to parse the output of an LLM call."""
|
||||||
|
|
||||||
@ -121,7 +120,9 @@ class BaseGenerationOutputParser(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]):
|
class BaseOutputParser(
|
||||||
|
BaseLLMOutputParser, RunnableSerializable[Union[str, BaseMessage], T]
|
||||||
|
):
|
||||||
"""Base class to parse the output of an LLM call.
|
"""Base class to parse the output of an LLM call.
|
||||||
|
|
||||||
Output parsers help structure language model responses.
|
Output parsers help structure language model responses.
|
||||||
|
@ -7,15 +7,14 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Union
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
|
||||||
from langchain.pydantic_v1 import BaseModel, Field, create_model, root_validator
|
from langchain.pydantic_v1 import BaseModel, Field, create_model, root_validator
|
||||||
from langchain.schema.document import Document
|
from langchain.schema.document import Document
|
||||||
from langchain.schema.output_parser import BaseOutputParser
|
from langchain.schema.output_parser import BaseOutputParser
|
||||||
from langchain.schema.prompt import PromptValue
|
from langchain.schema.prompt import PromptValue
|
||||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
from langchain.schema.runnable import RunnableConfig, RunnableSerializable
|
||||||
|
|
||||||
|
|
||||||
class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC):
|
class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
||||||
"""Base class for all prompt templates, returning a prompt."""
|
"""Base class for all prompt templates, returning a prompt."""
|
||||||
|
|
||||||
input_variables: List[str]
|
input_variables: List[str]
|
||||||
|
@ -6,9 +6,8 @@ from inspect import signature
|
|||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd
|
||||||
from langchain.load.serializable import Serializable
|
|
||||||
from langchain.schema.document import Document
|
from langchain.schema.document import Document
|
||||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
from langchain.schema.runnable import RunnableConfig, RunnableSerializable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
@ -18,7 +17,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
class BaseRetriever(RunnableSerializable[str, List[Document]], ABC):
|
||||||
"""Abstract base class for a Document retrieval system.
|
"""Abstract base class for a Document retrieval system.
|
||||||
|
|
||||||
A retrieval system is defined as something that can take string queries and return
|
A retrieval system is defined as something that can take string queries and return
|
||||||
|
@ -2,12 +2,13 @@ from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar
|
|||||||
from langchain.schema.runnable.base import (
|
from langchain.schema.runnable.base import (
|
||||||
Runnable,
|
Runnable,
|
||||||
RunnableBinding,
|
RunnableBinding,
|
||||||
RunnableBranch,
|
|
||||||
RunnableLambda,
|
RunnableLambda,
|
||||||
RunnableMap,
|
RunnableMap,
|
||||||
RunnableSequence,
|
RunnableSequence,
|
||||||
|
RunnableSerializable,
|
||||||
RunnableWithFallbacks,
|
RunnableWithFallbacks,
|
||||||
)
|
)
|
||||||
|
from langchain.schema.runnable.branch import RunnableBranch
|
||||||
from langchain.schema.runnable.config import RunnableConfig, patch_config
|
from langchain.schema.runnable.config import RunnableConfig, patch_config
|
||||||
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
||||||
from langchain.schema.runnable.router import RouterInput, RouterRunnable
|
from langchain.schema.runnable.router import RouterInput, RouterRunnable
|
||||||
@ -19,6 +20,7 @@ __all__ = [
|
|||||||
"RouterInput",
|
"RouterInput",
|
||||||
"RouterRunnable",
|
"RouterRunnable",
|
||||||
"Runnable",
|
"Runnable",
|
||||||
|
"RunnableSerializable",
|
||||||
"RunnableBinding",
|
"RunnableBinding",
|
||||||
"RunnableBranch",
|
"RunnableBranch",
|
||||||
"RunnableConfig",
|
"RunnableConfig",
|
||||||
|
@ -11,8 +11,7 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.schema.runnable.base import Input, Output, RunnableSerializable
|
||||||
from langchain.schema.runnable.base import Input, Output, Runnable
|
|
||||||
from langchain.schema.runnable.config import RunnableConfig
|
from langchain.schema.runnable.config import RunnableConfig
|
||||||
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
||||||
|
|
||||||
@ -104,7 +103,7 @@ class PutLocalVar(RunnablePassthrough):
|
|||||||
|
|
||||||
|
|
||||||
class GetLocalVar(
|
class GetLocalVar(
|
||||||
Serializable, Runnable[Input, Union[Output, Dict[str, Union[Input, Output]]]]
|
RunnableSerializable[Input, Union[Output, Dict[str, Union[Input, Output]]]]
|
||||||
):
|
):
|
||||||
key: str
|
key: str
|
||||||
"""The key to extract from the local state."""
|
"""The key to extract from the local state."""
|
||||||
|
@ -119,6 +119,24 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
self.__class__.__name__ + "Output", __root__=(root_type, None)
|
self.__class__.__name__ + "Output", __root__=(root_type, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def config_schema(
|
||||||
|
self, *, include: Optional[Sequence[str]] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
|
class _Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
include = include or []
|
||||||
|
|
||||||
|
return create_model( # type: ignore[call-overload]
|
||||||
|
self.__class__.__name__ + "Config",
|
||||||
|
__config__=_Config,
|
||||||
|
**{
|
||||||
|
field_name: (field_type, None)
|
||||||
|
for field_name, field_type in RunnableConfig.__annotations__.items()
|
||||||
|
if field_name in include
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def __or__(
|
def __or__(
|
||||||
self,
|
self,
|
||||||
other: Union[
|
other: Union[
|
||||||
@ -812,209 +830,11 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
await run_manager.on_chain_end(final_output, inputs=final_input)
|
await run_manager.on_chain_end(final_output, inputs=final_input)
|
||||||
|
|
||||||
|
|
||||||
class RunnableBranch(Serializable, Runnable[Input, Output]):
|
class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||||
"""A Runnable that selects which branch to run based on a condition.
|
pass
|
||||||
|
|
||||||
The runnable is initialized with a list of (condition, runnable) pairs and
|
|
||||||
a default branch.
|
|
||||||
|
|
||||||
When operating on an input, the first condition that evaluates to True is
|
|
||||||
selected, and the corresponding runnable is run on the input.
|
|
||||||
|
|
||||||
If no condition evaluates to True, the default branch is run on the input.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from langchain.schema.runnable import RunnableBranch
|
|
||||||
|
|
||||||
branch = RunnableBranch(
|
|
||||||
(lambda x: isinstance(x, str), lambda x: x.upper()),
|
|
||||||
(lambda x: isinstance(x, int), lambda x: x + 1),
|
|
||||||
(lambda x: isinstance(x, float), lambda x: x * 2),
|
|
||||||
lambda x: "goodbye",
|
|
||||||
)
|
|
||||||
|
|
||||||
branch.invoke("hello") # "HELLO"
|
|
||||||
branch.invoke(None) # "goodbye"
|
|
||||||
"""
|
|
||||||
|
|
||||||
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]]
|
|
||||||
default: Runnable[Input, Output]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*branches: Union[
|
|
||||||
Tuple[
|
|
||||||
Union[
|
|
||||||
Runnable[Input, bool],
|
|
||||||
Callable[[Input], bool],
|
|
||||||
Callable[[Input], Awaitable[bool]],
|
|
||||||
],
|
|
||||||
RunnableLike,
|
|
||||||
],
|
|
||||||
RunnableLike, # To accommodate the default branch
|
|
||||||
],
|
|
||||||
) -> None:
|
|
||||||
"""A Runnable that runs one of two branches based on a condition."""
|
|
||||||
if len(branches) < 2:
|
|
||||||
raise ValueError("RunnableBranch requires at least two branches")
|
|
||||||
|
|
||||||
default = branches[-1]
|
|
||||||
|
|
||||||
if not isinstance(
|
|
||||||
default, (Runnable, Callable, Mapping) # type: ignore[arg-type]
|
|
||||||
):
|
|
||||||
raise TypeError(
|
|
||||||
"RunnableBranch default must be runnable, callable or mapping."
|
|
||||||
)
|
|
||||||
|
|
||||||
default_ = cast(
|
|
||||||
Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default))
|
|
||||||
)
|
|
||||||
|
|
||||||
_branches = []
|
|
||||||
|
|
||||||
for branch in branches[:-1]:
|
|
||||||
if not isinstance(branch, (tuple, list)): # type: ignore[arg-type]
|
|
||||||
raise TypeError(
|
|
||||||
f"RunnableBranch branches must be "
|
|
||||||
f"tuples or lists, not {type(branch)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not len(branch) == 2:
|
|
||||||
raise ValueError(
|
|
||||||
f"RunnableBranch branches must be "
|
|
||||||
f"tuples or lists of length 2, not {len(branch)}"
|
|
||||||
)
|
|
||||||
condition, runnable = branch
|
|
||||||
condition = cast(Runnable[Input, bool], coerce_to_runnable(condition))
|
|
||||||
runnable = coerce_to_runnable(runnable)
|
|
||||||
_branches.append((condition, runnable))
|
|
||||||
|
|
||||||
super().__init__(branches=_branches, default=default_)
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_lc_serializable(cls) -> bool:
|
|
||||||
"""RunnableBranch is serializable if all its branches are serializable."""
|
|
||||||
return True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
|
||||||
"""The namespace of a RunnableBranch is the namespace of its default branch."""
|
|
||||||
return cls.__module__.split(".")[:-1]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_schema(self) -> type[BaseModel]:
|
|
||||||
runnables = (
|
|
||||||
[self.default]
|
|
||||||
+ [r for _, r in self.branches]
|
|
||||||
+ [r for r, _ in self.branches]
|
|
||||||
)
|
|
||||||
|
|
||||||
for runnable in runnables:
|
|
||||||
if runnable.input_schema.schema().get("type") is not None:
|
|
||||||
return runnable.input_schema
|
|
||||||
|
|
||||||
return super().input_schema
|
|
||||||
|
|
||||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
|
||||||
"""First evaluates the condition, then delegate to true or false branch."""
|
|
||||||
config = ensure_config(config)
|
|
||||||
callback_manager = get_callback_manager_for_config(config)
|
|
||||||
run_manager = callback_manager.on_chain_start(
|
|
||||||
dumpd(self),
|
|
||||||
input,
|
|
||||||
name=config.get("run_name"),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
for idx, branch in enumerate(self.branches):
|
|
||||||
condition, runnable = branch
|
|
||||||
|
|
||||||
expression_value = condition.invoke(
|
|
||||||
input,
|
|
||||||
config=patch_config(
|
|
||||||
config,
|
|
||||||
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if expression_value:
|
|
||||||
output = runnable.invoke(
|
|
||||||
input,
|
|
||||||
config=patch_config(
|
|
||||||
config,
|
|
||||||
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
output = self.default.invoke(
|
|
||||||
input,
|
|
||||||
config=patch_config(
|
|
||||||
config, callbacks=run_manager.get_child(tag="branch:default")
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
run_manager.on_chain_error(e)
|
|
||||||
raise
|
|
||||||
run_manager.on_chain_end(dumpd(output))
|
|
||||||
return output
|
|
||||||
|
|
||||||
async def ainvoke(
|
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
|
||||||
) -> Output:
|
|
||||||
"""Async version of invoke."""
|
|
||||||
config = ensure_config(config)
|
|
||||||
callback_manager = get_callback_manager_for_config(config)
|
|
||||||
run_manager = callback_manager.on_chain_start(
|
|
||||||
dumpd(self),
|
|
||||||
input,
|
|
||||||
name=config.get("run_name"),
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
for idx, branch in enumerate(self.branches):
|
|
||||||
condition, runnable = branch
|
|
||||||
|
|
||||||
expression_value = await condition.ainvoke(
|
|
||||||
input,
|
|
||||||
config=patch_config(
|
|
||||||
config,
|
|
||||||
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if expression_value:
|
|
||||||
output = await runnable.ainvoke(
|
|
||||||
input,
|
|
||||||
config=patch_config(
|
|
||||||
config,
|
|
||||||
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
|
|
||||||
),
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
output = await self.default.ainvoke(
|
|
||||||
input,
|
|
||||||
config=patch_config(
|
|
||||||
config, callbacks=run_manager.get_child(tag="branch:default")
|
|
||||||
),
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
run_manager.on_chain_error(e)
|
|
||||||
raise
|
|
||||||
run_manager.on_chain_end(dumpd(output))
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||||
"""
|
"""
|
||||||
A Runnable that can fallback to other Runnables if it fails.
|
A Runnable that can fallback to other Runnables if it fails.
|
||||||
"""
|
"""
|
||||||
@ -1042,6 +862,11 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
def output_schema(self) -> Type[BaseModel]:
|
def output_schema(self) -> Type[BaseModel]:
|
||||||
return self.runnable.output_schema
|
return self.runnable.output_schema
|
||||||
|
|
||||||
|
def config_schema(
|
||||||
|
self, *, include: Optional[Sequence[str]] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
|
return self.runnable.config_schema(include=include)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_lc_serializable(cls) -> bool:
|
def is_lc_serializable(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
@ -1267,7 +1092,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
raise first_error
|
raise first_error
|
||||||
|
|
||||||
|
|
||||||
class RunnableSequence(Serializable, Runnable[Input, Output]):
|
class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||||
"""
|
"""
|
||||||
A sequence of runnables, where the output of each is the input of the next.
|
A sequence of runnables, where the output of each is the input of the next.
|
||||||
"""
|
"""
|
||||||
@ -1749,7 +1574,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
|
||||||
"""
|
"""
|
||||||
A runnable that runs a mapping of runnables in parallel,
|
A runnable that runs a mapping of runnables in parallel,
|
||||||
and returns a mapping of their outputs.
|
and returns a mapping of their outputs.
|
||||||
@ -1799,7 +1624,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
return create_model( # type: ignore[call-overload]
|
return create_model( # type: ignore[call-overload]
|
||||||
"RunnableMapInput",
|
"RunnableMapInput",
|
||||||
**{
|
**{
|
||||||
k: (v.type_, v.default)
|
k: (v.annotation, v.default)
|
||||||
for step in self.steps.values()
|
for step in self.steps.values()
|
||||||
for k, v in step.input_schema.__fields__.items()
|
for k, v in step.input_schema.__fields__.items()
|
||||||
if k != "__root__"
|
if k != "__root__"
|
||||||
@ -2374,7 +2199,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
return await super().ainvoke(input, config)
|
return await super().ainvoke(input, config)
|
||||||
|
|
||||||
|
|
||||||
class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
|
class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
|
||||||
"""
|
"""
|
||||||
A runnable that delegates calls to another runnable
|
A runnable that delegates calls to another runnable
|
||||||
with each element of the input sequence.
|
with each element of the input sequence.
|
||||||
@ -2413,6 +2238,11 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def config_schema(
|
||||||
|
self, *, include: Optional[Sequence[str]] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
|
return self.bound.config_schema(include=include)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_lc_serializable(cls) -> bool:
|
def is_lc_serializable(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
@ -2455,7 +2285,7 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
|
|||||||
return await self._acall_with_config(self._ainvoke, input, config)
|
return await self._acall_with_config(self._ainvoke, input, config)
|
||||||
|
|
||||||
|
|
||||||
class RunnableBinding(Serializable, Runnable[Input, Output]):
|
class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||||
"""
|
"""
|
||||||
A runnable that delegates calls to another runnable with a set of kwargs.
|
A runnable that delegates calls to another runnable with a set of kwargs.
|
||||||
"""
|
"""
|
||||||
@ -2485,6 +2315,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
def output_schema(self) -> Type[BaseModel]:
|
def output_schema(self) -> Type[BaseModel]:
|
||||||
return self.bound.output_schema
|
return self.bound.output_schema
|
||||||
|
|
||||||
|
def config_schema(
|
||||||
|
self, *, include: Optional[Sequence[str]] = None
|
||||||
|
) -> Type[BaseModel]:
|
||||||
|
return self.bound.config_schema(include=include)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_lc_serializable(cls) -> bool:
|
def is_lc_serializable(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
|
234
libs/langchain/langchain/schema/runnable/branch.py
Normal file
234
libs/langchain/langchain/schema/runnable/branch.py
Normal file
@ -0,0 +1,234 @@
|
|||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain.load.dump import dumpd
|
||||||
|
from langchain.pydantic_v1 import BaseModel
|
||||||
|
from langchain.schema.runnable.base import (
|
||||||
|
Runnable,
|
||||||
|
RunnableLike,
|
||||||
|
RunnableSerializable,
|
||||||
|
coerce_to_runnable,
|
||||||
|
)
|
||||||
|
from langchain.schema.runnable.config import (
|
||||||
|
RunnableConfig,
|
||||||
|
ensure_config,
|
||||||
|
get_callback_manager_for_config,
|
||||||
|
patch_config,
|
||||||
|
)
|
||||||
|
from langchain.schema.runnable.utils import Input, Output
|
||||||
|
|
||||||
|
|
||||||
|
class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||||
|
"""A Runnable that selects which branch to run based on a condition.
|
||||||
|
|
||||||
|
The runnable is initialized with a list of (condition, runnable) pairs and
|
||||||
|
a default branch.
|
||||||
|
|
||||||
|
When operating on an input, the first condition that evaluates to True is
|
||||||
|
selected, and the corresponding runnable is run on the input.
|
||||||
|
|
||||||
|
If no condition evaluates to True, the default branch is run on the input.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.schema.runnable import RunnableBranch
|
||||||
|
|
||||||
|
branch = RunnableBranch(
|
||||||
|
(lambda x: isinstance(x, str), lambda x: x.upper()),
|
||||||
|
(lambda x: isinstance(x, int), lambda x: x + 1),
|
||||||
|
(lambda x: isinstance(x, float), lambda x: x * 2),
|
||||||
|
lambda x: "goodbye",
|
||||||
|
)
|
||||||
|
|
||||||
|
branch.invoke("hello") # "HELLO"
|
||||||
|
branch.invoke(None) # "goodbye"
|
||||||
|
"""
|
||||||
|
|
||||||
|
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]]
|
||||||
|
default: Runnable[Input, Output]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*branches: Union[
|
||||||
|
Tuple[
|
||||||
|
Union[
|
||||||
|
Runnable[Input, bool],
|
||||||
|
Callable[[Input], bool],
|
||||||
|
Callable[[Input], Awaitable[bool]],
|
||||||
|
],
|
||||||
|
RunnableLike,
|
||||||
|
],
|
||||||
|
RunnableLike, # To accommodate the default branch
|
||||||
|
],
|
||||||
|
) -> None:
|
||||||
|
"""A Runnable that runs one of two branches based on a condition."""
|
||||||
|
if len(branches) < 2:
|
||||||
|
raise ValueError("RunnableBranch requires at least two branches")
|
||||||
|
|
||||||
|
default = branches[-1]
|
||||||
|
|
||||||
|
if not isinstance(
|
||||||
|
default, (Runnable, Callable, Mapping) # type: ignore[arg-type]
|
||||||
|
):
|
||||||
|
raise TypeError(
|
||||||
|
"RunnableBranch default must be runnable, callable or mapping."
|
||||||
|
)
|
||||||
|
|
||||||
|
default_ = cast(
|
||||||
|
Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default))
|
||||||
|
)
|
||||||
|
|
||||||
|
_branches = []
|
||||||
|
|
||||||
|
for branch in branches[:-1]:
|
||||||
|
if not isinstance(branch, (tuple, list)): # type: ignore[arg-type]
|
||||||
|
raise TypeError(
|
||||||
|
f"RunnableBranch branches must be "
|
||||||
|
f"tuples or lists, not {type(branch)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not len(branch) == 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"RunnableBranch branches must be "
|
||||||
|
f"tuples or lists of length 2, not {len(branch)}"
|
||||||
|
)
|
||||||
|
condition, runnable = branch
|
||||||
|
condition = cast(Runnable[Input, bool], coerce_to_runnable(condition))
|
||||||
|
runnable = coerce_to_runnable(runnable)
|
||||||
|
_branches.append((condition, runnable))
|
||||||
|
|
||||||
|
super().__init__(branches=_branches, default=default_)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
"""RunnableBranch is serializable if all its branches are serializable."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_lc_namespace(cls) -> List[str]:
|
||||||
|
"""The namespace of a RunnableBranch is the namespace of its default branch."""
|
||||||
|
return cls.__module__.split(".")[:-1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_schema(self) -> type[BaseModel]:
|
||||||
|
runnables = (
|
||||||
|
[self.default]
|
||||||
|
+ [r for _, r in self.branches]
|
||||||
|
+ [r for r, _ in self.branches]
|
||||||
|
)
|
||||||
|
|
||||||
|
for runnable in runnables:
|
||||||
|
if runnable.input_schema.schema().get("type") is not None:
|
||||||
|
return runnable.input_schema
|
||||||
|
|
||||||
|
return super().input_schema
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
|
) -> Output:
|
||||||
|
"""First evaluates the condition, then delegate to true or false branch."""
|
||||||
|
config = ensure_config(config)
|
||||||
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
|
run_manager = callback_manager.on_chain_start(
|
||||||
|
dumpd(self),
|
||||||
|
input,
|
||||||
|
name=config.get("run_name"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for idx, branch in enumerate(self.branches):
|
||||||
|
condition, runnable = branch
|
||||||
|
|
||||||
|
expression_value = condition.invoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if expression_value:
|
||||||
|
output = runnable.invoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
output = self.default.invoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config, callbacks=run_manager.get_child(tag="branch:default")
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
run_manager.on_chain_error(e)
|
||||||
|
raise
|
||||||
|
run_manager.on_chain_end(dumpd(output))
|
||||||
|
return output
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
|
) -> Output:
|
||||||
|
"""Async version of invoke."""
|
||||||
|
config = ensure_config(config)
|
||||||
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
|
run_manager = callback_manager.on_chain_start(
|
||||||
|
dumpd(self),
|
||||||
|
input,
|
||||||
|
name=config.get("run_name"),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
for idx, branch in enumerate(self.branches):
|
||||||
|
condition, runnable = branch
|
||||||
|
|
||||||
|
expression_value = await condition.ainvoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if expression_value:
|
||||||
|
output = await runnable.ainvoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config,
|
||||||
|
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
output = await self.default.ainvoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config, callbacks=run_manager.get_child(tag="branch:default")
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
run_manager.on_chain_error(e)
|
||||||
|
raise
|
||||||
|
run_manager.on_chain_end(dumpd(output))
|
||||||
|
return output
|
@ -16,9 +16,13 @@ from typing import (
|
|||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
|
||||||
from langchain.pydantic_v1 import BaseModel, create_model
|
from langchain.pydantic_v1 import BaseModel, create_model
|
||||||
from langchain.schema.runnable.base import Input, Runnable, RunnableMap
|
from langchain.schema.runnable.base import (
|
||||||
|
Input,
|
||||||
|
Runnable,
|
||||||
|
RunnableMap,
|
||||||
|
RunnableSerializable,
|
||||||
|
)
|
||||||
from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config
|
from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config
|
||||||
from langchain.schema.runnable.utils import AddableDict
|
from langchain.schema.runnable.utils import AddableDict
|
||||||
from langchain.utils.aiter import atee, py_anext
|
from langchain.utils.aiter import atee, py_anext
|
||||||
@ -33,7 +37,7 @@ async def aidentity(x: Input) -> Input:
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
class RunnablePassthrough(RunnableSerializable[Input, Input]):
|
||||||
"""
|
"""
|
||||||
A runnable that passes through the input.
|
A runnable that passes through the input.
|
||||||
"""
|
"""
|
||||||
@ -109,7 +113,7 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
class RunnableAssign(Serializable, Runnable[Dict[str, Any], Dict[str, Any]]):
|
class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||||
"""
|
"""
|
||||||
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
|
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
|
||||||
"""
|
"""
|
||||||
|
@ -14,8 +14,13 @@ from typing import (
|
|||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.schema.runnable.base import (
|
||||||
from langchain.schema.runnable.base import Input, Output, Runnable, coerce_to_runnable
|
Input,
|
||||||
|
Output,
|
||||||
|
Runnable,
|
||||||
|
RunnableSerializable,
|
||||||
|
coerce_to_runnable,
|
||||||
|
)
|
||||||
from langchain.schema.runnable.config import (
|
from langchain.schema.runnable.config import (
|
||||||
RunnableConfig,
|
RunnableConfig,
|
||||||
get_config_list,
|
get_config_list,
|
||||||
@ -36,7 +41,7 @@ class RouterInput(TypedDict):
|
|||||||
input: Any
|
input: Any
|
||||||
|
|
||||||
|
|
||||||
class RouterRunnable(Serializable, Runnable[RouterInput, Output]):
|
class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
||||||
"""
|
"""
|
||||||
A runnable that routes to a set of runnables based on Input['key'].
|
A runnable that routes to a set of runnables based on Input['key'].
|
||||||
Returns the output of the selected runnable.
|
Returns the output of the selected runnable.
|
||||||
|
@ -25,7 +25,7 @@ from langchain.pydantic_v1 import (
|
|||||||
root_validator,
|
root_validator,
|
||||||
validate_arguments,
|
validate_arguments,
|
||||||
)
|
)
|
||||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
from langchain.schema.runnable import Runnable, RunnableConfig, RunnableSerializable
|
||||||
|
|
||||||
|
|
||||||
class SchemaAnnotationError(TypeError):
|
class SchemaAnnotationError(TypeError):
|
||||||
@ -97,7 +97,7 @@ class ToolException(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BaseTool(BaseModel, Runnable[Union[str, Dict], Any]):
|
class BaseTool(RunnableSerializable[Union[str, Dict], Any]):
|
||||||
"""Interface LangChain tools must implement."""
|
"""Interface LangChain tools must implement."""
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||||
@ -168,7 +168,6 @@ class ChildTool(BaseTool):
|
|||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -21,14 +21,6 @@ class BaseSparkSQLTool(BaseModel):
|
|||||||
|
|
||||||
db: SparkSQL = Field(exclude=True)
|
db: SparkSQL = Field(exclude=True)
|
||||||
|
|
||||||
# Override BaseTool.Config to appease mypy
|
|
||||||
# See https://github.com/pydantic/pydantic/issues/4173
|
|
||||||
class Config(BaseTool.Config):
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
extra = Extra.forbid
|
|
||||||
|
|
||||||
|
|
||||||
class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool):
|
class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool):
|
||||||
"""Tool for querying a Spark SQL."""
|
"""Tool for querying a Spark SQL."""
|
||||||
|
@ -21,14 +21,6 @@ class BaseSQLDatabaseTool(BaseModel):
|
|||||||
|
|
||||||
db: SQLDatabase = Field(exclude=True)
|
db: SQLDatabase = Field(exclude=True)
|
||||||
|
|
||||||
# Override BaseTool.Config to appease mypy
|
|
||||||
# See https://github.com/pydantic/pydantic/issues/4173
|
|
||||||
class Config(BaseTool.Config):
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
extra = Extra.forbid
|
|
||||||
|
|
||||||
|
|
||||||
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
|
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||||
"""Tool for querying a SQL database."""
|
"""Tool for querying a SQL database."""
|
||||||
|
@ -17,11 +17,6 @@ class BaseVectorStoreTool(BaseModel):
|
|||||||
vectorstore: VectorStore = Field(exclude=True)
|
vectorstore: VectorStore = Field(exclude=True)
|
||||||
llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
|
llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
|
||||||
|
|
||||||
class Config(BaseTool.Config):
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
|
|
||||||
def _create_description_from_template(values: Dict[str, Any]) -> Dict[str, Any]:
|
def _create_description_from_template(values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
values["description"] = values["template"].format(name=values["name"])
|
values["description"] = values["template"].format(name=values["name"])
|
||||||
|
Loading…
Reference in New Issue
Block a user