mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 20:58:25 +00:00
RunnableLambda, if func returns a Runnable, run it
This commit is contained in:
parent
677da6a0fd
commit
6d19709b65
@ -13,7 +13,6 @@ from typing import (
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterator,
|
||||
@ -1347,8 +1346,8 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: Union[Callable[[Input], Output], Coroutine[Input, Any, Output]],
|
||||
afunc: Optional[Coroutine[Input, Any, Output]] = None,
|
||||
func: Union[Callable[[Input], Output], Callable[[Input], Awaitable[Output]]],
|
||||
afunc: Optional[Callable[[Input], Awaitable[Output]]] = None,
|
||||
) -> None:
|
||||
if afunc is not None:
|
||||
self.afunc = afunc
|
||||
@ -1356,7 +1355,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
if inspect.iscoroutinefunction(func):
|
||||
self.afunc = func
|
||||
elif callable(func):
|
||||
self.func = func
|
||||
self.func = cast(Callable[[Input], Output], func)
|
||||
else:
|
||||
raise TypeError(
|
||||
"Expected a callable type for `func`."
|
||||
@ -1374,6 +1373,54 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
else:
|
||||
return False
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
input: Input,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> Output:
|
||||
output = self.func(input)
|
||||
# If the output is a runnable, invoke it
|
||||
if isinstance(output, Runnable):
|
||||
recursion_limit = config["recursion_limit"]
|
||||
if recursion_limit <= 0:
|
||||
raise RecursionError(
|
||||
f"Recursion limit reached when invoking {self} with input {input}."
|
||||
)
|
||||
output = output.invoke(
|
||||
input,
|
||||
patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(),
|
||||
recursion_limit=recursion_limit - 1,
|
||||
),
|
||||
)
|
||||
return output
|
||||
|
||||
async def _ainvoke(
|
||||
self,
|
||||
input: Input,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> Output:
|
||||
output = await self.afunc(input)
|
||||
# If the output is a runnable, invoke it
|
||||
if isinstance(output, Runnable):
|
||||
recursion_limit = config["recursion_limit"]
|
||||
if recursion_limit <= 0:
|
||||
raise RecursionError(
|
||||
f"Recursion limit reached when invoking {self} with input {input}."
|
||||
)
|
||||
output = await output.ainvoke(
|
||||
input,
|
||||
patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(),
|
||||
recursion_limit=recursion_limit - 1,
|
||||
),
|
||||
)
|
||||
return output
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Input,
|
||||
@ -1381,7 +1428,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
if hasattr(self, "func"):
|
||||
return self._call_with_config(self.func, input, config)
|
||||
return self._call_with_config(self._invoke, input, config)
|
||||
else:
|
||||
raise TypeError(
|
||||
"Cannot invoke a coroutine function synchronously."
|
||||
@ -1395,7 +1442,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
if hasattr(self, "afunc"):
|
||||
return await self._acall_with_config(self.afunc, input, config)
|
||||
return await self._acall_with_config(self._ainvoke, input, config)
|
||||
else:
|
||||
return await super().ainvoke(input, config)
|
||||
|
||||
|
@ -47,6 +47,11 @@ class RunnableConfig(TypedDict, total=False):
|
||||
ThreadPoolExecutor will be created.
|
||||
"""
|
||||
|
||||
recursion_limit: int
|
||||
"""
|
||||
Maximum number of times a call can recurse. If not provided, defaults to 10.
|
||||
"""
|
||||
|
||||
|
||||
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
empty = RunnableConfig(
|
||||
@ -54,6 +59,7 @@ def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
metadata={},
|
||||
callbacks=None,
|
||||
_locals={},
|
||||
recursion_limit=10,
|
||||
)
|
||||
if config is not None:
|
||||
empty.update(config)
|
||||
@ -66,6 +72,7 @@ def patch_config(
|
||||
deep_copy_locals: bool = False,
|
||||
callbacks: Optional[BaseCallbackManager] = None,
|
||||
executor: Optional[Executor] = None,
|
||||
recursion_limit: Optional[int] = None,
|
||||
) -> RunnableConfig:
|
||||
config = ensure_config(config)
|
||||
if deep_copy_locals:
|
||||
@ -74,6 +81,8 @@ def patch_config(
|
||||
config["callbacks"] = callbacks
|
||||
if executor is not None:
|
||||
config["executor"] = executor
|
||||
if recursion_limit is not None:
|
||||
config["recursion_limit"] = recursion_limit
|
||||
return config
|
||||
|
||||
|
||||
|
File diff suppressed because one or more lines are too long
@ -1,4 +1,4 @@
|
||||
from ast import Not
|
||||
from operator import itemgetter
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
@ -39,7 +39,6 @@ from langchain.schema.runnable import (
|
||||
RunnablePassthrough,
|
||||
RunnableSequence,
|
||||
RunnableWithFallbacks,
|
||||
passthrough,
|
||||
)
|
||||
|
||||
|
||||
@ -178,6 +177,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||
tags=[],
|
||||
callbacks=None,
|
||||
_locals={},
|
||||
recursion_limit=10,
|
||||
),
|
||||
),
|
||||
mocker.call(
|
||||
@ -187,6 +187,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||
tags=[],
|
||||
callbacks=None,
|
||||
_locals={},
|
||||
recursion_limit=10,
|
||||
),
|
||||
),
|
||||
]
|
||||
@ -768,6 +769,105 @@ async def test_router_runnable(
|
||||
assert len(router_run.child_runs) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_higher_order_lambda_runnable(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
math_chain = ChatPromptTemplate.from_template(
|
||||
"You are a math genius. Answer the question: {question}"
|
||||
) | FakeListLLM(responses=["4"])
|
||||
english_chain = ChatPromptTemplate.from_template(
|
||||
"You are an english major. Answer the question: {question}"
|
||||
) | FakeListLLM(responses=["2"])
|
||||
input_map = RunnableMap(
|
||||
{ # type: ignore[arg-type]
|
||||
"key": lambda x: x["key"],
|
||||
"input": {"question": lambda x: x["question"]},
|
||||
}
|
||||
)
|
||||
|
||||
def router(input: Dict[str, Any]) -> Runnable:
|
||||
if input["key"] == "math":
|
||||
return itemgetter("input") | math_chain
|
||||
elif input["key"] == "english":
|
||||
return itemgetter("input") | english_chain
|
||||
else:
|
||||
raise ValueError(f"Unknown key: {input['key']}")
|
||||
|
||||
chain: Runnable = input_map | router
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
result = chain.invoke({"key": "math", "question": "2 + 2"})
|
||||
assert result == "4"
|
||||
|
||||
result2 = chain.batch(
|
||||
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
|
||||
)
|
||||
assert result2 == ["4", "2"]
|
||||
|
||||
result = await chain.ainvoke({"key": "math", "question": "2 + 2"})
|
||||
assert result == "4"
|
||||
|
||||
result2 = await chain.abatch(
|
||||
[{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}]
|
||||
)
|
||||
assert result2 == ["4", "2"]
|
||||
|
||||
# Test invoke
|
||||
math_spy = mocker.spy(math_chain.__class__, "invoke")
|
||||
tracer = FakeTracer()
|
||||
assert (
|
||||
chain.invoke({"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer]))
|
||||
== "4"
|
||||
)
|
||||
assert math_spy.call_args.args[1] == {
|
||||
"key": "math",
|
||||
"input": {"question": "2 + 2"},
|
||||
}
|
||||
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
|
||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||
assert len(parent_run.child_runs) == 2
|
||||
router_run = parent_run.child_runs[1]
|
||||
assert router_run.name == "RunnableLambda"
|
||||
assert len(router_run.child_runs) == 1
|
||||
math_run = router_run.child_runs[0]
|
||||
assert math_run.name == "RunnableSequence"
|
||||
assert len(math_run.child_runs) == 3
|
||||
|
||||
# Test ainvoke
|
||||
async def arouter(input: Dict[str, Any]) -> Runnable:
|
||||
if input["key"] == "math":
|
||||
return itemgetter("input") | math_chain
|
||||
elif input["key"] == "english":
|
||||
return itemgetter("input") | english_chain
|
||||
else:
|
||||
raise ValueError(f"Unknown key: {input['key']}")
|
||||
|
||||
achain: Runnable = input_map | arouter
|
||||
math_spy = mocker.spy(math_chain.__class__, "ainvoke")
|
||||
tracer = FakeTracer()
|
||||
assert (
|
||||
await achain.ainvoke(
|
||||
{"key": "math", "question": "2 + 2"}, dict(callbacks=[tracer])
|
||||
)
|
||||
== "4"
|
||||
)
|
||||
assert math_spy.call_args.args[1] == {
|
||||
"key": "math",
|
||||
"input": {"question": "2 + 2"},
|
||||
}
|
||||
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
|
||||
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
|
||||
assert len(parent_run.child_runs) == 2
|
||||
router_run = parent_run.child_runs[1]
|
||||
assert router_run.name == "RunnableLambda"
|
||||
assert len(router_run.child_runs) == 1
|
||||
math_run = router_run.child_runs[0]
|
||||
assert math_run.name == "RunnableSequence"
|
||||
assert len(math_run.child_runs) == 3
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None:
|
||||
passthrough = mocker.Mock(side_effect=lambda x: x)
|
||||
|
Loading…
Reference in New Issue
Block a user