mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 03:01:29 +00:00
Allow config propagation, Add default lambda name, Improve ergonomics of config passed in (#10273)
Makes it easier to do recursion using regular python compositional patterns ```py def lambda_decorator(func): """Decorate function as a RunnableLambda""" return runnable.RunnableLambda(func) @lambda_decorator def fibonacci(a, config: runnable.RunnableConfig) -> int: if a <= 1: return a else: return fibonacci.invoke( a - 1, config ) + fibonacci.invoke(a - 2, config) fibonacci.invoke(10) ``` https://smith.langchain.com/public/cb98edb4-3a09-4798-9c22-a930037faf88/r Also makes it more natural to do things like error handle and call other langchain objects in ways we probably don't want to support in `with_fallbacks()` ```py @lambda_decorator def handle_errors(a, config: runnable.RunnableConfig) -> int: try: return my_chain.invoke(a, config) except MyExceptionType as exc: return my_other_chain.invoke({"original": a, "error": exc}, config) ``` In this case, the next chain takes in the exception object. Maybe this could be something we toggle in `with_fallbacks` but I fear we'll get into uglier APIs + heavier cognitive load if we try to do too much there --------- Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
parent
c732d8fffd
commit
ffca5e7eea
@ -39,6 +39,8 @@ from langchain.load.serializable import Serializable
|
||||
from langchain.pydantic_v1 import Field
|
||||
from langchain.schema.runnable.config import (
|
||||
RunnableConfig,
|
||||
acall_func_with_variable_args,
|
||||
call_func_with_variable_args,
|
||||
ensure_config,
|
||||
get_async_callback_manager_for_config,
|
||||
get_callback_manager_for_config,
|
||||
@ -47,16 +49,15 @@ from langchain.schema.runnable.config import (
|
||||
patch_config,
|
||||
)
|
||||
from langchain.schema.runnable.utils import (
|
||||
Input,
|
||||
Output,
|
||||
accepts_config,
|
||||
accepts_run_manager,
|
||||
accepts_run_manager_and_config,
|
||||
gather_with_concurrency,
|
||||
)
|
||||
from langchain.utils.aiter import atee, py_anext
|
||||
from langchain.utils.iter import safetee
|
||||
|
||||
Input = TypeVar("Input")
|
||||
# Output type should implement __concat__, as eg str, list, dict do
|
||||
Output = TypeVar("Output")
|
||||
Other = TypeVar("Other")
|
||||
|
||||
|
||||
@ -311,16 +312,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
if accepts_run_manager_and_config(func):
|
||||
output = func(
|
||||
input,
|
||||
run_manager=run_manager,
|
||||
config=config,
|
||||
) # type: ignore[call-arg]
|
||||
elif accepts_run_manager(func):
|
||||
output = func(input, run_manager=run_manager) # type: ignore[call-arg]
|
||||
else:
|
||||
output = func(input) # type: ignore[call-arg]
|
||||
output = call_func_with_variable_args(func, input, run_manager, config)
|
||||
except Exception as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
@ -353,19 +345,9 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
if accepts_run_manager_and_config(func):
|
||||
output = await func(
|
||||
input,
|
||||
run_manager=run_manager,
|
||||
config=config,
|
||||
) # type: ignore[call-arg]
|
||||
elif accepts_run_manager(func):
|
||||
output = await func(
|
||||
input,
|
||||
run_manager=run_manager,
|
||||
) # type: ignore[call-arg]
|
||||
else:
|
||||
output = await func(input) # type: ignore[call-arg]
|
||||
output = await acall_func_with_variable_args(
|
||||
func, input, run_manager, config
|
||||
)
|
||||
except Exception as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
@ -408,16 +390,15 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
)
|
||||
]
|
||||
try:
|
||||
if accepts_run_manager_and_config(func):
|
||||
output = func(
|
||||
input,
|
||||
run_manager=run_managers,
|
||||
config=configs,
|
||||
) # type: ignore[call-arg]
|
||||
elif accepts_run_manager(func):
|
||||
output = func(input, run_manager=run_managers) # type: ignore[call-arg]
|
||||
else:
|
||||
output = func(input) # type: ignore[call-arg]
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if accepts_config(func):
|
||||
kwargs["config"] = [
|
||||
patch_config(c, callbacks=rm.get_child())
|
||||
for c, rm in zip(configs, run_managers)
|
||||
]
|
||||
if accepts_run_manager(func):
|
||||
kwargs["run_manager"] = run_managers
|
||||
output = func(input, **kwargs) # type: ignore[call-arg]
|
||||
except Exception as e:
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_chain_error(e)
|
||||
@ -479,16 +460,15 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
)
|
||||
)
|
||||
try:
|
||||
if accepts_run_manager_and_config(func):
|
||||
output = await func(
|
||||
input,
|
||||
run_manager=run_managers,
|
||||
config=configs,
|
||||
) # type: ignore[call-arg]
|
||||
elif accepts_run_manager(func):
|
||||
output = await func(input, run_manager=run_managers) # type: ignore
|
||||
else:
|
||||
output = await func(input) # type: ignore[call-arg]
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if accepts_config(func):
|
||||
kwargs["config"] = [
|
||||
patch_config(c, callbacks=rm.get_child())
|
||||
for c, rm in zip(configs, run_managers)
|
||||
]
|
||||
if accepts_run_manager(func):
|
||||
kwargs["run_manager"] = run_managers
|
||||
output = await func(input, **kwargs) # type: ignore[call-arg]
|
||||
except Exception as e:
|
||||
await asyncio.gather(
|
||||
*(run_manager.on_chain_error(e) for run_manager in run_managers)
|
||||
@ -550,19 +530,14 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
if accepts_run_manager_and_config(transformer):
|
||||
iterator = transformer(
|
||||
input_for_transform,
|
||||
run_manager=run_manager,
|
||||
config=config,
|
||||
) # type: ignore[call-arg]
|
||||
elif accepts_run_manager(transformer):
|
||||
iterator = transformer(
|
||||
input_for_transform,
|
||||
run_manager=run_manager,
|
||||
) # type: ignore[call-arg]
|
||||
else:
|
||||
iterator = transformer(input_for_transform) # type: ignore[call-arg]
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if accepts_config(transformer):
|
||||
kwargs["config"] = patch_config(
|
||||
config, callbacks=run_manager.get_child()
|
||||
)
|
||||
if accepts_run_manager(transformer):
|
||||
kwargs["run_manager"] = run_manager
|
||||
iterator = transformer(input_for_transform, **kwargs) # type: ignore[call-arg]
|
||||
for chunk in iterator:
|
||||
yield chunk
|
||||
if final_output_supported:
|
||||
@ -631,21 +606,14 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
try:
|
||||
# mypy can't quite work out thew type guard here, but this is safe,
|
||||
# check implementations of the accepts_* functions
|
||||
if accepts_run_manager_and_config(transformer):
|
||||
iterator = transformer(
|
||||
input_for_transform,
|
||||
run_manager=run_manager,
|
||||
config=config,
|
||||
) # type: ignore[call-arg]
|
||||
elif accepts_run_manager(transformer):
|
||||
iterator = transformer(
|
||||
input_for_transform,
|
||||
run_manager=run_manager,
|
||||
) # type: ignore[call-arg]
|
||||
else:
|
||||
iterator = transformer(input_for_transform) # type: ignore[call-arg]
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if accepts_config(transformer):
|
||||
kwargs["config"] = patch_config(
|
||||
config, callbacks=run_manager.get_child()
|
||||
)
|
||||
if accepts_run_manager(transformer):
|
||||
kwargs["run_manager"] = run_manager
|
||||
iterator = transformer(input_for_transform, **kwargs) # type: ignore[call-arg]
|
||||
async for chunk in iterator:
|
||||
yield chunk
|
||||
if final_output_supported:
|
||||
@ -1756,7 +1724,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> Output:
|
||||
output = self.func(input)
|
||||
output = call_func_with_variable_args(self.func, input, run_manager, config)
|
||||
# If the output is a runnable, invoke it
|
||||
if isinstance(output, Runnable):
|
||||
recursion_limit = config["recursion_limit"]
|
||||
@ -1780,7 +1748,9 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> Output:
|
||||
output = await self.afunc(input)
|
||||
output = await acall_func_with_variable_args(
|
||||
self.afunc, input, run_manager, config
|
||||
)
|
||||
# If the output is a runnable, invoke it
|
||||
if isinstance(output, Runnable):
|
||||
recursion_limit = config["recursion_limit"]
|
||||
@ -1798,6 +1768,21 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
)
|
||||
return output
|
||||
|
||||
def _config(
|
||||
self, config: Optional[RunnableConfig], callable: Callable[..., Any]
|
||||
) -> RunnableConfig:
|
||||
config = config or {}
|
||||
|
||||
if config.get("run_name") is None:
|
||||
try:
|
||||
run_name = callable.__name__
|
||||
except AttributeError:
|
||||
run_name = None
|
||||
if run_name is not None:
|
||||
return patch_config(config, run_name=run_name)
|
||||
|
||||
return config
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Input,
|
||||
@ -1805,7 +1790,11 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
if hasattr(self, "func"):
|
||||
return self._call_with_config(self._invoke, input, config)
|
||||
return self._call_with_config(
|
||||
self._invoke,
|
||||
input,
|
||||
self._config(config, self.func),
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
"Cannot invoke a coroutine function synchronously."
|
||||
@ -1819,7 +1808,11 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
if hasattr(self, "afunc"):
|
||||
return await self._acall_with_config(self._ainvoke, input, config)
|
||||
return await self._acall_with_config(
|
||||
self._ainvoke,
|
||||
input,
|
||||
self._config(config, self.afunc),
|
||||
)
|
||||
else:
|
||||
# Delegating to super implementation of ainvoke.
|
||||
# Uses asyncio executor to run the sync version (invoke)
|
||||
|
@ -3,13 +3,35 @@ from __future__ import annotations
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain.schema.runnable.utils import (
|
||||
Input,
|
||||
Output,
|
||||
accepts_config,
|
||||
accepts_run_manager,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.callbacks.base import BaseCallbackManager, Callbacks
|
||||
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManager,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
|
||||
|
||||
class RunnableConfig(TypedDict, total=False):
|
||||
@ -117,6 +139,47 @@ def patch_config(
|
||||
return config
|
||||
|
||||
|
||||
def call_func_with_variable_args(
|
||||
func: Union[
|
||||
Callable[[Input], Output],
|
||||
Callable[[Input, CallbackManagerForChainRun], Output],
|
||||
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
|
||||
],
|
||||
input: Input,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> Output:
|
||||
"""Call function that may optionally accept a run_manager and/or config."""
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if accepts_config(func):
|
||||
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
|
||||
if accepts_run_manager(func):
|
||||
kwargs["run_manager"] = run_manager
|
||||
return func(input, **kwargs) # type: ignore[call-arg]
|
||||
|
||||
|
||||
async def acall_func_with_variable_args(
|
||||
func: Union[
|
||||
Callable[[Input], Awaitable[Output]],
|
||||
Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]],
|
||||
Callable[
|
||||
[Input, AsyncCallbackManagerForChainRun, RunnableConfig],
|
||||
Awaitable[Output],
|
||||
],
|
||||
],
|
||||
input: Input,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> Output:
|
||||
"""Call function that may optionally accept a run_manager and/or config."""
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if accepts_config(func):
|
||||
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
|
||||
if accepts_run_manager(func):
|
||||
kwargs["run_manager"] = run_manager
|
||||
return await func(input, **kwargs) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
|
@ -2,7 +2,11 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from inspect import signature
|
||||
from typing import Any, Callable, Coroutine, Union
|
||||
from typing import Any, Callable, Coroutine, TypeVar, Union
|
||||
|
||||
Input = TypeVar("Input")
|
||||
# Output type should implement __concat__, as eg str, list, dict do
|
||||
Output = TypeVar("Output")
|
||||
|
||||
|
||||
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
||||
@ -26,8 +30,8 @@ def accepts_run_manager(callable: Callable[..., Any]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def accepts_run_manager_and_config(callable: Callable[..., Any]) -> bool:
|
||||
return (
|
||||
accepts_run_manager(callable)
|
||||
and signature(callable).parameters.get("config") is not None
|
||||
)
|
||||
def accepts_config(callable: Callable[..., Any]) -> bool:
|
||||
try:
|
||||
return signature(callable).parameters.get("config") is not None
|
||||
except ValueError:
|
||||
return False
|
||||
|
File diff suppressed because one or more lines are too long
@ -948,7 +948,7 @@ async def test_higher_order_lambda_runnable(
|
||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||
assert len(parent_run.child_runs) == 2
|
||||
router_run = parent_run.child_runs[1]
|
||||
assert router_run.name == "RunnableLambda"
|
||||
assert router_run.name == "router"
|
||||
assert len(router_run.child_runs) == 1
|
||||
math_run = router_run.child_runs[0]
|
||||
assert math_run.name == "RunnableSequence"
|
||||
@ -980,7 +980,7 @@ async def test_higher_order_lambda_runnable(
|
||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||
assert len(parent_run.child_runs) == 2
|
||||
router_run = parent_run.child_runs[1]
|
||||
assert router_run.name == "RunnableLambda"
|
||||
assert router_run.name == "arouter"
|
||||
assert len(router_run.child_runs) == 1
|
||||
math_run = router_run.child_runs[0]
|
||||
assert math_run.name == "RunnableSequence"
|
||||
|
Loading…
Reference in New Issue
Block a user