RunnableLambda, if func returns a Runnable, run it

This commit is contained in:
Nuno Campos 2023-08-17 15:04:36 +01:00
parent 677da6a0fd
commit 6d19709b65
4 changed files with 245 additions and 9 deletions

View File

@ -13,7 +13,6 @@ from typing import (
AsyncIterator, AsyncIterator,
Awaitable, Awaitable,
Callable, Callable,
Coroutine,
Dict, Dict,
Generic, Generic,
Iterator, Iterator,
@ -1347,8 +1346,8 @@ class RunnableLambda(Runnable[Input, Output]):
def __init__( def __init__(
self, self,
func: Union[Callable[[Input], Output], Coroutine[Input, Any, Output]], func: Union[Callable[[Input], Output], Callable[[Input], Awaitable[Output]]],
afunc: Optional[Coroutine[Input, Any, Output]] = None, afunc: Optional[Callable[[Input], Awaitable[Output]]] = None,
) -> None: ) -> None:
if afunc is not None: if afunc is not None:
self.afunc = afunc self.afunc = afunc
@ -1356,7 +1355,7 @@ class RunnableLambda(Runnable[Input, Output]):
if inspect.iscoroutinefunction(func): if inspect.iscoroutinefunction(func):
self.afunc = func self.afunc = func
elif callable(func): elif callable(func):
self.func = func self.func = cast(Callable[[Input], Output], func)
else: else:
raise TypeError( raise TypeError(
"Expected a callable type for `func`." "Expected a callable type for `func`."
@ -1374,6 +1373,54 @@ class RunnableLambda(Runnable[Input, Output]):
else: else:
return False return False
def _invoke(
self,
input: Input,
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
) -> Output:
output = self.func(input)
# If the output is a runnable, invoke it
if isinstance(output, Runnable):
recursion_limit = config["recursion_limit"]
if recursion_limit <= 0:
raise RecursionError(
f"Recursion limit reached when invoking {self} with input {input}."
)
output = output.invoke(
input,
patch_config(
config,
callbacks=run_manager.get_child(),
recursion_limit=recursion_limit - 1,
),
)
return output
async def _ainvoke(
self,
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> Output:
output = await self.afunc(input)
# If the output is a runnable, invoke it
if isinstance(output, Runnable):
recursion_limit = config["recursion_limit"]
if recursion_limit <= 0:
raise RecursionError(
f"Recursion limit reached when invoking {self} with input {input}."
)
output = await output.ainvoke(
input,
patch_config(
config,
callbacks=run_manager.get_child(),
recursion_limit=recursion_limit - 1,
),
)
return output
def invoke( def invoke(
self, self,
input: Input, input: Input,
@ -1381,7 +1428,7 @@ class RunnableLambda(Runnable[Input, Output]):
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Output: ) -> Output:
if hasattr(self, "func"): if hasattr(self, "func"):
return self._call_with_config(self.func, input, config) return self._call_with_config(self._invoke, input, config)
else: else:
raise TypeError( raise TypeError(
"Cannot invoke a coroutine function synchronously." "Cannot invoke a coroutine function synchronously."
@ -1395,7 +1442,7 @@ class RunnableLambda(Runnable[Input, Output]):
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Output: ) -> Output:
if hasattr(self, "afunc"): if hasattr(self, "afunc"):
return await self._acall_with_config(self.afunc, input, config) return await self._acall_with_config(self._ainvoke, input, config)
else: else:
return await super().ainvoke(input, config) return await super().ainvoke(input, config)

View File

@ -47,6 +47,11 @@ class RunnableConfig(TypedDict, total=False):
ThreadPoolExecutor will be created. ThreadPoolExecutor will be created.
""" """
recursion_limit: int
"""
Maximum number of times a call can recurse. If not provided, defaults to 10.
"""
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
empty = RunnableConfig( empty = RunnableConfig(
@ -54,6 +59,7 @@ def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
metadata={}, metadata={},
callbacks=None, callbacks=None,
_locals={}, _locals={},
recursion_limit=10,
) )
if config is not None: if config is not None:
empty.update(config) empty.update(config)
@ -66,6 +72,7 @@ def patch_config(
deep_copy_locals: bool = False, deep_copy_locals: bool = False,
callbacks: Optional[BaseCallbackManager] = None, callbacks: Optional[BaseCallbackManager] = None,
executor: Optional[Executor] = None, executor: Optional[Executor] = None,
recursion_limit: Optional[int] = None,
) -> RunnableConfig: ) -> RunnableConfig:
config = ensure_config(config) config = ensure_config(config)
if deep_copy_locals: if deep_copy_locals:
@ -74,6 +81,8 @@ def patch_config(
config["callbacks"] = callbacks config["callbacks"] = callbacks
if executor is not None: if executor is not None:
config["executor"] = executor config["executor"] = executor
if recursion_limit is not None:
config["recursion_limit"] = recursion_limit
return config return config

File diff suppressed because one or more lines are too long

View File

@ -1,4 +1,4 @@
from ast import Not from operator import itemgetter
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from uuid import UUID from uuid import UUID
@ -39,7 +39,6 @@ from langchain.schema.runnable import (
RunnablePassthrough, RunnablePassthrough,
RunnableSequence, RunnableSequence,
RunnableWithFallbacks, RunnableWithFallbacks,
passthrough,
) )
@ -178,6 +177,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
tags=[], tags=[],
callbacks=None, callbacks=None,
_locals={}, _locals={},
recursion_limit=10,
), ),
), ),
mocker.call( mocker.call(
@ -187,6 +187,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
tags=[], tags=[],
callbacks=None, callbacks=None,
_locals={}, _locals={},
recursion_limit=10,
), ),
), ),
] ]
@ -768,6 +769,105 @@ async def test_router_runnable(
assert len(router_run.child_runs) == 2 assert len(router_run.child_runs) == 2
@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_higher_order_lambda_runnable(
mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
math_chain = ChatPromptTemplate.from_template(
"You are a math genius. Answer the question: {question}"
) | FakeListLLM(responses=["4"])
english_chain = ChatPromptTemplate.from_template(
"You are an english major. Answer the question: {question}"
) | FakeListLLM(responses=["2"])
input_map = RunnableMap(
{ # type: ignore[arg-type]
"key": lambda x: x["key"],
"input": {"question": lambda x: x["question"]},
}
)
def router(input: Dict[str, Any]) -> Runnable:
if input["key"] == "math":
return itemgetter("input") | math_chain
elif input["key"] == "english":
return itemgetter("input") | english_chain
else:
raise ValueError(f"Unknown key: {input['key']}")
chain: Runnable = input_map | router
assert dumps(chain, pretty=True) == snapshot
result = chain.invoke({"key": "math", "question": "2 + 2"})
assert result == "4"
result2 = chain.batch(
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
)
assert result2 == ["4", "2"]
result = await chain.ainvoke({"key": "math", "question": "2 + 2"})
assert result == "4"
result2 = await chain.abatch(
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
)
assert result2 == ["4", "2"]
# Test invoke
math_spy = mocker.spy(math_chain.__class__, "invoke")
tracer = FakeTracer()
assert (
chain.invoke({"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer]))
== "4"
)
assert math_spy.call_args.args[1] == {
"key": "math",
"input": {"question": "2 + 2"},
}
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 2
router_run = parent_run.child_runs[1]
assert router_run.name == "RunnableLambda"
assert len(router_run.child_runs) == 1
math_run = router_run.child_runs[0]
assert math_run.name == "RunnableSequence"
assert len(math_run.child_runs) == 3
# Test ainvoke
async def arouter(input: Dict[str, Any]) -> Runnable:
if input["key"] == "math":
return itemgetter("input") | math_chain
elif input["key"] == "english":
return itemgetter("input") | english_chain
else:
raise ValueError(f"Unknown key: {input['key']}")
achain: Runnable = input_map | arouter
math_spy = mocker.spy(math_chain.__class__, "ainvoke")
tracer = FakeTracer()
assert (
await achain.ainvoke(
{"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer])
)
== "4"
)
assert math_spy.call_args.args[1] == {
"key": "math",
"input": {"question": "2 + 2"},
}
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 2
router_run = parent_run.child_runs[1]
assert router_run.name == "RunnableLambda"
assert len(router_run.child_runs) == 1
math_run = router_run.child_runs[0]
assert math_run.name == "RunnableSequence"
assert len(math_run.child_runs) == 3
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None: def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None:
passthrough = mocker.Mock(side_effect=lambda x: x) passthrough = mocker.Mock(side_effect=lambda x: x)