mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-25 20:43:23 +00:00
make runnable dir (#9016)
Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
parent
c7a489ae0d
commit
434a96415b
24
libs/langchain/langchain/schema/runnable/__init__.py
Normal file
24
libs/langchain/langchain/schema/runnable/__init__.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from langchain.schema.runnable.base import (
|
||||||
|
Runnable,
|
||||||
|
RunnableBinding,
|
||||||
|
RunnableLambda,
|
||||||
|
RunnableMap,
|
||||||
|
RunnableSequence,
|
||||||
|
RunnableWithFallbacks,
|
||||||
|
)
|
||||||
|
from langchain.schema.runnable.config import RunnableConfig
|
||||||
|
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
||||||
|
from langchain.schema.runnable.router import RouterInput, RouterRunnable
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"RouterInput",
|
||||||
|
"RouterRunnable",
|
||||||
|
"Runnable",
|
||||||
|
"RunnableBinding",
|
||||||
|
"RunnableConfig",
|
||||||
|
"RunnableMap",
|
||||||
|
"RunnableLambda",
|
||||||
|
"RunnablePassthrough",
|
||||||
|
"RunnableSequence",
|
||||||
|
"RunnableWithFallbacks",
|
||||||
|
]
|
@ -9,7 +9,6 @@ from typing import (
|
|||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Coroutine,
|
|
||||||
Dict,
|
Dict,
|
||||||
Generic,
|
Generic,
|
||||||
Iterator,
|
Iterator,
|
||||||
@ -19,7 +18,6 @@ from typing import (
|
|||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
TypedDict,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
@ -27,48 +25,15 @@ from typing import (
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackManager, Callbacks
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
|
from langchain.schema.runnable.config import RunnableConfig
|
||||||
|
from langchain.schema.runnable.utils import (
|
||||||
|
gather_with_concurrency,
|
||||||
|
)
|
||||||
from langchain.utils.aiter import atee, py_anext
|
from langchain.utils.aiter import atee, py_anext
|
||||||
|
|
||||||
|
|
||||||
async def _gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
|
||||||
async with semaphore:
|
|
||||||
return await coro
|
|
||||||
|
|
||||||
|
|
||||||
async def _gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list:
|
|
||||||
if n is None:
|
|
||||||
return await asyncio.gather(*coros)
|
|
||||||
|
|
||||||
semaphore = asyncio.Semaphore(n)
|
|
||||||
|
|
||||||
return await asyncio.gather(*(_gated_coro(semaphore, c) for c in coros))
|
|
||||||
|
|
||||||
|
|
||||||
class RunnableConfig(TypedDict, total=False):
|
|
||||||
"""Configuration for a Runnable."""
|
|
||||||
|
|
||||||
tags: List[str]
|
|
||||||
"""
|
|
||||||
Tags for this call and any sub-calls (eg. a Chain calling an LLM).
|
|
||||||
You can use these to filter calls.
|
|
||||||
"""
|
|
||||||
|
|
||||||
metadata: Dict[str, Any]
|
|
||||||
"""
|
|
||||||
Metadata for this call and any sub-calls (eg. a Chain calling an LLM).
|
|
||||||
Keys should be strings, values should be JSON-serializable.
|
|
||||||
"""
|
|
||||||
|
|
||||||
callbacks: Callbacks
|
|
||||||
"""
|
|
||||||
Callbacks for this call and any sub-calls (eg. a Chain calling an LLM).
|
|
||||||
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
Input = TypeVar("Input")
|
Input = TypeVar("Input")
|
||||||
# Output type should implement __concat__, as eg str, list, dict do
|
# Output type should implement __concat__, as eg str, list, dict do
|
||||||
Output = TypeVar("Output")
|
Output = TypeVar("Output")
|
||||||
@ -87,7 +52,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
||||||
],
|
],
|
||||||
) -> RunnableSequence[Input, Other]:
|
) -> RunnableSequence[Input, Other]:
|
||||||
return RunnableSequence(first=self, last=_coerce_to_runnable(other))
|
return RunnableSequence(first=self, last=coerce_to_runnable(other))
|
||||||
|
|
||||||
def __ror__(
|
def __ror__(
|
||||||
self,
|
self,
|
||||||
@ -97,7 +62,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
||||||
],
|
],
|
||||||
) -> RunnableSequence[Other, Output]:
|
) -> RunnableSequence[Other, Output]:
|
||||||
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
|
return RunnableSequence(first=coerce_to_runnable(other), last=self)
|
||||||
|
|
||||||
""" --- Public API --- """
|
""" --- Public API --- """
|
||||||
|
|
||||||
@ -150,7 +115,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
configs = self._get_config_list(config, len(inputs))
|
configs = self._get_config_list(config, len(inputs))
|
||||||
coros = map(self.ainvoke, inputs, configs)
|
coros = map(self.ainvoke, inputs, configs)
|
||||||
|
|
||||||
return await _gather_with_concurrency(max_concurrency, *coros)
|
return await gather_with_concurrency(max_concurrency, *coros)
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
@ -478,6 +443,14 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_namespace(self) -> List[str]:
|
||||||
|
return self.__class__.__module__.split(".")[:-1]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def runnables(self) -> Iterator[Runnable[Input, Output]]:
|
def runnables(self) -> Iterator[Runnable[Input, Output]]:
|
||||||
yield self.runnable
|
yield self.runnable
|
||||||
@ -506,7 +479,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
try:
|
try:
|
||||||
output = runnable.invoke(
|
output = runnable.invoke(
|
||||||
input,
|
input,
|
||||||
_patch_config(config, run_manager.get_child()),
|
patch_config(config, run_manager.get_child()),
|
||||||
)
|
)
|
||||||
except self.exceptions_to_handle as e:
|
except self.exceptions_to_handle as e:
|
||||||
if first_error is None:
|
if first_error is None:
|
||||||
@ -550,7 +523,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
try:
|
try:
|
||||||
output = await runnable.ainvoke(
|
output = await runnable.ainvoke(
|
||||||
input,
|
input,
|
||||||
_patch_config(config, run_manager.get_child()),
|
patch_config(config, run_manager.get_child()),
|
||||||
)
|
)
|
||||||
except self.exceptions_to_handle as e:
|
except self.exceptions_to_handle as e:
|
||||||
if first_error is None:
|
if first_error is None:
|
||||||
@ -606,7 +579,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
inputs,
|
inputs,
|
||||||
[
|
[
|
||||||
# each step a child run of the corresponding root run
|
# each step a child run of the corresponding root run
|
||||||
_patch_config(config, rm.get_child())
|
patch_config(config, rm.get_child())
|
||||||
for rm, config in zip(run_managers, configs)
|
for rm, config in zip(run_managers, configs)
|
||||||
],
|
],
|
||||||
max_concurrency=max_concurrency,
|
max_concurrency=max_concurrency,
|
||||||
@ -673,7 +646,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
inputs,
|
inputs,
|
||||||
[
|
[
|
||||||
# each step a child run of the corresponding root run
|
# each step a child run of the corresponding root run
|
||||||
_patch_config(config, rm.get_child())
|
patch_config(config, rm.get_child())
|
||||||
for rm, config in zip(run_managers, configs)
|
for rm, config in zip(run_managers, configs)
|
||||||
],
|
],
|
||||||
max_concurrency=max_concurrency,
|
max_concurrency=max_concurrency,
|
||||||
@ -716,6 +689,10 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
def lc_serializable(self) -> bool:
|
def lc_serializable(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_namespace(self) -> List[str]:
|
||||||
|
return self.__class__.__module__.split(".")[:-1]
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@ -737,7 +714,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
return RunnableSequence(
|
return RunnableSequence(
|
||||||
first=self.first,
|
first=self.first,
|
||||||
middle=self.middle + [self.last],
|
middle=self.middle + [self.last],
|
||||||
last=_coerce_to_runnable(other),
|
last=coerce_to_runnable(other),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __ror__(
|
def __ror__(
|
||||||
@ -756,7 +733,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return RunnableSequence(
|
return RunnableSequence(
|
||||||
first=_coerce_to_runnable(other),
|
first=coerce_to_runnable(other),
|
||||||
middle=[self.first] + self.middle,
|
middle=[self.first] + self.middle,
|
||||||
last=self.last,
|
last=self.last,
|
||||||
)
|
)
|
||||||
@ -786,7 +763,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
input = step.invoke(
|
input = step.invoke(
|
||||||
input,
|
input,
|
||||||
# mark each step as a child run
|
# mark each step as a child run
|
||||||
_patch_config(config, run_manager.get_child()),
|
patch_config(config, run_manager.get_child()),
|
||||||
)
|
)
|
||||||
# finish the root run
|
# finish the root run
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
@ -825,7 +802,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
input = await step.ainvoke(
|
input = await step.ainvoke(
|
||||||
input,
|
input,
|
||||||
# mark each step as a child run
|
# mark each step as a child run
|
||||||
_patch_config(config, run_manager.get_child()),
|
patch_config(config, run_manager.get_child()),
|
||||||
)
|
)
|
||||||
# finish the root run
|
# finish the root run
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
@ -875,7 +852,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
inputs,
|
inputs,
|
||||||
[
|
[
|
||||||
# each step a child run of the corresponding root run
|
# each step a child run of the corresponding root run
|
||||||
_patch_config(config, rm.get_child())
|
patch_config(config, rm.get_child())
|
||||||
for rm, config in zip(run_managers, configs)
|
for rm, config in zip(run_managers, configs)
|
||||||
],
|
],
|
||||||
max_concurrency=max_concurrency,
|
max_concurrency=max_concurrency,
|
||||||
@ -934,7 +911,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
inputs,
|
inputs,
|
||||||
[
|
[
|
||||||
# each step a child run of the corresponding root run
|
# each step a child run of the corresponding root run
|
||||||
_patch_config(config, rm.get_child())
|
patch_config(config, rm.get_child())
|
||||||
for rm, config in zip(run_managers, configs)
|
for rm, config in zip(run_managers, configs)
|
||||||
],
|
],
|
||||||
max_concurrency=max_concurrency,
|
max_concurrency=max_concurrency,
|
||||||
@ -990,7 +967,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
input = step.invoke(
|
input = step.invoke(
|
||||||
input,
|
input,
|
||||||
# mark each step as a child run
|
# mark each step as a child run
|
||||||
_patch_config(config, run_manager.get_child()),
|
patch_config(config, run_manager.get_child()),
|
||||||
)
|
)
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
@ -1002,12 +979,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
try:
|
try:
|
||||||
# stream the first of the last steps with non-streaming input
|
# stream the first of the last steps with non-streaming input
|
||||||
final_pipeline = steps[streaming_start_index].stream(
|
final_pipeline = steps[streaming_start_index].stream(
|
||||||
input, _patch_config(config, run_manager.get_child())
|
input, patch_config(config, run_manager.get_child())
|
||||||
)
|
)
|
||||||
# stream the rest of the last steps with streaming input
|
# stream the rest of the last steps with streaming input
|
||||||
for step in steps[streaming_start_index + 1 :]:
|
for step in steps[streaming_start_index + 1 :]:
|
||||||
final_pipeline = step.transform(
|
final_pipeline = step.transform(
|
||||||
final_pipeline, _patch_config(config, run_manager.get_child())
|
final_pipeline, patch_config(config, run_manager.get_child())
|
||||||
)
|
)
|
||||||
for output in final_pipeline:
|
for output in final_pipeline:
|
||||||
yield output
|
yield output
|
||||||
@ -1067,7 +1044,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
input = await step.ainvoke(
|
input = await step.ainvoke(
|
||||||
input,
|
input,
|
||||||
# mark each step as a child run
|
# mark each step as a child run
|
||||||
_patch_config(config, run_manager.get_child()),
|
patch_config(config, run_manager.get_child()),
|
||||||
)
|
)
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
@ -1079,12 +1056,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
try:
|
try:
|
||||||
# stream the first of the last steps with non-streaming input
|
# stream the first of the last steps with non-streaming input
|
||||||
final_pipeline = steps[streaming_start_index].astream(
|
final_pipeline = steps[streaming_start_index].astream(
|
||||||
input, _patch_config(config, run_manager.get_child())
|
input, patch_config(config, run_manager.get_child())
|
||||||
)
|
)
|
||||||
# stream the rest of the last steps with streaming input
|
# stream the rest of the last steps with streaming input
|
||||||
for step in steps[streaming_start_index + 1 :]:
|
for step in steps[streaming_start_index + 1 :]:
|
||||||
final_pipeline = step.atransform(
|
final_pipeline = step.atransform(
|
||||||
final_pipeline, _patch_config(config, run_manager.get_child())
|
final_pipeline, patch_config(config, run_manager.get_child())
|
||||||
)
|
)
|
||||||
async for output in final_pipeline:
|
async for output in final_pipeline:
|
||||||
yield output
|
yield output
|
||||||
@ -1128,14 +1105,16 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
],
|
],
|
||||||
],
|
],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(steps={key: coerce_to_runnable(r) for key, r in steps.items()})
|
||||||
steps={key: _coerce_to_runnable(r) for key, r in steps.items()}
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lc_serializable(self) -> bool:
|
def lc_serializable(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_namespace(self) -> List[str]:
|
||||||
|
return self.__class__.__module__.split(".")[:-1]
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@ -1168,7 +1147,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
step.invoke,
|
step.invoke,
|
||||||
input,
|
input,
|
||||||
# mark each step as a child run
|
# mark each step as a child run
|
||||||
_patch_config(config, run_manager.get_child()),
|
patch_config(config, run_manager.get_child()),
|
||||||
)
|
)
|
||||||
for step in steps.values()
|
for step in steps.values()
|
||||||
]
|
]
|
||||||
@ -1211,7 +1190,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
step.ainvoke(
|
step.ainvoke(
|
||||||
input,
|
input,
|
||||||
# mark each step as a child run
|
# mark each step as a child run
|
||||||
_patch_config(config, run_manager.get_child()),
|
patch_config(config, run_manager.get_child()),
|
||||||
)
|
)
|
||||||
for step in steps.values()
|
for step in steps.values()
|
||||||
)
|
)
|
||||||
@ -1250,19 +1229,6 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
return self._call_with_config(self.func, input, config)
|
return self._call_with_config(self.func, input, config)
|
||||||
|
|
||||||
|
|
||||||
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
|
||||||
"""
|
|
||||||
A runnable that passes through the input.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def lc_serializable(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
|
||||||
return self._call_with_config(lambda x: x, input, config)
|
|
||||||
|
|
||||||
|
|
||||||
class RunnableBinding(Serializable, Runnable[Input, Output]):
|
class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||||
"""
|
"""
|
||||||
A runnable that delegates calls to another runnable with a set of kwargs.
|
A runnable that delegates calls to another runnable with a set of kwargs.
|
||||||
@ -1279,6 +1245,10 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
def lc_serializable(self) -> bool:
|
def lc_serializable(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_namespace(self) -> List[str]:
|
||||||
|
return self.__class__.__module__.split(".")[:-1]
|
||||||
|
|
||||||
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||||
return self.__class__(bound=self.bound, kwargs={**self.kwargs, **kwargs})
|
return self.__class__(bound=self.bound, kwargs={**self.kwargs, **kwargs})
|
||||||
|
|
||||||
@ -1335,160 +1305,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
yield item
|
yield item
|
||||||
|
|
||||||
|
|
||||||
class RouterInput(TypedDict):
|
def patch_config(
|
||||||
"""A Router input.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
key: The key to route on.
|
|
||||||
input: The input to pass to the selected runnable.
|
|
||||||
"""
|
|
||||||
|
|
||||||
key: str
|
|
||||||
input: Any
|
|
||||||
|
|
||||||
|
|
||||||
class RouterRunnable(
|
|
||||||
Serializable, Generic[Input, Output], Runnable[RouterInput, Output]
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
A runnable that routes to a set of runnables based on Input['key'].
|
|
||||||
Returns the output of the selected runnable.
|
|
||||||
"""
|
|
||||||
|
|
||||||
runnables: Mapping[str, Runnable[Input, Output]]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
runnables: Mapping[
|
|
||||||
str, Union[Runnable[Input, Output], Callable[[Input], Output]]
|
|
||||||
],
|
|
||||||
) -> None:
|
|
||||||
super().__init__(
|
|
||||||
runnables={key: _coerce_to_runnable(r) for key, r in runnables.items()}
|
|
||||||
)
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def lc_serializable(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def __or__(
|
|
||||||
self,
|
|
||||||
other: Union[
|
|
||||||
Runnable[Any, Other],
|
|
||||||
Callable[[Any], Other],
|
|
||||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
|
||||||
Mapping[str, Any],
|
|
||||||
],
|
|
||||||
) -> RunnableSequence[RouterInput, Other]:
|
|
||||||
return RunnableSequence(first=self, last=_coerce_to_runnable(other))
|
|
||||||
|
|
||||||
def __ror__(
|
|
||||||
self,
|
|
||||||
other: Union[
|
|
||||||
Runnable[Other, Any],
|
|
||||||
Callable[[Any], Other],
|
|
||||||
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
|
||||||
Mapping[str, Any],
|
|
||||||
],
|
|
||||||
) -> RunnableSequence[Other, Output]:
|
|
||||||
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
|
|
||||||
|
|
||||||
def invoke(
|
|
||||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
|
||||||
) -> Output:
|
|
||||||
key = input["key"]
|
|
||||||
actual_input = input["input"]
|
|
||||||
if key not in self.runnables:
|
|
||||||
raise ValueError(f"No runnable associated with key '{key}'")
|
|
||||||
|
|
||||||
runnable = self.runnables[key]
|
|
||||||
return runnable.invoke(actual_input, config)
|
|
||||||
|
|
||||||
async def ainvoke(
|
|
||||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
|
||||||
) -> Output:
|
|
||||||
key = input["key"]
|
|
||||||
actual_input = input["input"]
|
|
||||||
if key not in self.runnables:
|
|
||||||
raise ValueError(f"No runnable associated with key '{key}'")
|
|
||||||
|
|
||||||
runnable = self.runnables[key]
|
|
||||||
return await runnable.ainvoke(actual_input, config)
|
|
||||||
|
|
||||||
def batch(
|
|
||||||
self,
|
|
||||||
inputs: List[RouterInput],
|
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
|
||||||
*,
|
|
||||||
max_concurrency: Optional[int] = None,
|
|
||||||
) -> List[Output]:
|
|
||||||
keys = [input["key"] for input in inputs]
|
|
||||||
actual_inputs = [input["input"] for input in inputs]
|
|
||||||
if any(key not in self.runnables for key in keys):
|
|
||||||
raise ValueError("One or more keys do not have a corresponding runnable")
|
|
||||||
|
|
||||||
runnables = [self.runnables[key] for key in keys]
|
|
||||||
configs = self._get_config_list(config, len(inputs))
|
|
||||||
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
|
||||||
return list(
|
|
||||||
executor.map(
|
|
||||||
lambda runnable, input, config: runnable.invoke(input, config),
|
|
||||||
runnables,
|
|
||||||
actual_inputs,
|
|
||||||
configs,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def abatch(
|
|
||||||
self,
|
|
||||||
inputs: List[RouterInput],
|
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
|
||||||
*,
|
|
||||||
max_concurrency: Optional[int] = None,
|
|
||||||
) -> List[Output]:
|
|
||||||
keys = [input["key"] for input in inputs]
|
|
||||||
actual_inputs = [input["input"] for input in inputs]
|
|
||||||
if any(key not in self.runnables for key in keys):
|
|
||||||
raise ValueError("One or more keys do not have a corresponding runnable")
|
|
||||||
|
|
||||||
runnables = [self.runnables[key] for key in keys]
|
|
||||||
configs = self._get_config_list(config, len(inputs))
|
|
||||||
return await _gather_with_concurrency(
|
|
||||||
max_concurrency,
|
|
||||||
*(
|
|
||||||
runnable.ainvoke(input, config)
|
|
||||||
for runnable, input, config in zip(runnables, actual_inputs, configs)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def stream(
|
|
||||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
|
||||||
) -> Iterator[Output]:
|
|
||||||
key = input["key"]
|
|
||||||
actual_input = input["input"]
|
|
||||||
if key not in self.runnables:
|
|
||||||
raise ValueError(f"No runnable associated with key '{key}'")
|
|
||||||
|
|
||||||
runnable = self.runnables[key]
|
|
||||||
yield from runnable.stream(actual_input, config)
|
|
||||||
|
|
||||||
async def astream(
|
|
||||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
|
||||||
) -> AsyncIterator[Output]:
|
|
||||||
key = input["key"]
|
|
||||||
actual_input = input["input"]
|
|
||||||
if key not in self.runnables:
|
|
||||||
raise ValueError(f"No runnable associated with key '{key}'")
|
|
||||||
|
|
||||||
runnable = self.runnables[key]
|
|
||||||
async for output in runnable.astream(actual_input, config):
|
|
||||||
yield output
|
|
||||||
|
|
||||||
|
|
||||||
def _patch_config(
|
|
||||||
config: RunnableConfig, callback_manager: BaseCallbackManager
|
config: RunnableConfig, callback_manager: BaseCallbackManager
|
||||||
) -> RunnableConfig:
|
) -> RunnableConfig:
|
||||||
config = config.copy()
|
config = config.copy()
|
||||||
@ -1496,7 +1313,7 @@ def _patch_config(
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def _coerce_to_runnable(
|
def coerce_to_runnable(
|
||||||
thing: Union[
|
thing: Union[
|
||||||
Runnable[Input, Output],
|
Runnable[Input, Output],
|
||||||
Callable[[Input], Output],
|
Callable[[Input], Output],
|
||||||
@ -1508,7 +1325,7 @@ def _coerce_to_runnable(
|
|||||||
elif callable(thing):
|
elif callable(thing):
|
||||||
return RunnableLambda(thing)
|
return RunnableLambda(thing)
|
||||||
elif isinstance(thing, dict):
|
elif isinstance(thing, dict):
|
||||||
runnables = {key: _coerce_to_runnable(r) for key, r in thing.items()}
|
runnables = {key: coerce_to_runnable(r) for key, r in thing.items()}
|
||||||
return cast(Runnable[Input, Output], RunnableMap(steps=runnables))
|
return cast(Runnable[Input, Output], RunnableMap(steps=runnables))
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
27
libs/langchain/langchain/schema/runnable/config.py
Normal file
27
libs/langchain/langchain/schema/runnable/config.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, TypedDict
|
||||||
|
|
||||||
|
from langchain.callbacks.base import Callbacks
|
||||||
|
|
||||||
|
|
||||||
|
class RunnableConfig(TypedDict, total=False):
|
||||||
|
"""Configuration for a Runnable."""
|
||||||
|
|
||||||
|
tags: List[str]
|
||||||
|
"""
|
||||||
|
Tags for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||||
|
You can use these to filter calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
"""
|
||||||
|
Metadata for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||||
|
Keys should be strings, values should be JSON-serializable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
callbacks: Callbacks
|
||||||
|
"""
|
||||||
|
Callbacks for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||||
|
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
|
||||||
|
"""
|
23
libs/langchain/langchain/schema/runnable/passthrough.py
Normal file
23
libs/langchain/langchain/schema/runnable/passthrough.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from langchain.load.serializable import Serializable
|
||||||
|
from langchain.schema.runnable.base import Input, Runnable, RunnableConfig
|
||||||
|
|
||||||
|
|
||||||
|
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
||||||
|
"""
|
||||||
|
A runnable that passes through the input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_namespace(self) -> List[str]:
|
||||||
|
return self.__class__.__module__.split(".")[:-1]
|
||||||
|
|
||||||
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
||||||
|
return self._call_with_config(lambda x: x, input, config)
|
184
libs/langchain/langchain/schema/runnable/router.py
Normal file
184
libs/langchain/langchain/schema/runnable/router.py
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Callable,
|
||||||
|
Generic,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
TypedDict,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain.load.serializable import Serializable
|
||||||
|
from langchain.schema.runnable.base import (
|
||||||
|
Input,
|
||||||
|
Other,
|
||||||
|
Output,
|
||||||
|
Runnable,
|
||||||
|
RunnableSequence,
|
||||||
|
coerce_to_runnable,
|
||||||
|
)
|
||||||
|
from langchain.schema.runnable.config import RunnableConfig
|
||||||
|
from langchain.schema.runnable.utils import gather_with_concurrency
|
||||||
|
|
||||||
|
|
||||||
|
class RouterInput(TypedDict):
|
||||||
|
"""A Router input.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
key: The key to route on.
|
||||||
|
input: The input to pass to the selected runnable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
key: str
|
||||||
|
input: Any
|
||||||
|
|
||||||
|
|
||||||
|
class RouterRunnable(
|
||||||
|
Serializable, Generic[Input, Output], Runnable[RouterInput, Output]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
A runnable that routes to a set of runnables based on Input['key'].
|
||||||
|
Returns the output of the selected runnable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
runnables: Mapping[str, Runnable[Input, Output]]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
runnables: Mapping[
|
||||||
|
str, Union[Runnable[Input, Output], Callable[[Input], Output]]
|
||||||
|
],
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
runnables={key: coerce_to_runnable(r) for key, r in runnables.items()}
|
||||||
|
)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_namespace(self) -> List[str]:
|
||||||
|
return self.__class__.__module__.split(".")[:-1]
|
||||||
|
|
||||||
|
def __or__(
|
||||||
|
self,
|
||||||
|
other: Union[
|
||||||
|
Runnable[Any, Other],
|
||||||
|
Callable[[Any], Other],
|
||||||
|
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
||||||
|
Mapping[str, Any],
|
||||||
|
],
|
||||||
|
) -> RunnableSequence[RouterInput, Other]:
|
||||||
|
return RunnableSequence(first=self, last=coerce_to_runnable(other))
|
||||||
|
|
||||||
|
def __ror__(
|
||||||
|
self,
|
||||||
|
other: Union[
|
||||||
|
Runnable[Other, Any],
|
||||||
|
Callable[[Any], Other],
|
||||||
|
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
||||||
|
Mapping[str, Any],
|
||||||
|
],
|
||||||
|
) -> RunnableSequence[Other, Output]:
|
||||||
|
return RunnableSequence(first=coerce_to_runnable(other), last=self)
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Output:
|
||||||
|
key = input["key"]
|
||||||
|
actual_input = input["input"]
|
||||||
|
if key not in self.runnables:
|
||||||
|
raise ValueError(f"No runnable associated with key '{key}'")
|
||||||
|
|
||||||
|
runnable = self.runnables[key]
|
||||||
|
return runnable.invoke(actual_input, config)
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Output:
|
||||||
|
key = input["key"]
|
||||||
|
actual_input = input["input"]
|
||||||
|
if key not in self.runnables:
|
||||||
|
raise ValueError(f"No runnable associated with key '{key}'")
|
||||||
|
|
||||||
|
runnable = self.runnables[key]
|
||||||
|
return await runnable.ainvoke(actual_input, config)
|
||||||
|
|
||||||
|
def batch(
|
||||||
|
self,
|
||||||
|
inputs: List[RouterInput],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
max_concurrency: Optional[int] = None,
|
||||||
|
) -> List[Output]:
|
||||||
|
keys = [input["key"] for input in inputs]
|
||||||
|
actual_inputs = [input["input"] for input in inputs]
|
||||||
|
if any(key not in self.runnables for key in keys):
|
||||||
|
raise ValueError("One or more keys do not have a corresponding runnable")
|
||||||
|
|
||||||
|
runnables = [self.runnables[key] for key in keys]
|
||||||
|
configs = self._get_config_list(config, len(inputs))
|
||||||
|
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||||
|
return list(
|
||||||
|
executor.map(
|
||||||
|
lambda runnable, input, config: runnable.invoke(input, config),
|
||||||
|
runnables,
|
||||||
|
actual_inputs,
|
||||||
|
configs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def abatch(
|
||||||
|
self,
|
||||||
|
inputs: List[RouterInput],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
max_concurrency: Optional[int] = None,
|
||||||
|
) -> List[Output]:
|
||||||
|
keys = [input["key"] for input in inputs]
|
||||||
|
actual_inputs = [input["input"] for input in inputs]
|
||||||
|
if any(key not in self.runnables for key in keys):
|
||||||
|
raise ValueError("One or more keys do not have a corresponding runnable")
|
||||||
|
|
||||||
|
runnables = [self.runnables[key] for key in keys]
|
||||||
|
configs = self._get_config_list(config, len(inputs))
|
||||||
|
return await gather_with_concurrency(
|
||||||
|
max_concurrency,
|
||||||
|
*(
|
||||||
|
runnable.ainvoke(input, config)
|
||||||
|
for runnable, input, config in zip(runnables, actual_inputs, configs)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Iterator[Output]:
|
||||||
|
key = input["key"]
|
||||||
|
actual_input = input["input"]
|
||||||
|
if key not in self.runnables:
|
||||||
|
raise ValueError(f"No runnable associated with key '{key}'")
|
||||||
|
|
||||||
|
runnable = self.runnables[key]
|
||||||
|
yield from runnable.stream(actual_input, config)
|
||||||
|
|
||||||
|
async def astream(
|
||||||
|
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||||
|
) -> AsyncIterator[Output]:
|
||||||
|
key = input["key"]
|
||||||
|
actual_input = input["input"]
|
||||||
|
if key not in self.runnables:
|
||||||
|
raise ValueError(f"No runnable associated with key '{key}'")
|
||||||
|
|
||||||
|
runnable = self.runnables[key]
|
||||||
|
async for output in runnable.astream(actual_input, config):
|
||||||
|
yield output
|
18
libs/langchain/langchain/schema/runnable/utils.py
Normal file
18
libs/langchain/langchain/schema/runnable/utils.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Any, Coroutine, Union
|
||||||
|
|
||||||
|
|
||||||
|
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
||||||
|
async with semaphore:
|
||||||
|
return await coro
|
||||||
|
|
||||||
|
|
||||||
|
async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list:
|
||||||
|
if n is None:
|
||||||
|
return await asyncio.gather(*coros)
|
||||||
|
|
||||||
|
semaphore = asyncio.Semaphore(n)
|
||||||
|
|
||||||
|
return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros))
|
File diff suppressed because one or more lines are too long
@ -839,7 +839,7 @@ def llm_chain_with_fallbacks() -> RunnableSequence:
|
|||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_llm_with_fallbacks(
|
async def test_llm_with_fallbacks(
|
||||||
runnable: RunnableWithFallbacks, request: Any
|
runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion
|
||||||
) -> None:
|
) -> None:
|
||||||
runnable = request.getfixturevalue(runnable)
|
runnable = request.getfixturevalue(runnable)
|
||||||
assert runnable.invoke("hello") == "bar"
|
assert runnable.invoke("hello") == "bar"
|
||||||
@ -848,3 +848,4 @@ async def test_llm_with_fallbacks(
|
|||||||
assert await runnable.ainvoke("hello") == "bar"
|
assert await runnable.ainvoke("hello") == "bar"
|
||||||
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
|
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
|
||||||
assert list(await runnable.ainvoke("hello")) == list("bar")
|
assert list(await runnable.ainvoke("hello")) == list("bar")
|
||||||
|
assert dumps(runnable, pretty=True) == snapshot
|
||||||
|
Loading…
Reference in New Issue
Block a user