RunnableLambda, if func returns a Runnable, run it

This commit is contained in:
Nuno Campos 2023-08-17 15:04:36 +01:00
parent 677da6a0fd
commit 6d19709b65
4 changed files with 245 additions and 9 deletions

View File

@ -13,7 +13,6 @@ from typing import (
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Dict,
Generic,
Iterator,
@ -1347,8 +1346,8 @@ class RunnableLambda(Runnable[Input, Output]):
def __init__(
self,
func: Union[Callable[[Input], Output], Coroutine[Input, Any, Output]],
afunc: Optional[Coroutine[Input, Any, Output]] = None,
func: Union[Callable[[Input], Output], Callable[[Input], Awaitable[Output]]],
afunc: Optional[Callable[[Input], Awaitable[Output]]] = None,
) -> None:
if afunc is not None:
self.afunc = afunc
@ -1356,7 +1355,7 @@ class RunnableLambda(Runnable[Input, Output]):
if inspect.iscoroutinefunction(func):
self.afunc = func
elif callable(func):
self.func = func
self.func = cast(Callable[[Input], Output], func)
else:
raise TypeError(
"Expected a callable type for `func`."
@ -1374,6 +1373,54 @@ class RunnableLambda(Runnable[Input, Output]):
else:
return False
def _invoke(
self,
input: Input,
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
) -> Output:
output = self.func(input)
# If the output is a runnable, invoke it
if isinstance(output, Runnable):
recursion_limit = config["recursion_limit"]
if recursion_limit <= 0:
raise RecursionError(
f"Recursion limit reached when invoking {self} with input {input}."
)
output = output.invoke(
input,
patch_config(
config,
callbacks=run_manager.get_child(),
recursion_limit=recursion_limit - 1,
),
)
return output
async def _ainvoke(
self,
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> Output:
output = await self.afunc(input)
# If the output is a runnable, invoke it
if isinstance(output, Runnable):
recursion_limit = config["recursion_limit"]
if recursion_limit <= 0:
raise RecursionError(
f"Recursion limit reached when invoking {self} with input {input}."
)
output = await output.ainvoke(
input,
patch_config(
config,
callbacks=run_manager.get_child(),
recursion_limit=recursion_limit - 1,
),
)
return output
def invoke(
self,
input: Input,
@ -1381,7 +1428,7 @@ class RunnableLambda(Runnable[Input, Output]):
**kwargs: Optional[Any],
) -> Output:
if hasattr(self, "func"):
return self._call_with_config(self.func, input, config)
return self._call_with_config(self._invoke, input, config)
else:
raise TypeError(
"Cannot invoke a coroutine function synchronously."
@ -1395,7 +1442,7 @@ class RunnableLambda(Runnable[Input, Output]):
**kwargs: Optional[Any],
) -> Output:
if hasattr(self, "afunc"):
return await self._acall_with_config(self.afunc, input, config)
return await self._acall_with_config(self._ainvoke, input, config)
else:
return await super().ainvoke(input, config)

View File

@ -47,6 +47,11 @@ class RunnableConfig(TypedDict, total=False):
ThreadPoolExecutor will be created.
"""
recursion_limit: int
"""
Maximum number of times a call can recurse. If not provided, defaults to 10.
"""
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
empty = RunnableConfig(
@ -54,6 +59,7 @@ def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
metadata={},
callbacks=None,
_locals={},
recursion_limit=10,
)
if config is not None:
empty.update(config)
@ -66,6 +72,7 @@ def patch_config(
deep_copy_locals: bool = False,
callbacks: Optional[BaseCallbackManager] = None,
executor: Optional[Executor] = None,
recursion_limit: Optional[int] = None,
) -> RunnableConfig:
config = ensure_config(config)
if deep_copy_locals:
@ -74,6 +81,8 @@ def patch_config(
config["callbacks"] = callbacks
if executor is not None:
config["executor"] = executor
if recursion_limit is not None:
config["recursion_limit"] = recursion_limit
return config

File diff suppressed because one or more lines are too long

View File

@ -1,4 +1,4 @@
from ast import Not
from operator import itemgetter
from typing import Any, Dict, List, Optional
from uuid import UUID
@ -39,7 +39,6 @@ from langchain.schema.runnable import (
RunnablePassthrough,
RunnableSequence,
RunnableWithFallbacks,
passthrough,
)
@ -178,6 +177,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
tags=[],
callbacks=None,
_locals={},
recursion_limit=10,
),
),
mocker.call(
@ -187,6 +187,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
tags=[],
callbacks=None,
_locals={},
recursion_limit=10,
),
),
]
@ -768,6 +769,105 @@ async def test_router_runnable(
assert len(router_run.child_runs) == 2
@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_higher_order_lambda_runnable(
mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
math_chain = ChatPromptTemplate.from_template(
"You are a math genius. Answer the question: {question}"
) | FakeListLLM(responses=["4"])
english_chain = ChatPromptTemplate.from_template(
"You are an english major. Answer the question: {question}"
) | FakeListLLM(responses=["2"])
input_map = RunnableMap(
{ # type: ignore[arg-type]
"key": lambda x: x["key"],
"input": {"question": lambda x: x["question"]},
}
)
def router(input: Dict[str, Any]) -> Runnable:
if input["key"] == "math":
return itemgetter("input") | math_chain
elif input["key"] == "english":
return itemgetter("input") | english_chain
else:
raise ValueError(f"Unknown key: {input['key']}")
chain: Runnable = input_map | router
assert dumps(chain, pretty=True) == snapshot
result = chain.invoke({"key": "math", "question": "2 + 2"})
assert result == "4"
result2 = chain.batch(
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
)
assert result2 == ["4", "2"]
result = await chain.ainvoke({"key": "math", "question": "2 + 2"})
assert result == "4"
result2 = await chain.abatch(
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
)
assert result2 == ["4", "2"]
# Test invoke
math_spy = mocker.spy(math_chain.__class__, "invoke")
tracer = FakeTracer()
assert (
chain.invoke({"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer]))
== "4"
)
assert math_spy.call_args.args[1] == {
"key": "math",
"input": {"question": "2 + 2"},
}
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 2
router_run = parent_run.child_runs[1]
assert router_run.name == "RunnableLambda"
assert len(router_run.child_runs) == 1
math_run = router_run.child_runs[0]
assert math_run.name == "RunnableSequence"
assert len(math_run.child_runs) == 3
# Test ainvoke
async def arouter(input: Dict[str, Any]) -> Runnable:
if input["key"] == "math":
return itemgetter("input") | math_chain
elif input["key"] == "english":
return itemgetter("input") | english_chain
else:
raise ValueError(f"Unknown key: {input['key']}")
achain: Runnable = input_map | arouter
math_spy = mocker.spy(math_chain.__class__, "ainvoke")
tracer = FakeTracer()
assert (
await achain.ainvoke(
{"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer])
)
== "4"
)
assert math_spy.call_args.args[1] == {
"key": "math",
"input": {"question": "2 + 2"},
}
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 2
router_run = parent_run.child_runs[1]
assert router_run.name == "RunnableLambda"
assert len(router_run.child_runs) == 1
math_run = router_run.child_runs[0]
assert math_run.name == "RunnableSequence"
assert len(math_run.child_runs) == 3
@freeze_time("2023-01-01")
def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None:
passthrough = mocker.Mock(side_effect=lambda x: x)