Allow config propagation, Add default lambda name, Improve ergonomics of config passed in (#10273)

Makes it easier to do recursion using regular python compositional
patterns

```py
def lambda_decorator(func):
    """Decorate function as a RunnableLambda"""
    return runnable.RunnableLambda(func)

@lambda_decorator
def fibonacci(a, config: runnable.RunnableConfig) -> int:
    if a <= 1:
        return a
    else:
        return fibonacci.invoke(
            a - 1, config
        ) + fibonacci.invoke(a - 2, config)

fibonacci.invoke(10)
```

https://smith.langchain.com/public/cb98edb4-3a09-4798-9c22-a930037faf88/r

Also makes it more natural to do things like error handle and call other
langchain objects in ways we probably don't want to support in
`with_fallbacks()`

```py
@lambda_decorator
def handle_errors(a, config: runnable.RunnableConfig) -> int:
    try:
        return my_chain.invoke(a, config)
    except MyExceptionType as exc:
        return my_other_chain.invoke({"original": a, "error": exc}, config)
```

In this case, the next chain takes in the exception object. Maybe this
could be something we toggle in `with_fallbacks` but I fear we'll get
into uglier APIs + heavier cognitive load if we try to do too much there

---------

Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
William FH 2023-09-06 05:54:38 -07:00 committed by GitHub
parent c732d8fffd
commit ffca5e7eea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 151 additions and 91 deletions

View File

@ -39,6 +39,8 @@ from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Field from langchain.pydantic_v1 import Field
from langchain.schema.runnable.config import ( from langchain.schema.runnable.config import (
RunnableConfig, RunnableConfig,
acall_func_with_variable_args,
call_func_with_variable_args,
ensure_config, ensure_config,
get_async_callback_manager_for_config, get_async_callback_manager_for_config,
get_callback_manager_for_config, get_callback_manager_for_config,
@ -47,16 +49,15 @@ from langchain.schema.runnable.config import (
patch_config, patch_config,
) )
from langchain.schema.runnable.utils import ( from langchain.schema.runnable.utils import (
Input,
Output,
accepts_config,
accepts_run_manager, accepts_run_manager,
accepts_run_manager_and_config,
gather_with_concurrency, gather_with_concurrency,
) )
from langchain.utils.aiter import atee, py_anext from langchain.utils.aiter import atee, py_anext
from langchain.utils.iter import safetee from langchain.utils.iter import safetee
Input = TypeVar("Input")
# Output type should implement __concat__, as eg str, list, dict do
Output = TypeVar("Output")
Other = TypeVar("Other") Other = TypeVar("Other")
@ -311,16 +312,7 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name"), name=config.get("run_name"),
) )
try: try:
if accepts_run_manager_and_config(func): output = call_func_with_variable_args(func, input, run_manager, config)
output = func(
input,
run_manager=run_manager,
config=config,
) # type: ignore[call-arg]
elif accepts_run_manager(func):
output = func(input, run_manager=run_manager) # type: ignore[call-arg]
else:
output = func(input) # type: ignore[call-arg]
except Exception as e: except Exception as e:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
raise raise
@ -353,19 +345,9 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name"), name=config.get("run_name"),
) )
try: try:
if accepts_run_manager_and_config(func): output = await acall_func_with_variable_args(
output = await func( func, input, run_manager, config
input, )
run_manager=run_manager,
config=config,
) # type: ignore[call-arg]
elif accepts_run_manager(func):
output = await func(
input,
run_manager=run_manager,
) # type: ignore[call-arg]
else:
output = await func(input) # type: ignore[call-arg]
except Exception as e: except Exception as e:
await run_manager.on_chain_error(e) await run_manager.on_chain_error(e)
raise raise
@ -408,16 +390,15 @@ class Runnable(Generic[Input, Output], ABC):
) )
] ]
try: try:
if accepts_run_manager_and_config(func): kwargs: Dict[str, Any] = {}
output = func( if accepts_config(func):
input, kwargs["config"] = [
run_manager=run_managers, patch_config(c, callbacks=rm.get_child())
config=configs, for c, rm in zip(configs, run_managers)
) # type: ignore[call-arg] ]
elif accepts_run_manager(func): if accepts_run_manager(func):
output = func(input, run_manager=run_managers) # type: ignore[call-arg] kwargs["run_manager"] = run_managers
else: output = func(input, **kwargs) # type: ignore[call-arg]
output = func(input) # type: ignore[call-arg]
except Exception as e: except Exception as e:
for run_manager in run_managers: for run_manager in run_managers:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
@ -479,16 +460,15 @@ class Runnable(Generic[Input, Output], ABC):
) )
) )
try: try:
if accepts_run_manager_and_config(func): kwargs: Dict[str, Any] = {}
output = await func( if accepts_config(func):
input, kwargs["config"] = [
run_manager=run_managers, patch_config(c, callbacks=rm.get_child())
config=configs, for c, rm in zip(configs, run_managers)
) # type: ignore[call-arg] ]
elif accepts_run_manager(func): if accepts_run_manager(func):
output = await func(input, run_manager=run_managers) # type: ignore kwargs["run_manager"] = run_managers
else: output = await func(input, **kwargs) # type: ignore[call-arg]
output = await func(input) # type: ignore[call-arg]
except Exception as e: except Exception as e:
await asyncio.gather( await asyncio.gather(
*(run_manager.on_chain_error(e) for run_manager in run_managers) *(run_manager.on_chain_error(e) for run_manager in run_managers)
@ -550,19 +530,14 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name"), name=config.get("run_name"),
) )
try: try:
if accepts_run_manager_and_config(transformer): kwargs: Dict[str, Any] = {}
iterator = transformer( if accepts_config(transformer):
input_for_transform, kwargs["config"] = patch_config(
run_manager=run_manager, config, callbacks=run_manager.get_child()
config=config, )
) # type: ignore[call-arg] if accepts_run_manager(transformer):
elif accepts_run_manager(transformer): kwargs["run_manager"] = run_manager
iterator = transformer( iterator = transformer(input_for_transform, **kwargs) # type: ignore[call-arg]
input_for_transform,
run_manager=run_manager,
) # type: ignore[call-arg]
else:
iterator = transformer(input_for_transform) # type: ignore[call-arg]
for chunk in iterator: for chunk in iterator:
yield chunk yield chunk
if final_output_supported: if final_output_supported:
@ -631,21 +606,14 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name"), name=config.get("run_name"),
) )
try: try:
# mypy can't quite work out thew type guard here, but this is safe, kwargs: Dict[str, Any] = {}
# check implementations of the accepts_* functions if accepts_config(transformer):
if accepts_run_manager_and_config(transformer): kwargs["config"] = patch_config(
iterator = transformer( config, callbacks=run_manager.get_child()
input_for_transform, )
run_manager=run_manager, if accepts_run_manager(transformer):
config=config, kwargs["run_manager"] = run_manager
) # type: ignore[call-arg] iterator = transformer(input_for_transform, **kwargs) # type: ignore[call-arg]
elif accepts_run_manager(transformer):
iterator = transformer(
input_for_transform,
run_manager=run_manager,
) # type: ignore[call-arg]
else:
iterator = transformer(input_for_transform) # type: ignore[call-arg]
async for chunk in iterator: async for chunk in iterator:
yield chunk yield chunk
if final_output_supported: if final_output_supported:
@ -1756,7 +1724,7 @@ class RunnableLambda(Runnable[Input, Output]):
run_manager: CallbackManagerForChainRun, run_manager: CallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
) -> Output: ) -> Output:
output = self.func(input) output = call_func_with_variable_args(self.func, input, run_manager, config)
# If the output is a runnable, invoke it # If the output is a runnable, invoke it
if isinstance(output, Runnable): if isinstance(output, Runnable):
recursion_limit = config["recursion_limit"] recursion_limit = config["recursion_limit"]
@ -1780,7 +1748,9 @@ class RunnableLambda(Runnable[Input, Output]):
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
) -> Output: ) -> Output:
output = await self.afunc(input) output = await acall_func_with_variable_args(
self.afunc, input, run_manager, config
)
# If the output is a runnable, invoke it # If the output is a runnable, invoke it
if isinstance(output, Runnable): if isinstance(output, Runnable):
recursion_limit = config["recursion_limit"] recursion_limit = config["recursion_limit"]
@ -1798,6 +1768,21 @@ class RunnableLambda(Runnable[Input, Output]):
) )
return output return output
def _config(
self, config: Optional[RunnableConfig], callable: Callable[..., Any]
) -> RunnableConfig:
config = config or {}
if config.get("run_name") is None:
try:
run_name = callable.__name__
except AttributeError:
run_name = None
if run_name is not None:
return patch_config(config, run_name=run_name)
return config
def invoke( def invoke(
self, self,
input: Input, input: Input,
@ -1805,7 +1790,11 @@ class RunnableLambda(Runnable[Input, Output]):
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Output: ) -> Output:
if hasattr(self, "func"): if hasattr(self, "func"):
return self._call_with_config(self._invoke, input, config) return self._call_with_config(
self._invoke,
input,
self._config(config, self.func),
)
else: else:
raise TypeError( raise TypeError(
"Cannot invoke a coroutine function synchronously." "Cannot invoke a coroutine function synchronously."
@ -1819,7 +1808,11 @@ class RunnableLambda(Runnable[Input, Output]):
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Output: ) -> Output:
if hasattr(self, "afunc"): if hasattr(self, "afunc"):
return await self._acall_with_config(self._ainvoke, input, config) return await self._acall_with_config(
self._ainvoke,
input,
self._config(config, self.afunc),
)
else: else:
# Delegating to super implementation of ainvoke. # Delegating to super implementation of ainvoke.
# Uses asyncio executor to run the sync version (invoke) # Uses asyncio executor to run the sync version (invoke)

View File

@ -3,13 +3,35 @@ from __future__ import annotations
from concurrent.futures import Executor, ThreadPoolExecutor from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Generator,
List,
Optional,
Union,
)
from typing_extensions import TypedDict from typing_extensions import TypedDict
from langchain.schema.runnable.utils import (
Input,
Output,
accepts_config,
accepts_run_manager,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain.callbacks.base import BaseCallbackManager, Callbacks from langchain.callbacks.base import BaseCallbackManager, Callbacks
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
CallbackManager,
CallbackManagerForChainRun,
)
class RunnableConfig(TypedDict, total=False): class RunnableConfig(TypedDict, total=False):
@ -117,6 +139,47 @@ def patch_config(
return config return config
def call_func_with_variable_args(
func: Union[
Callable[[Input], Output],
Callable[[Input, CallbackManagerForChainRun], Output],
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
],
input: Input,
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
) -> Output:
"""Call function that may optionally accept a run_manager and/or config."""
kwargs: Dict[str, Any] = {}
if accepts_config(func):
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
if accepts_run_manager(func):
kwargs["run_manager"] = run_manager
return func(input, **kwargs) # type: ignore[call-arg]
async def acall_func_with_variable_args(
func: Union[
Callable[[Input], Awaitable[Output]],
Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]],
Callable[
[Input, AsyncCallbackManagerForChainRun, RunnableConfig],
Awaitable[Output],
],
],
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> Output:
"""Call function that may optionally accept a run_manager and/or config."""
kwargs: Dict[str, Any] = {}
if accepts_config(func):
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
if accepts_run_manager(func):
kwargs["run_manager"] = run_manager
return await func(input, **kwargs) # type: ignore[call-arg]
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager: def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
from langchain.callbacks.manager import CallbackManager from langchain.callbacks.manager import CallbackManager

View File

@ -2,7 +2,11 @@ from __future__ import annotations
import asyncio import asyncio
from inspect import signature from inspect import signature
from typing import Any, Callable, Coroutine, Union from typing import Any, Callable, Coroutine, TypeVar, Union
Input = TypeVar("Input")
# Output type should implement __concat__, as eg str, list, dict do
Output = TypeVar("Output")
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any: async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
@ -26,8 +30,8 @@ def accepts_run_manager(callable: Callable[..., Any]) -> bool:
return False return False
def accepts_run_manager_and_config(callable: Callable[..., Any]) -> bool: def accepts_config(callable: Callable[..., Any]) -> bool:
return ( try:
accepts_run_manager(callable) return signature(callable).parameters.get("config") is not None
and signature(callable).parameters.get("config") is not None except ValueError:
) return False

File diff suppressed because one or more lines are too long

View File

@ -948,7 +948,7 @@ async def test_higher_order_lambda_runnable(
parent_run = next(r for r in tracer.runs if r.parent_run_id is None) parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 2 assert len(parent_run.child_runs) == 2
router_run = parent_run.child_runs[1] router_run = parent_run.child_runs[1]
assert router_run.name == "RunnableLambda" assert router_run.name == "router"
assert len(router_run.child_runs) == 1 assert len(router_run.child_runs) == 1
math_run = router_run.child_runs[0] math_run = router_run.child_runs[0]
assert math_run.name == "RunnableSequence" assert math_run.name == "RunnableSequence"
@ -980,7 +980,7 @@ async def test_higher_order_lambda_runnable(
parent_run = next(r for r in tracer.runs if r.parent_run_id is None) parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 2 assert len(parent_run.child_runs) == 2
router_run = parent_run.child_runs[1] router_run = parent_run.child_runs[1]
assert router_run.name == "RunnableLambda" assert router_run.name == "arouter"
assert len(router_run.child_runs) == 1 assert len(router_run.child_runs) == 1
math_run = router_run.child_runs[0] math_run = router_run.child_runs[0]
assert math_run.name == "RunnableSequence" assert math_run.name == "RunnableSequence"